package org.elasticsearch.xpack.ml.inference.assignment;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.component.LifecycleListener;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskAwareRequest;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfoUpdate;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
import org.elasticsearch.xpack.ml.inference.deployment.ModelStats;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
import org.elasticsearch.xpack.ml.task.AbstractJobPersistentTasksExecutor;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.class */
public class TrainedModelAssignmentNodeService implements ClusterStateListener {
    private static final String NODE_NO_LONGER_REFERENCED = "node no longer referenced in model routing table";
    private static final String ASSIGNMENT_NO_LONGER_EXISTS = "model assignment no longer exists";
    private static final TimeValue MODEL_LOADING_CHECK_INTERVAL = TimeValue.timeValueSeconds(1);
    private static final TimeValue UPDATE_NUMBER_OF_ALLOCATIONS_TIMEOUT = TimeValue.timeValueSeconds(60);
    private static final Logger logger = LogManager.getLogger(TrainedModelAssignmentNodeService.class);
    private final TrainedModelAssignmentService trainedModelAssignmentService;
    private final DeploymentManager deploymentManager;
    private final TaskManager taskManager;
    private final ThreadPool threadPool;
    private final XPackLicenseState licenseState;
    private final IndexNameExpressionResolver expressionResolver;
    private volatile Scheduler.Cancellable scheduledFuture;
    private volatile ClusterState latestState;
    private volatile boolean stopped;
    private volatile String nodeId;
    private final Map<String, TrainedModelDeploymentTask> modelIdToTask = new ConcurrentHashMap();
    private final Deque<TrainedModelDeploymentTask> loadingModels = new ConcurrentLinkedDeque();

