package org.elasticsearch.xpack.ml.inference.persistence;

import java.io.IOException;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/persistence/ChunkedTrainedModelRestorer.class */
public class ChunkedTrainedModelRestorer {
    private static final Logger logger;
    private static final int MAX_NUM_DEFINITION_DOCS = 20;
    private final Client client;
    private final NamedXContentRegistry xContentRegistry;
    private final ExecutorService executorService;
    private final String modelId;
    private String index = ".ml-inference-*";
    private int searchSize = 10;
    private int numDocsWritten = 0;
    static final /* synthetic */ boolean $assertionsDisabled;

    public ChunkedTrainedModelRestorer(String str, Client client, ExecutorService executorService, NamedXContentRegistry namedXContentRegistry) {
        this.client = new OriginSettingClient(client, "ml");
        this.executorService = executorService;
        this.xContentRegistry = namedXContentRegistry;
        this.modelId = str;
    }

    public void setSearchSize(int i) {
        if (i > 20) {
            throw new IllegalArgumentException("search size [" + i + "] cannot be bigger than [20]");
        }
        if (i <= 0) {
            throw new IllegalArgumentException("search size [" + i + "] must be greater than 0");
        }
        this.searchSize = i;
    }

    public void setSearchIndex(String str) {
        this.index = str;
    }

    public int getNumDocsWritten() {
        return this.numDocsWritten;
    }

    public void restoreModelDefinition(CheckedFunction<TrainedModelDefinitionDoc, Boolean, IOException> checkedFunction, Consumer<Boolean> consumer, Consumer<Exception> consumer2) {
        logger.debug("[{}] restoring model", this.modelId);
        SearchRequest buildSearch = buildSearch(this.client, this.modelId, this.index, this.searchSize, null);
        this.executorService.execute(() -> {
            doSearch(buildSearch, checkedFunction, consumer, consumer2);
        });
    }

    private void doSearch(SearchRequest searchRequest, CheckedFunction<TrainedModelDefinitionDoc, Boolean, IOException> checkedFunction, Consumer<Boolean> consumer, Consumer<Exception> consumer2) {
        try {
            if (!$assertionsDisabled && !Thread.currentThread().getName().contains(MachineLearning.NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME) && !Thread.currentThread().getName().contains(MachineLearning.UTILITY_THREAD_POOL_NAME)) {
                throw new AssertionError(Strings.format("Must execute from [%s] or [%s] but thread is [%s]", new Object[]{MachineLearning.NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME, MachineLearning.UTILITY_THREAD_POOL_NAME, Thread.currentThread().getName()}));
            }
            SearchResponse searchResponse = (SearchResponse) this.client.search(searchRequest).actionGet();
            if (searchResponse.getHits().getHits().length == 0) {
                consumer2.accept(new ResourceNotFoundException(Messages.getMessage("Could not find trained model definition [{0}]", new Object[]{this.modelId}), new Object[0]));
                return;
            }
            int i = this.numDocsWritten - 1;
            for (SearchHit searchHit : searchResponse.getHits().getHits()) {
                try {
                    TrainedModelDefinitionDoc parseModelDefinitionDocLenientlyFromSource = parseModelDefinitionDocLenientlyFromSource(searchHit.getSourceRef(), this.modelId, this.xContentRegistry);
                    i = parseModelDefinitionDocLenientlyFromSource.getDocNum();
                    if (!((Boolean) checkedFunction.apply(parseModelDefinitionDocLenientlyFromSource)).booleanValue()) {
                        consumer.accept(Boolean.FALSE);
                        return;
                    }
                } catch (IOException e) {
                    logger.error(() -> {
                        return "[" + this.modelId + "] error writing model definition";
                    }, e);
                    consumer2.accept(e);
                    return;
                }
            }
            this.numDocsWritten += searchResponse.getHits().getHits().length;
            if (searchResponse.getHits().getHits().length < this.searchSize || searchResponse.getHits().getTotalHits().value == ((long) this.numDocsWritten)) {
                consumer.accept(Boolean.TRUE);
            } else {
                SearchHit at = searchResponse.getHits().getAt(searchResponse.getHits().getHits().length - 1);
                SearchRequestBuilder buildSearchBuilder = buildSearchBuilder(this.client, this.modelId, this.index, this.searchSize);
                buildSearchBuilder.searchAfter(new Object[]{at.getIndex(), Integer.valueOf(i)});
                this.executorService.execute(() -> {
                    doSearch((SearchRequest) buildSearchBuilder.request(), checkedFunction, consumer, consumer2);
                });
            }
        } catch (Exception e2) {
            if (ExceptionsHelper.unwrapCause(e2) instanceof ResourceNotFoundException) {
                consumer2.accept(new ResourceNotFoundException(Messages.getMessage("Could not find trained model definition [{0}]", new Object[]{this.modelId}), new Object[0]));
            } else {
                consumer2.accept(e2);
            }
        }
    }

    private static SearchRequestBuilder buildSearchBuilder(Client client, String str, String str2, int i) {
        return client.prepareSearch(new String[]{str2}).setQuery(QueryBuilders.constantScoreQuery(QueryBuilders.boolQuery().filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), str)).filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelDefinitionDoc.NAME)))).setSize(i).setTrackTotalHits(true).addSort("_index", SortOrder.DESC).addSort(SortBuilders.fieldSort(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName()).order(SortOrder.ASC).unmappedType("long"));
    }

    public static SearchRequest buildSearch(Client client, String str, String str2, int i, @Nullable TaskId taskId) {
        SearchRequest request = buildSearchBuilder(client, str, str2, i).request();
        if (taskId != null) {
            request.setParentTask(taskId);
        }
        return request;
    }

    public static TrainedModelDefinitionDoc parseModelDefinitionDocLenientlyFromSource(BytesReference bytesReference, String str, NamedXContentRegistry namedXContentRegistry) throws IOException {
        try {
            StreamInput streamInput = bytesReference.streamInput();
            try {
                XContentParser createParser = XContentFactory.xContent(XContentType.JSON).createParser(namedXContentRegistry, LoggingDeprecationHandler.INSTANCE, streamInput);
                try {
                    TrainedModelDefinitionDoc build = TrainedModelDefinitionDoc.fromXContent(createParser, true).build();
                    if (createParser != null) {
                        createParser.close();
                    }
                    if (streamInput != null) {
                        streamInput.close();
                    }
                    return build;
                } catch (Throwable th) {
                    if (createParser != null) {
                        try {
                            createParser.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (IOException e) {
            logger.error(() -> {
                return "[" + str + "] failed to parse model definition";
            }, e);
            throw e;
        }
    }

    static {
        $assertionsDisabled = !ChunkedTrainedModelRestorer.class.desiredAssertionStatus();
        logger = LogManager.getLogger(ChunkedTrainedModelRestorer.class);
    }
}
