package org.elasticsearch.xpack.ml.inference.nlp;

import java.util.Comparator;
import java.util.Objects;
import java.util.PriorityQueue;
import org.elasticsearch.search.aggregations.pipeline.MovingFunctions;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/NlpHelpers.class */
public final class NlpHelpers {

    /* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/NlpHelpers$ScoreAndIndex.class */
    public static class ScoreAndIndex {
        final double score;
        final int index;

        ScoreAndIndex(double d, int i) {
            this.score = d;
            this.index = i;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            ScoreAndIndex scoreAndIndex = (ScoreAndIndex) obj;
            return Double.compare(scoreAndIndex.score, this.score) == 0 && this.index == scoreAndIndex.index;
        }

        public int hashCode() {
            return Objects.hash(Double.valueOf(this.score), Integer.valueOf(this.index));
        }
    }

    private NlpHelpers() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v2, types: [double[], double[][]] */
    public static double[][] convertToProbabilitiesBySoftMax(double[][] dArr) {
        ?? r0 = new double[dArr.length];
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            r0[i] = new double[dArr[i].length];
            double max = MovingFunctions.max(dArr[i]);
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                r0[i][i2] = Math.exp(dArr[i][i2] - max);
                int i3 = i;
                dArr2[i3] = dArr2[i3] + r0[i][i2];
            }
        }
        for (int i4 = 0; i4 < dArr.length; i4++) {
            for (int i5 = 0; i5 < dArr[i4].length; i5++) {
                double[] dArr3 = r0[i4];
                int i6 = i5;
                dArr3[i6] = dArr3[i6] / dArr2[i4];
            }
        }
        return r0;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double[] convertToProbabilitiesBySoftMax(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        double d = 0.0d;
        double max = MovingFunctions.max(dArr);
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = Math.exp(dArr[i] - max);
            d += dArr2[i];
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            int i3 = i2;
            dArr2[i3] = dArr2[i3] / d;
        }
        return dArr2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int argmax(double[] dArr) {
        int i = 0;
        for (int i2 = 1; i2 < dArr.length; i2++) {
            if (dArr[i2] > dArr[i]) {
                i = i2;
            }
        }
        return i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static ScoreAndIndex[] topK(int i, double[] dArr) {
        if (i > dArr.length) {
            i = dArr.length;
        }
        PriorityQueue priorityQueue = new PriorityQueue(i, Comparator.comparingDouble(scoreAndIndex -> {
            return scoreAndIndex.score;
        }));
        for (int i2 = 0; i2 < i; i2++) {
            priorityQueue.add(new ScoreAndIndex(dArr[i2], i2));
        }
        double d = ((ScoreAndIndex) priorityQueue.peek()).score;
        for (int i3 = i; i3 < dArr.length; i3++) {
            if (dArr[i3] > d) {
                priorityQueue.poll();
                priorityQueue.add(new ScoreAndIndex(dArr[i3], i3));
                d = ((ScoreAndIndex) priorityQueue.peek()).score;
            }
        }
        ScoreAndIndex[] scoreAndIndexArr = new ScoreAndIndex[i];
        for (int i4 = i - 1; i4 >= 0; i4--) {
            scoreAndIndexArr[i4] = (ScoreAndIndex) priorityQueue.poll();
        }
        return scoreAndIndexArr;
    }
}
