package org.elasticsearch.xpack.ml.dataframe.traintestsplit;

import java.util.List;
import java.util.Random;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;

/* loaded from: input_file:org/elasticsearch/xpack/ml/dataframe/traintestsplit/AbstractReservoirTrainTestSplitter.class */
abstract class AbstractReservoirTrainTestSplitter implements TrainTestSplitter {
    protected final int dependentVariableIndex;
    private final double samplingRatio;
    private final Random random;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/elasticsearch/xpack/ml/dataframe/traintestsplit/AbstractReservoirTrainTestSplitter$SampleInfo.class */
    static class SampleInfo {
        private final long classCount;
        private long training;
        private long observed;

        /* JADX INFO: Access modifiers changed from: package-private */
        public SampleInfo(long j) {
            this.classCount = j;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public AbstractReservoirTrainTestSplitter(List<String> list, String str, double d, long j) {
        if (!$assertionsDisabled && (d < 1.0d || d > 100.0d)) {
            throw new AssertionError();
        }
        this.dependentVariableIndex = findDependentVariableIndex(list, str);
        this.samplingRatio = d / 100.0d;
        this.random = new Random(j);
    }

    private static int findDependentVariableIndex(List<String> list, String str) {
        int indexOf = list.indexOf(str);
        if (indexOf < 0) {
            throw ExceptionsHelper.serverError("Could not find dependent variable [" + str + "] in fields " + list);
        }
        return indexOf;
    }

    @Override // org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter
    public boolean isTraining(String[] strArr) {
        if (!canBeUsedForTraining(strArr)) {
            return false;
        }
        SampleInfo sampleInfo = getSampleInfo(strArr);
        boolean z = this.random.nextDouble() <= ((double) (((long) Math.max(1.0d, this.samplingRatio * ((double) sampleInfo.classCount))) - sampleInfo.training)) / ((double) (sampleInfo.classCount - sampleInfo.observed));
        sampleInfo.observed++;
        if (!z) {
            return false;
        }
        sampleInfo.training++;
        return true;
    }

    private boolean canBeUsedForTraining(String[] strArr) {
        return strArr[this.dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE;
    }

    protected abstract SampleInfo getSampleInfo(String[] strArr);

    static {
        $assertionsDisabled = !AbstractReservoirTrainTestSplitter.class.desiredAssertionStatus();
    }
}
