package org.elasticsearch.xpack.ml.dataframe.traintestsplit;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.AbstractReservoirTrainTestSplitter;

/* loaded from: input_file:org/elasticsearch/xpack/ml/dataframe/traintestsplit/StratifiedTrainTestSplitter.class */
public class StratifiedTrainTestSplitter extends AbstractReservoirTrainTestSplitter {
    private final Map<String, AbstractReservoirTrainTestSplitter.SampleInfo> classSamples;

    public StratifiedTrainTestSplitter(List<String> list, String str, Map<String, Long> map, double d, long j) {
        super(list, str, d, j);
        this.classSamples = new HashMap();
        map.entrySet().forEach(entry -> {
            this.classSamples.put((String) entry.getKey(), new AbstractReservoirTrainTestSplitter.SampleInfo(((Long) entry.getValue()).longValue()));
        });
    }

    @Override // org.elasticsearch.xpack.ml.dataframe.traintestsplit.AbstractReservoirTrainTestSplitter
    protected AbstractReservoirTrainTestSplitter.SampleInfo getSampleInfo(String[] strArr) {
        String str = strArr[this.dependentVariableIndex];
        AbstractReservoirTrainTestSplitter.SampleInfo sampleInfo = this.classSamples.get(str);
        if (sampleInfo == null) {
            throw new IllegalStateException("Unknown class [" + str + "]; expected one of " + this.classSamples.keySet());
        }
        return sampleInfo;
    }

    @Override // org.elasticsearch.xpack.ml.dataframe.traintestsplit.AbstractReservoirTrainTestSplitter, org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter
    public /* bridge */ /* synthetic */ boolean isTraining(String[] strArr) {
        return super.isTraining(strArr);
    }
}
