package org.elasticsearch.xpack.ml.inference.deployment;

import java.io.IOException;
import java.time.Instant;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.IdsQueryBuilder;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
import org.elasticsearch.xpack.ml.inference.pytorch.PriorityProcessWorkerExecutorService;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcess;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchStateStreamer;
import org.elasticsearch.xpack.ml.inference.pytorch.results.ThreadSettings;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.class */
public class DeploymentManager {
    private static final Logger logger;
    private static final AtomicLong requestIdCounter;
    private final Client client;
    private final NamedXContentRegistry xContentRegistry;
    private final PyTorchProcessFactory pyTorchProcessFactory;
    private final ExecutorService executorServiceForDeployment;
    private final ExecutorService executorServiceForProcess;
    private final ThreadPool threadPool;
    private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap();
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager$ProcessContext.class */
    public class ProcessContext {
        private final TrainedModelDeploymentTask task;
        private final PyTorchResultProcessor resultProcessor;
        private final PyTorchStateStreamer stateStreamer;
        private final PriorityProcessWorkerExecutorService executorService;
        private volatile Instant startTime;
        private volatile Integer numThreadsPerAllocation;
        private volatile Integer numAllocations;
        private final SetOnce<PyTorchProcess> process = new SetOnce<>();
        private final SetOnce<NlpTask.Processor> nlpTaskProcessor = new SetOnce<>();
        private final SetOnce<TrainedModelInput> modelInput = new SetOnce<>();
        private final AtomicInteger rejectedExecutionCount = new AtomicInteger();
        private final AtomicInteger timeoutCount = new AtomicInteger();