    public TrainedModelAssignmentNodeService(TrainedModelAssignmentService trainedModelAssignmentService, final ClusterService clusterService, DeploymentManager deploymentManager, IndexNameExpressionResolver indexNameExpressionResolver, TaskManager taskManager, ThreadPool threadPool, XPackLicenseState xPackLicenseState) {
        this.trainedModelAssignmentService = trainedModelAssignmentService;
        this.deploymentManager = deploymentManager;
        this.taskManager = taskManager;
        this.threadPool = threadPool;
        this.licenseState = xPackLicenseState;
        clusterService.addLifecycleListener(new LifecycleListener() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentNodeService.1
            public void afterStart() {
                TrainedModelAssignmentNodeService.this.nodeId = clusterService.localNode().getId();
                TrainedModelAssignmentNodeService.this.start();
            }

            public void beforeStop() {
                TrainedModelAssignmentNodeService.this.stop();
            }
        });
        this.expressionResolver = indexNameExpressionResolver;
    }

    TrainedModelAssignmentNodeService(TrainedModelAssignmentService trainedModelAssignmentService, ClusterService clusterService, DeploymentManager deploymentManager, IndexNameExpressionResolver indexNameExpressionResolver, TaskManager taskManager, ThreadPool threadPool, String str, XPackLicenseState xPackLicenseState) {
        this.trainedModelAssignmentService = trainedModelAssignmentService;
        this.deploymentManager = deploymentManager;
        this.taskManager = taskManager;
        this.threadPool = threadPool;
        this.nodeId = str;
        this.licenseState = xPackLicenseState;
        clusterService.addLifecycleListener(new LifecycleListener() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentNodeService.2
            public void afterStart() {
                TrainedModelAssignmentNodeService.this.start();
            }

            public void beforeStop() {
                TrainedModelAssignmentNodeService.this.stop();
            }
        });
        this.expressionResolver = indexNameExpressionResolver;
    }

    void stopDeploymentAsync(TrainedModelDeploymentTask trainedModelDeploymentTask, String str, ActionListener<Void> actionListener) {
        if (this.stopped) {
            return;
        }
        trainedModelDeploymentTask.markAsStopped(str);
        this.threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
            try {
                this.deploymentManager.stopDeployment(trainedModelDeploymentTask);
                this.taskManager.unregister(trainedModelDeploymentTask);
                this.modelIdToTask.remove(trainedModelDeploymentTask.getModelId());
                actionListener.onResponse((Object) null);
            } catch (Exception e) {
                actionListener.onFailure(e);
            }
        });
    }

    public void start() {
        this.stopped = false;
        this.scheduledFuture = this.threadPool.scheduleWithFixedDelay(this::loadQueuedModels, MODEL_LOADING_CHECK_INTERVAL, MachineLearning.UTILITY_THREAD_POOL_NAME);
    }

    public void stop() {
        this.stopped = true;
        Scheduler.Cancellable cancellable = this.scheduledFuture;
        if (cancellable != null) {
            cancellable.cancel();
        }
    }

    void loadQueuedModels() {
        if (this.loadingModels.isEmpty()) {
            return;
        }
        if (this.latestState != null) {
            List<String> verifyIndicesPrimaryShardsAreActive = AbstractJobPersistentTasksExecutor.verifyIndicesPrimaryShardsAreActive(this.latestState, this.expressionResolver, true, ".ml-inference-*", InferenceIndexConstants.nativeDefinitionStore());
            if (verifyIndicesPrimaryShardsAreActive.size() > 0) {
                logger.trace("not loading models as indices {} primary shards are unassigned", verifyIndicesPrimaryShardsAreActive);
                return;
            }
        }
        logger.trace("attempting to load all currently queued models");
        ArrayDeque arrayDeque = new ArrayDeque();
        while (true) {
            TrainedModelDeploymentTask poll = this.loadingModels.poll();
            if (poll == null) {
                this.loadingModels.addAll(arrayDeque);
                return;
            }
            String modelId = poll.getModelId();
            if (poll.isStopped()) {
                if (logger.isTraceEnabled()) {
                    logger.trace("[{}] attempted to load stopped task with reason [{}]", modelId, poll.stoppedReason().orElse("_unknown_"));
                }
            } else {
                if (this.stopped) {
                    return;
                }
                logger.trace(() -> {
                    return "[" + modelId + "] attempting to load model";
                });
                ActionListener<TrainedModelDeploymentTask> plainActionFuture = new PlainActionFuture<>();
                try {
                    this.deploymentManager.startDeployment(poll, plainActionFuture);
                    handleLoadSuccess((TrainedModelDeploymentTask) plainActionFuture.actionGet());
                } catch (Exception e) {
                    logger.warn(() -> {
                        return "[" + modelId + "] Start deployment failed";
                    }, e);
                    if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
                        logger.debug(() -> {
                            return "[" + modelId + "] Start deployment failed as model was not found";
                        }, e);
                        handleLoadFailure(poll, ExceptionsHelper.missingTrainedModel(modelId, e));
                    } else if (ExceptionsHelper.unwrapCause(e) instanceof SearchPhaseExecutionException) {
                        logger.debug(() -> {
                            return "[" + modelId + "] Start deployment failed, will retry";
                        }, e);
                        arrayDeque.add(poll);
                    } else {
                        handleLoadFailure(poll, e);
                    }
                }
            }
        }
    }

    public void stopDeploymentAndNotify(TrainedModelDeploymentTask trainedModelDeploymentTask, String str, ActionListener<AcknowledgedResponse> actionListener) {
        RoutingInfoUpdate updateStateAndReason = RoutingInfoUpdate.updateStateAndReason(new RoutingStateAndReason(RoutingState.STOPPED, str));
        ActionListener wrap = ActionListener.wrap(r9 -> {
            updateStoredState(trainedModelDeploymentTask.getModelId(), updateStateAndReason, actionListener);
        }, exc -> {
            logger.warn(() -> {
                return "[" + trainedModelDeploymentTask.getModelId() + "] failed to stop due to error";
            }, exc);
            updateStoredState(trainedModelDeploymentTask.getModelId(), updateStateAndReason, actionListener);
        });
        updateStoredState(trainedModelDeploymentTask.getModelId(), RoutingInfoUpdate.updateStateAndReason(new RoutingStateAndReason(RoutingState.STOPPING, str)), ActionListener.wrap(acknowledgedResponse -> {
            stopDeploymentAsync(trainedModelDeploymentTask, str, wrap);
        }, exc2 -> {
            if (ExceptionsHelper.unwrapCause(exc2) instanceof ResourceNotFoundException) {
                logger.debug(() -> {
                    return Strings.format("[%s] failed to set routing state to stopping as assignment already removed", new Object[]{trainedModelDeploymentTask.getModelId()});
                }, exc2);
            } else {
                logger.warn(() -> {
                    return "[" + trainedModelDeploymentTask.getModelId() + "] failed to set routing state to stopping due to error";
                }, exc2);
            }
            stopDeploymentAsync(trainedModelDeploymentTask, str, wrap);
        }));
    }

    public void infer(TrainedModelDeploymentTask trainedModelDeploymentTask, InferenceConfig inferenceConfig, Map<String, Object> map, boolean z, TimeValue timeValue, Task task, ActionListener<InferenceResults> actionListener) {
        this.deploymentManager.infer(trainedModelDeploymentTask, inferenceConfig, map, z, timeValue, task, actionListener);
    }

    public Optional<ModelStats> modelStats(TrainedModelDeploymentTask trainedModelDeploymentTask) {
        return this.deploymentManager.getStats(trainedModelDeploymentTask);
    }

    private TaskAwareRequest taskAwareRequest(final StartTrainedModelDeploymentAction.TaskParams taskParams) {
        return new TaskAwareRequest() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentNodeService.3
            public void setParentTask(TaskId taskId) {
                throw new UnsupportedOperationException("parent task id for model assignment tasks shouldn't change");
            }

            public TaskId getParentTask() {
                return TaskId.EMPTY_TASK_ID;
            }

            public Task createTask(long j, String str, String str2, TaskId taskId, Map<String, String> map) {
                return new TrainedModelDeploymentTask(j, str, str2, taskId, map, taskParams, this, TrainedModelAssignmentNodeService.this.licenseState, MachineLearning.ML_PYTORCH_MODEL_INFERENCE_FEATURE);
            }
        };
    }

    public void clusterChanged(ClusterChangedEvent clusterChangedEvent) {
        TrainedModelDeploymentTask remove;
        this.latestState = clusterChangedEvent.state();
        if (clusterChangedEvent.metadataChanged()) {
            boolean isResetMode = MlMetadata.getMlMetadata(clusterChangedEvent.state()).isResetMode();
            TrainedModelAssignmentMetadata fromState = TrainedModelAssignmentMetadata.fromState(clusterChangedEvent.state());
            String localNodeId = clusterChangedEvent.state().nodes().getLocalNodeId();
            boolean onOrAfter = clusterChangedEvent.state().getNodes().getMinNodeVersion().onOrAfter(TrainedModelAssignmentClusterService.DISTRIBUTED_MODEL_ALLOCATION_VERSION);
            if (!isResetMode && onOrAfter) {
                updateNumberOfAllocations(fromState);
            }
            for (TrainedModelAssignment trainedModelAssignment : fromState.modelAssignments().values()) {
                RoutingInfo routingInfo = (RoutingInfo) trainedModelAssignment.getNodeRoutingTable().get(localNodeId);
                if (routingInfo != null && onOrAfter) {
                    if (routingInfo.getState() == RoutingState.STARTING && this.modelIdToTask.containsKey(trainedModelAssignment.getModelId()) && this.modelIdToTask.get(trainedModelAssignment.getModelId()).isFailed()) {
                        this.taskManager.unregister(this.modelIdToTask.get(trainedModelAssignment.getModelId()));
                        this.modelIdToTask.remove(trainedModelAssignment.getModelId());
                    }
                    if (routingInfo.getState().isAnyOf(new RoutingState[]{RoutingState.STARTING, RoutingState.STARTED}) && !this.modelIdToTask.containsKey(trainedModelAssignment.getTaskParams().getModelId()) && !isResetMode) {
                        prepareModelToLoad(new StartTrainedModelDeploymentAction.TaskParams(trainedModelAssignment.getModelId(), trainedModelAssignment.getTaskParams().getModelBytes(), trainedModelAssignment.getTaskParams().getThreadsPerAllocation(), routingInfo.getCurrentAllocations(), trainedModelAssignment.getTaskParams().getQueueCapacity(), (ByteSizeValue) trainedModelAssignment.getTaskParams().getCacheSize().orElse(null)));
                    }
                }
                if (routingInfo == null && (remove = this.modelIdToTask.remove(trainedModelAssignment.getTaskParams().getModelId())) != null) {
                    stopDeploymentAsync(remove, NODE_NO_LONGER_REFERENCED, ActionListener.wrap(r4 -> {
                        logger.trace(() -> {
                            return "[" + remove.getModelId() + "] stopped deployment";
                        });
                    }, exc -> {
                        logger.warn(() -> {
                            return "[" + remove.getModelId() + "] failed to fully stop deployment";
                        }, exc);
                    }));
                }
            }
            ArrayList<TrainedModelDeploymentTask> arrayList = new ArrayList();
            Iterator it = Sets.difference(this.modelIdToTask.keySet(), fromState.modelAssignments().keySet()).iterator();
            while (it.hasNext()) {
                arrayList.add(this.modelIdToTask.remove((String) it.next()));
            }
            for (TrainedModelDeploymentTask trainedModelDeploymentTask : arrayList) {
                stopDeploymentAsync(trainedModelDeploymentTask, ASSIGNMENT_NO_LONGER_EXISTS, ActionListener.wrap(r42 -> {
                    logger.trace(() -> {
                        return "[" + trainedModelDeploymentTask.getModelId() + "] stopped deployment";
                    });
                }, exc2 -> {
                    logger.warn(() -> {
                        return "[" + trainedModelDeploymentTask.getModelId() + "] failed to fully stop deployment";
                    }, exc2);
                }));
            }
        }
    }

    private void updateNumberOfAllocations(TrainedModelAssignmentMetadata trainedModelAssignmentMetadata) {
        for (TrainedModelAssignment trainedModelAssignment : trainedModelAssignmentMetadata.modelAssignments().values().stream().filter(trainedModelAssignment2 -> {
            return !hasStartingAssignments(trainedModelAssignment2);
        }).filter(trainedModelAssignment3 -> {
            return trainedModelAssignment3.isRoutedToNode(this.nodeId);
        }).filter(trainedModelAssignment4 -> {
            RoutingInfo routingInfo = (RoutingInfo) trainedModelAssignment4.getNodeRoutingTable().get(this.nodeId);
            return routingInfo.getState() == RoutingState.STARTED && routingInfo.getCurrentAllocations() != routingInfo.getTargetAllocations();
        }).toList()) {
            TrainedModelDeploymentTask trainedModelDeploymentTask = this.modelIdToTask.get(trainedModelAssignment.getModelId());
            if (trainedModelDeploymentTask == null) {
                logger.debug(() -> {
                    return Strings.format("[%s] task was removed whilst updating number of allocations", new Object[]{trainedModelAssignment.getModelId()});
                });
            } else {
                RoutingInfo routingInfo = (RoutingInfo) trainedModelAssignment.getNodeRoutingTable().get(this.nodeId);
                this.deploymentManager.updateNumAllocations(trainedModelDeploymentTask, ((RoutingInfo) trainedModelAssignment.getNodeRoutingTable().get(this.nodeId)).getTargetAllocations(), UPDATE_NUMBER_OF_ALLOCATIONS_TIMEOUT, ActionListener.wrap(threadSettings -> {
                    logger.debug("[{}] Updated number of allocations to [{}]", trainedModelAssignment.getModelId(), Integer.valueOf(threadSettings.numAllocations()));
                    trainedModelDeploymentTask.updateNumberOfAllocations(threadSettings.numAllocations());
                    updateStoredState(trainedModelAssignment.getModelId(), RoutingInfoUpdate.updateNumberOfAllocations(threadSettings.numAllocations()), ActionListener.noop());
                }, exc -> {
                    logger.error(Strings.format("[%s] Could not update number of allocations to [%s]", new Object[]{trainedModelAssignment.getModelId(), Integer.valueOf(routingInfo.getTargetAllocations())}), exc);
                }));
            }
        }
    }

    private boolean hasStartingAssignments(TrainedModelAssignment trainedModelAssignment) {
        return trainedModelAssignment.getNodeRoutingTable().values().stream().anyMatch(routingInfo -> {
            return routingInfo.getState().isAnyOf(new RoutingState[]{RoutingState.STARTING});
        });
    }

    TrainedModelDeploymentTask getTask(String str) {
        return this.modelIdToTask.get(str);
    }

    void prepareModelToLoad(StartTrainedModelDeploymentAction.TaskParams taskParams) {
        logger.debug(() -> {
            return Strings.format("[%s] preparing to load model with task params: %s", new Object[]{taskParams.getModelId(), taskParams});
        });
        TrainedModelDeploymentTask register = this.taskManager.register(TrainedModelAssignmentMetadata.NAME, "xpack/ml/trained_model_assignment[n]", taskAwareRequest(taskParams));
        if (this.modelIdToTask.putIfAbsent(taskParams.getModelId(), register) == null) {
            this.loadingModels.add(register);
        } else {
            this.taskManager.unregister(register);
        }
    }

    private void handleLoadSuccess(TrainedModelDeploymentTask trainedModelDeploymentTask) {
        String modelId = trainedModelDeploymentTask.getModelId();
        logger.debug(() -> {
            return "[" + modelId + "] model successfully loaded and ready for inference. Notifying master node";
        });
        if (trainedModelDeploymentTask.isStopped()) {
            logger.debug(() -> {
                return Strings.format("[%s] model loaded successfully, but stopped before routing table was updated; reason [%s]", new Object[]{modelId, trainedModelDeploymentTask.stoppedReason().orElse("_unknown_")});
            });
        } else {
            updateStoredState(modelId, RoutingInfoUpdate.updateStateAndReason(new RoutingStateAndReason(RoutingState.STARTED, "")), ActionListener.wrap(acknowledgedResponse -> {
                logger.debug(() -> {
                    return "[" + modelId + "] model loaded and accepting routes";
                });
            }, exc -> {
                if (ExceptionsHelper.unwrapCause(exc) instanceof ResourceNotFoundException) {
                    logger.debug(() -> {
                        return Strings.format("[%s] model loaded but failed to start accepting routes as assignment to this node was removed", new Object[]{modelId});
                    }, exc);
                } else {
                    logger.warn(() -> {
                        return "[" + modelId + "] model loaded but failed to start accepting routes";
                    }, exc);
                }
            }));
        }
    }

    private void updateStoredState(String str, RoutingInfoUpdate routingInfoUpdate, ActionListener<AcknowledgedResponse> actionListener) {
        if (this.stopped) {
            return;
        }
        this.trainedModelAssignmentService.updateModelAssignmentState(new UpdateTrainedModelAssignmentRoutingInfoAction.Request(this.nodeId, str, routingInfoUpdate), ActionListener.wrap(acknowledgedResponse -> {
            logger.debug(() -> {
                return Strings.format("[%s] model routing info was updated with [%s] and master notified", new Object[]{str, routingInfoUpdate});
            });
            actionListener.onResponse(AcknowledgedResponse.TRUE);
        }, exc -> {
            logger.warn(() -> {
                return Strings.format("[%s] failed to update model routing info with [%s]", new Object[]{str, routingInfoUpdate});
            }, exc);
            actionListener.onFailure(exc);
        }));
    }

    private void handleLoadFailure(TrainedModelDeploymentTask trainedModelDeploymentTask, Exception exc) {
        logger.error(() -> {
            return "[" + trainedModelDeploymentTask.getModelId() + "] model failed to load";
        }, exc);
        if (trainedModelDeploymentTask.isStopped()) {
            logger.debug(() -> {
                return Strings.format("[%s] model failed to load, but is now stopped; reason [%s]", new Object[]{trainedModelDeploymentTask.getModelId(), trainedModelDeploymentTask.stoppedReason().orElse("_unknown_")});
            });
        }
        Runnable runnable = () -> {
            stopDeploymentAsync(trainedModelDeploymentTask, "model failed to load; reason [" + exc.getMessage() + "]", ActionListener.noop());
        };
        updateStoredState(trainedModelDeploymentTask.getModelId(), RoutingInfoUpdate.updateStateAndReason(new RoutingStateAndReason(RoutingState.FAILED, ExceptionsHelper.unwrapCause(exc).getMessage())), ActionListener.wrap(acknowledgedResponse -> {
            runnable.run();
        }, exc2 -> {
            runnable.run();
        }));
    }

    public void failAssignment(TrainedModelDeploymentTask trainedModelDeploymentTask, String str) {
        updateStoredState(trainedModelDeploymentTask.getModelId(), RoutingInfoUpdate.updateStateAndReason(new RoutingStateAndReason(RoutingState.FAILED, str)), ActionListener.wrap(acknowledgedResponse -> {
            logger.debug(() -> {
                return Strings.format("[%s] Successfully updating assignment state to [%s] with reason [%s]", new Object[]{trainedModelDeploymentTask.getModelId(), RoutingState.FAILED, str});
            });
        }, exc -> {
            logger.error(() -> {
                return Strings.format("[%s] Error while updating assignment state to [%s] with reason [%s]", new Object[]{trainedModelDeploymentTask.getModelId(), RoutingState.FAILED, str});
            }, exc);
        }));
    }
}