        ProcessContext(TrainedModelDeploymentTask trainedModelDeploymentTask, ExecutorService executorService) {
            this.task = (TrainedModelDeploymentTask) Objects.requireNonNull(trainedModelDeploymentTask);
            this.resultProcessor = new PyTorchResultProcessor(trainedModelDeploymentTask.getModelId(), threadSettings -> {
                this.numThreadsPerAllocation = Integer.valueOf(threadSettings.numThreadsPerAllocation());
                this.numAllocations = Integer.valueOf(threadSettings.numAllocations());
            });
            this.stateStreamer = new PyTorchStateStreamer(DeploymentManager.this.client, executorService, DeploymentManager.this.xContentRegistry);
            this.executorService = new PriorityProcessWorkerExecutorService(DeploymentManager.this.threadPool.getThreadContext(), "inference process", trainedModelDeploymentTask.getParams().getQueueCapacity());
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public PyTorchResultProcessor getResultProcessor() {
            return this.resultProcessor;
        }

        synchronized void startProcess() {
            this.process.set(DeploymentManager.this.pyTorchProcessFactory.createProcess(this.task, DeploymentManager.this.executorServiceForProcess, onProcessCrash()));
            this.startTime = Instant.now();
            ExecutorService executorService = DeploymentManager.this.executorServiceForProcess;
            PriorityProcessWorkerExecutorService priorityProcessWorkerExecutorService = this.executorService;
            Objects.requireNonNull(priorityProcessWorkerExecutorService);
            executorService.submit(priorityProcessWorkerExecutorService::start);
        }

        synchronized void stopProcess() {
            this.resultProcessor.stop();
            this.executorService.shutdown();
            try {
                try {
                } catch (IOException e) {
                    DeploymentManager.logger.error(() -> {
                        return "[" + this.task.getModelId() + "] Failed to kill process";
                    }, e);
                    if (this.nlpTaskProcessor.get() != null) {
                        ((NlpTask.Processor) this.nlpTaskProcessor.get()).close();
                    }
                }
                if (this.process.get() == null) {
                    if (this.nlpTaskProcessor.get() != null) {
                        ((NlpTask.Processor) this.nlpTaskProcessor.get()).close();
                    }
                } else {
                    this.stateStreamer.cancel();
                    ((PyTorchProcess) this.process.get()).kill(true);
                    DeploymentManager.this.processContextByAllocation.remove(Long.valueOf(this.task.getId()));
                    if (this.nlpTaskProcessor.get() != null) {
                        ((NlpTask.Processor) this.nlpTaskProcessor.get()).close();
                    }
                }
            } catch (Throwable th) {
                if (this.nlpTaskProcessor.get() != null) {
                    ((NlpTask.Processor) this.nlpTaskProcessor.get()).close();
                }
                throw th;
            }
        }

        private Consumer<String> onProcessCrash() {
            return str -> {
                DeploymentManager.logger.error("[{}] inference process crashed due to reason [{}]", this.task.getModelId(), str);
                this.resultProcessor.stop();
                this.executorService.shutdownWithError(new IllegalStateException(str));
                DeploymentManager.this.processContextByAllocation.remove(Long.valueOf(this.task.getId()));
                if (this.nlpTaskProcessor.get() != null) {
                    ((NlpTask.Processor) this.nlpTaskProcessor.get()).close();
                }
                this.task.setFailed("inference process crashed due to reason [" + str + "]");
            };
        }

        void loadModel(TrainedModelLocation trainedModelLocation, ActionListener<Boolean> actionListener) {
            if (trainedModelLocation instanceof IndexLocation) {
                ((PyTorchProcess) this.process.get()).loadModel(this.task.getModelId(), ((IndexLocation) trainedModelLocation).getIndexName(), this.stateStreamer, actionListener);
            } else {
                actionListener.onFailure(new IllegalStateException("unsupported trained model location [" + trainedModelLocation.getClass().getSimpleName() + "]"));
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public AtomicInteger getTimeoutCount() {
            return this.timeoutCount;
        }

        PriorityProcessWorkerExecutorService getExecutorService() {
            return this.executorService;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public AtomicInteger getRejectedExecutionCount() {
            return this.rejectedExecutionCount;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public SetOnce<TrainedModelInput> getModelInput() {
            return this.modelInput;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public SetOnce<PyTorchProcess> getProcess() {
            return this.process;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public SetOnce<NlpTask.Processor> getNlpTaskProcessor() {
            return this.nlpTaskProcessor;
        }
    }

    public DeploymentManager(Client client, NamedXContentRegistry namedXContentRegistry, ThreadPool threadPool, PyTorchProcessFactory pyTorchProcessFactory) {
        this.client = (Client) Objects.requireNonNull(client);
        this.xContentRegistry = (NamedXContentRegistry) Objects.requireNonNull(namedXContentRegistry);
        this.pyTorchProcessFactory = (PyTorchProcessFactory) Objects.requireNonNull(pyTorchProcessFactory);
        this.threadPool = (ThreadPool) Objects.requireNonNull(threadPool);
        this.executorServiceForDeployment = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME);
        this.executorServiceForProcess = threadPool.executor(MachineLearning.NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME);
    }

    public void startDeployment(TrainedModelDeploymentTask trainedModelDeploymentTask, ActionListener<TrainedModelDeploymentTask> actionListener) {
        doStartDeployment(trainedModelDeploymentTask, actionListener);
    }

    public Optional<ModelStats> getStats(TrainedModelDeploymentTask trainedModelDeploymentTask) {
        return Optional.ofNullable(this.processContextByAllocation.get(Long.valueOf(trainedModelDeploymentTask.getId()))).map(processContext -> {
            PyTorchResultProcessor.ResultStats resultStats = processContext.getResultProcessor().getResultStats();
            PyTorchResultProcessor.RecentStats recentStats = resultStats.recentStats();
            return new ModelStats(processContext.startTime, resultStats.timingStats(), resultStats.lastUsed(), processContext.executorService.queueSize() + resultStats.numberOfPendingResults(), resultStats.errorCount(), resultStats.cacheHitCount(), processContext.rejectedExecutionCount.intValue(), processContext.timeoutCount.intValue(), processContext.numThreadsPerAllocation, processContext.numAllocations, resultStats.peakThroughput(), recentStats.requestsProcessed(), recentStats.avgInferenceTime(), recentStats.cacheHitCount());
        });
    }

    ProcessContext addProcessContext(Long l, ProcessContext processContext) {
        return this.processContextByAllocation.putIfAbsent(l, processContext);
    }

    private void doStartDeployment(TrainedModelDeploymentTask trainedModelDeploymentTask, ActionListener<TrainedModelDeploymentTask> actionListener) {
        logger.info("[{}] Starting model deployment", trainedModelDeploymentTask.getModelId());
        ProcessContext processContext = new ProcessContext(trainedModelDeploymentTask, this.executorServiceForProcess);
        if (addProcessContext(Long.valueOf(trainedModelDeploymentTask.getId()), processContext) != null) {
            actionListener.onFailure(ExceptionsHelper.serverError("[{}] Could not create inference process as one already exists", new Object[]{trainedModelDeploymentTask.getModelId()}));
            return;
        }
        Objects.requireNonNull(actionListener);
        ActionListener wrap = ActionListener.wrap((v1) -> {
            r0.onResponse(v1);
        }, exc -> {
            this.processContextByAllocation.remove(Long.valueOf(trainedModelDeploymentTask.getId()));
            actionListener.onFailure(exc);
        });
        CheckedConsumer checkedConsumer = bool -> {
            this.executorServiceForProcess.execute(() -> {
                processContext.getResultProcessor().process((PyTorchProcess) processContext.process.get());
            });
            wrap.onResponse(trainedModelDeploymentTask);
        };
        Objects.requireNonNull(wrap);
        ActionListener wrap2 = ActionListener.wrap(checkedConsumer, wrap::onFailure);
        CheckedConsumer checkedConsumer2 = response -> {
            if (!$assertionsDisabled && response.getResources().results().size() != 1) {
                throw new AssertionError();
            }
            TrainedModelConfig trainedModelConfig = (TrainedModelConfig) response.getResources().results().get(0);
            processContext.modelInput.set(trainedModelConfig.getInput());
            NlpConfig inferenceConfig = trainedModelConfig.getInferenceConfig();
            if (!(inferenceConfig instanceof NlpConfig)) {
                wrap.onFailure(new IllegalArgumentException(Strings.format("[%s] must be a pytorch model; found inference config of kind [%s]", new Object[]{trainedModelConfig.getModelId(), trainedModelConfig.getInferenceConfig().getWriteableName()})));
                return;
            }
            NlpConfig nlpConfig = inferenceConfig;
            trainedModelDeploymentTask.init(nlpConfig);
            SearchRequest vocabSearchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), trainedModelConfig.getModelId());
            Client client = this.client;
            SearchAction searchAction = SearchAction.INSTANCE;
            CheckedConsumer checkedConsumer3 = searchResponse -> {
                if (searchResponse.getHits().getHits().length == 0) {
                    wrap.onFailure(new ResourceNotFoundException(Messages.getMessage("Could not find vocabulary document [{1}] for trained model [{0}]", new Object[]{trainedModelDeploymentTask.getModelId(), VocabularyConfig.docId(trainedModelConfig.getModelId())}), new Object[0]));
                } else {
                    processContext.nlpTaskProcessor.set(new NlpTask(nlpConfig, parseVocabularyDocLeniently(searchResponse.getHits().getAt(0))).createProcessor());
                    this.executorServiceForDeployment.execute(() -> {
                        startAndLoad(processContext, trainedModelConfig.getLocation(), wrap2);
                    });
                }
            };
            Objects.requireNonNull(wrap);
            ClientHelper.executeAsyncWithOrigin(client, "ml", searchAction, vocabSearchRequest, ActionListener.wrap(checkedConsumer3, wrap::onFailure));
        };
        Objects.requireNonNull(wrap);
        ClientHelper.executeAsyncWithOrigin(this.client, "ml", GetTrainedModelsAction.INSTANCE, new GetTrainedModelsAction.Request(trainedModelDeploymentTask.getModelId()), ActionListener.wrap(checkedConsumer2, wrap::onFailure));
    }

    private SearchRequest vocabSearchRequest(VocabularyConfig vocabularyConfig, String str) {
        return this.client.prepareSearch(new String[]{vocabularyConfig.getIndex()}).setQuery(new IdsQueryBuilder().addIds(new String[]{VocabularyConfig.docId(str)})).setSize(1).setTrackTotalHits(false).request();
    }

    Vocabulary parseVocabularyDocLeniently(SearchHit searchHit) throws IOException {
        try {
            StreamInput streamInput = searchHit.getSourceRef().streamInput();
            try {
                XContentParser createParser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY.withRegistry(this.xContentRegistry).withDeprecationHandler(LoggingDeprecationHandler.INSTANCE), streamInput);
                try {
                    Vocabulary vocabulary = (Vocabulary) Vocabulary.PARSER.apply(createParser, (Object) null);
                    if (createParser != null) {
                        createParser.close();
                    }
                    if (streamInput != null) {
                        streamInput.close();
                    }
                    return vocabulary;
                } catch (Throwable th) {
                    if (createParser != null) {
                        try {
                            createParser.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (IOException e) {
            logger.error(() -> {
                return "failed to parse trained model vocabulary [" + searchHit.getId() + "]";
            }, e);
            throw e;
        }
    }

    private void startAndLoad(ProcessContext processContext, TrainedModelLocation trainedModelLocation, ActionListener<Boolean> actionListener) {
        try {
            processContext.startProcess();
            processContext.loadModel(trainedModelLocation, actionListener);
        } catch (Exception e) {
            actionListener.onFailure(e);
        }
    }

    public void stopDeployment(TrainedModelDeploymentTask trainedModelDeploymentTask) {
        ProcessContext processContext;
        synchronized (this.processContextByAllocation) {
            processContext = this.processContextByAllocation.get(Long.valueOf(trainedModelDeploymentTask.getId()));
        }
        if (processContext == null) {
            logger.warn("[{}] No process context to stop", trainedModelDeploymentTask.getModelId());
        } else {
            logger.info("[{}] Stopping deployment, reason [{}]", trainedModelDeploymentTask.getModelId(), trainedModelDeploymentTask.stoppedReason().orElse("unknown"));
            processContext.stopProcess();
        }
    }

    public void infer(TrainedModelDeploymentTask trainedModelDeploymentTask, InferenceConfig inferenceConfig, Map<String, Object> map, boolean z, TimeValue timeValue, Task task, ActionListener<InferenceResults> actionListener) {
        Objects.requireNonNull(actionListener);
        ProcessContext processContext = getProcessContext(trainedModelDeploymentTask, actionListener::onFailure);
        if (processContext == null) {
            return;
        }
        executePyTorchAction(processContext, z ? PriorityProcessWorkerExecutorService.RequestPriority.HIGH : PriorityProcessWorkerExecutorService.RequestPriority.NORMAL, new InferencePyTorchAction(trainedModelDeploymentTask.getModelId(), requestIdCounter.getAndIncrement(), timeValue, processContext, inferenceConfig, map, this.threadPool, task, actionListener));
    }

    public void updateNumAllocations(TrainedModelDeploymentTask trainedModelDeploymentTask, int i, TimeValue timeValue, ActionListener<ThreadSettings> actionListener) {
        Objects.requireNonNull(actionListener);
        ProcessContext processContext = getProcessContext(trainedModelDeploymentTask, actionListener::onFailure);
        if (processContext == null) {
            return;
        }
        executePyTorchAction(processContext, PriorityProcessWorkerExecutorService.RequestPriority.HIGHEST, new ControlMessagePyTorchAction(trainedModelDeploymentTask.getModelId(), requestIdCounter.getAndIncrement(), i, timeValue, processContext, this.threadPool, actionListener));
    }

    public void executePyTorchAction(ProcessContext processContext, PriorityProcessWorkerExecutorService.RequestPriority requestPriority, AbstractPyTorchAction<?> abstractPyTorchAction) {
        try {
            processContext.getExecutorService().executeWithPriority(abstractPyTorchAction, requestPriority, abstractPyTorchAction.getRequestId());
        } catch (EsRejectedExecutionException e) {
            processContext.getRejectedExecutionCount().incrementAndGet();
            abstractPyTorchAction.onFailure((Exception) e);
        } catch (Exception e2) {
            abstractPyTorchAction.onFailure(e2);
        }
    }

    private ProcessContext getProcessContext(TrainedModelDeploymentTask trainedModelDeploymentTask, Consumer<Exception> consumer) {
        if (trainedModelDeploymentTask.isStopped()) {
            consumer.accept(ExceptionsHelper.conflictStatusException("[{}] is stopping or stopped due to [{}]", new Object[]{trainedModelDeploymentTask.getModelId(), trainedModelDeploymentTask.stoppedReason().orElse("")}));
            return null;
        }
        ProcessContext processContext = this.processContextByAllocation.get(Long.valueOf(trainedModelDeploymentTask.getId()));
        if (processContext != null) {
            return processContext;
        }
        consumer.accept(ExceptionsHelper.conflictStatusException("[{}] process context missing", new Object[]{trainedModelDeploymentTask.getModelId()}));
        return null;
    }

    static {
        $assertionsDisabled = !DeploymentManager.class.desiredAssertionStatus();
        logger = LogManager.getLogger(DeploymentManager.class);
        requestIdCounter = new AtomicLong(1L);
    }
}
