package org.elasticsearch.xpack.ml.inference.assignment;

import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.ClusterStateUpdateTask;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.metadata.NodesShutdownMetadata;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.gateway.GatewayService;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.job.NodeLoad;
import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
import org.elasticsearch.xpack.ml.notifications.SystemAuditor;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.class */
public class TrainedModelAssignmentClusterService implements ClusterStateListener {
    private static final Logger logger = LogManager.getLogger(TrainedModelAssignmentClusterService.class);
    private static final Version RENAME_ALLOCATION_TO_ASSIGNMENT_VERSION = Version.V_8_3_0;
    public static final Version DISTRIBUTED_MODEL_ALLOCATION_VERSION = Version.V_8_4_0;
    private final ClusterService clusterService;
    private final ThreadPool threadPool;
    private final NodeLoadDetector nodeLoadDetector;
    private final SystemAuditor systemAuditor;
    private volatile int maxMemoryPercentage;
    private volatile boolean useAuto;
    private volatile int maxOpenJobs;

    public TrainedModelAssignmentClusterService(Settings settings, ClusterService clusterService, ThreadPool threadPool, NodeLoadDetector nodeLoadDetector, SystemAuditor systemAuditor) {
        this.clusterService = (ClusterService) Objects.requireNonNull(clusterService);
        this.threadPool = (ThreadPool) Objects.requireNonNull(threadPool);
        this.nodeLoadDetector = (NodeLoadDetector) Objects.requireNonNull(nodeLoadDetector);
        this.systemAuditor = (SystemAuditor) Objects.requireNonNull(systemAuditor);
        this.maxMemoryPercentage = ((Integer) MachineLearning.MAX_MACHINE_MEMORY_PERCENT.get(settings)).intValue();
        this.useAuto = ((Boolean) MachineLearning.USE_AUTO_MACHINE_MEMORY_PERCENT.get(settings)).booleanValue();
        this.maxOpenJobs = ((Integer) MachineLearning.MAX_OPEN_JOBS_PER_NODE.get(settings)).intValue();
        if (DiscoveryNode.isMasterNode(settings)) {
            clusterService.addListener(this);
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_MACHINE_MEMORY_PERCENT, (v1) -> {
                setMaxMemoryPercentage(v1);
            });
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.USE_AUTO_MACHINE_MEMORY_PERCENT, (v1) -> {
                setUseAuto(v1);
            });
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_OPEN_JOBS_PER_NODE, (v1) -> {
                setMaxOpenJobs(v1);
            });
        }
    }

    private void setMaxMemoryPercentage(int i) {
        this.maxMemoryPercentage = i;
    }

    private void setUseAuto(boolean z) {
        this.useAuto = z;
    }

    private void setMaxOpenJobs(int i) {
        this.maxOpenJobs = i;
    }

    @SuppressForbidden(reason = "legacy usage of unbatched task")
    private void submitUnbatchedTask(String str, ClusterStateUpdateTask clusterStateUpdateTask) {
        this.clusterService.submitUnbatchedStateUpdateTask(str, clusterStateUpdateTask);
    }

    public void clusterChanged(ClusterChangedEvent clusterChangedEvent) {
        if (!clusterChangedEvent.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK) && clusterChangedEvent.localNodeMaster()) {
            if (clusterChangedEvent.state().nodes().getMinNodeVersion().before(DISTRIBUTED_MODEL_ALLOCATION_VERSION)) {
                removeRoutingToRemovedOrShuttingDownNodes(clusterChangedEvent);
                return;
            }
            Optional<String> detectReasonToRebalanceModels = detectReasonToRebalanceModels(clusterChangedEvent);
            if (detectReasonToRebalanceModels.isPresent()) {
                rebalanceAssignments(clusterChangedEvent.state(), Optional.empty(), detectReasonToRebalanceModels.get(), ActionListener.wrap(trainedModelAssignmentMetadata -> {
                    logger.debug(() -> {
                        return Strings.format("rebalanced model assignments [%s]", new Object[]{org.elasticsearch.common.Strings.toString(trainedModelAssignmentMetadata, false, true)});
                    });
                }, exc -> {
                    logger.warn("failed to rebalance models", exc);
                }));
            }
        }
    }

    private void removeRoutingToRemovedOrShuttingDownNodes(ClusterChangedEvent clusterChangedEvent) {
        if (areAssignedNodesRemoved(clusterChangedEvent)) {
            submitUnbatchedTask("removing routing entries for removed or shutting down nodes", new ClusterStateUpdateTask() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService.1
                public ClusterState execute(ClusterState clusterState) {
                    return TrainedModelAssignmentClusterService.removeRoutingToUnassignableNodes(clusterState);
                }

                public void onFailure(Exception exc) {
                    TrainedModelAssignmentClusterService.logger.error("could not remove routing entries for removed or shutting down nodes", exc);
                }

                public void clusterStateProcessed(ClusterState clusterState, ClusterState clusterState2) {
                    TrainedModelAssignmentClusterService.logger.debug(() -> {
                        return Strings.format("updated model assignments based on node changes in the cluster; new metadata [%s]", new Object[]{org.elasticsearch.common.Strings.toString(TrainedModelAssignmentMetadata.fromState(clusterState2), false, true)});
                    });
                }
            });
        }
    }

    static boolean areAssignedNodesRemoved(ClusterChangedEvent clusterChangedEvent) {
        boolean contains = clusterChangedEvent.changedCustomMetadataSet().contains("node_shutdown");
        if (!clusterChangedEvent.nodesRemoved() && !contains) {
            return false;
        }
        HashSet hashSet = new HashSet(nodesShuttingDown(clusterChangedEvent.state()));
        Stream map = clusterChangedEvent.nodesDelta().removedNodes().stream().map((v0) -> {
            return v0.getId();
        });
        Objects.requireNonNull(hashSet);
        map.forEach((v1) -> {
            r1.add(v1);
        });
        Iterator<TrainedModelAssignment> it = TrainedModelAssignmentMetadata.fromState(clusterChangedEvent.state()).modelAssignments().values().iterator();
        while (it.hasNext()) {
            if (!Sets.intersection(hashSet, it.next().getNodeRoutingTable().keySet()).isEmpty()) {
                return true;
            }
        }
        return false;
    }

    static ClusterState removeRoutingToUnassignableNodes(ClusterState clusterState) {
        Set set = (Set) getAssignableNodes(clusterState).stream().map((v0) -> {
            return v0.getId();
        }).collect(Collectors.toSet());
        TrainedModelAssignmentMetadata fromState = TrainedModelAssignmentMetadata.fromState(clusterState);
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState);
        for (TrainedModelAssignment trainedModelAssignment : fromState.modelAssignments().values()) {
            Set difference = Sets.difference(trainedModelAssignment.getNodeRoutingTable().keySet(), set);
            if (!difference.isEmpty()) {
                logger.debug(() -> {
                    return Strings.format("[%s] removing routing entries to nodes {} because they have been removed or are shutting down", new Object[]{trainedModelAssignment.getModelId(), difference});
                });
                TrainedModelAssignment.Builder fromAssignment = TrainedModelAssignment.Builder.fromAssignment(trainedModelAssignment);
                Objects.requireNonNull(fromAssignment);
                difference.forEach(fromAssignment::removeRoutingEntry);
                builder.updateAssignment(trainedModelAssignment.getModelId(), fromAssignment.calculateAndSetAssignmentState());
            }
        }
        return update(clusterState, builder);
    }

    public void updateModelRoutingTable(final UpdateTrainedModelAssignmentRoutingInfoAction.Request request, final ActionListener<AcknowledgedResponse> actionListener) {
        logger.debug(() -> {
            return Strings.format("[%s] updating routing table entry for node [%s], update [%s]", new Object[]{request.getModelId(), request.getNodeId(), request.getUpdate()});
        });
        submitUnbatchedTask("updating model routing for node assignment", new ClusterStateUpdateTask() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService.2
            public ClusterState execute(ClusterState clusterState) {
                return TrainedModelAssignmentClusterService.updateModelRoutingTable(clusterState, request);
            }

            public void onFailure(Exception exc) {
                actionListener.onFailure(exc);
            }

            public void clusterStateProcessed(ClusterState clusterState, ClusterState clusterState2) {
                actionListener.onResponse(AcknowledgedResponse.TRUE);
            }
        });
    }

    public void createNewModelAssignment(StartTrainedModelDeploymentAction.TaskParams taskParams, ActionListener<TrainedModelAssignment> actionListener) {
        if (this.clusterService.state().nodes().getMinNodeVersion().before(DISTRIBUTED_MODEL_ALLOCATION_VERSION)) {
            actionListener.onFailure(new ElasticsearchStatusException("cannot create new assignment for model [{}] while there are nodes older than version [{}]", RestStatus.CONFLICT, new Object[]{taskParams.getModelId(), DISTRIBUTED_MODEL_ALLOCATION_VERSION}));
            return;
        }
        if (MlMetadata.getMlMetadata(this.clusterService.state()).isResetMode()) {
            actionListener.onFailure(new ElasticsearchStatusException("cannot create new assignment for model [{}] while feature reset is in progress.", RestStatus.CONFLICT, new Object[]{taskParams.getModelId()}));
            return;
        }
        ClusterState state = this.clusterService.state();
        Optional<StartTrainedModelDeploymentAction.TaskParams> of = Optional.of(taskParams);
        CheckedConsumer checkedConsumer = trainedModelAssignmentMetadata -> {
            TrainedModelAssignment modelAssignment = trainedModelAssignmentMetadata.getModelAssignment(taskParams.getModelId());
            if (modelAssignment == null) {
                modelAssignment = TrainedModelAssignment.Builder.empty(taskParams).build();
            }
            actionListener.onResponse(modelAssignment);
        };
        Objects.requireNonNull(actionListener);
        rebalanceAssignments(state, of, "model deployment started", ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    public void setModelAssignmentToStopping(final String str, final ActionListener<AcknowledgedResponse> actionListener) {
        submitUnbatchedTask("set model assignment stopping", new ClusterStateUpdateTask() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService.3
            public ClusterState execute(ClusterState clusterState) {
                return TrainedModelAssignmentClusterService.setToStopping(clusterState, str, "client API call");
            }

            public void onFailure(Exception exc) {
                actionListener.onFailure(exc);
            }

            public void clusterStateProcessed(ClusterState clusterState, ClusterState clusterState2) {
                actionListener.onResponse(AcknowledgedResponse.TRUE);
            }
        });
    }

    public void removeModelAssignment(final String str, final ActionListener<AcknowledgedResponse> actionListener) {
        submitUnbatchedTask("delete model assignment", new ClusterStateUpdateTask() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService.4
            public ClusterState execute(ClusterState clusterState) {
                return TrainedModelAssignmentClusterService.removeAssignment(clusterState, str);
            }

            public void onFailure(Exception exc) {
                actionListener.onFailure(exc);
            }

            public void clusterStateProcessed(ClusterState clusterState, ClusterState clusterState2) {
                TrainedModelAssignmentClusterService trainedModelAssignmentClusterService = TrainedModelAssignmentClusterService.this;
                Optional<StartTrainedModelDeploymentAction.TaskParams> empty = Optional.empty();
                String str2 = str;
                CheckedConsumer checkedConsumer = trainedModelAssignmentMetadata -> {
                    TrainedModelAssignmentClusterService.logger.debug(() -> {
                        return Strings.format("Successfully rebalanced model deployments after deployment for model [%s] was stopped", new Object[]{str2});
                    });
                };
                String str3 = str;
                trainedModelAssignmentClusterService.rebalanceAssignments(clusterState2, empty, "model deployment stopped", ActionListener.wrap(checkedConsumer, exc -> {
                    TrainedModelAssignmentClusterService.logger.error(Strings.format("Failed to rebalance model deployments after deployment for model [%s] was stopped", new Object[]{str3}), exc);
                }));
                actionListener.onResponse(AcknowledgedResponse.TRUE);
            }
        });
    }

    public void removeAllModelAssignments(final ActionListener<AcknowledgedResponse> actionListener) {
        submitUnbatchedTask("delete all model assignments", new ClusterStateUpdateTask() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService.5
            public ClusterState execute(ClusterState clusterState) {
                return TrainedModelAssignmentClusterService.removeAllAssignments(clusterState);
            }

            public void onFailure(Exception exc) {
                actionListener.onFailure(exc);
            }

            public void clusterStateProcessed(ClusterState clusterState, ClusterState clusterState2) {
                actionListener.onResponse(AcknowledgedResponse.TRUE);
            }
        });
    }

    private static ClusterState update(ClusterState clusterState, TrainedModelAssignmentMetadata.Builder builder) {
        return builder.build().equals(TrainedModelAssignmentMetadata.fromState(clusterState)) ? clusterState : forceUpdate(clusterState, builder);
    }

    private static ClusterState forceUpdate(ClusterState clusterState, TrainedModelAssignmentMetadata.Builder builder) {
        logger.debug(() -> {
            return Strings.format("updated assignments: %s", new Object[]{builder.build()});
        });
        Metadata.Builder builder2 = Metadata.builder(clusterState.metadata());
        if (clusterState.getNodes().getMinNodeVersion().onOrAfter(RENAME_ALLOCATION_TO_ASSIGNMENT_VERSION)) {
            builder2.putCustom(TrainedModelAssignmentMetadata.NAME, builder.build()).removeCustom(TrainedModelAssignmentMetadata.DEPRECATED_NAME);
        } else {
            builder2.putCustom(TrainedModelAssignmentMetadata.DEPRECATED_NAME, builder.buildOld());
        }
        return ClusterState.builder(clusterState).metadata(builder2).build();
    }

    ClusterState createModelAssignment(ClusterState clusterState, StartTrainedModelDeploymentAction.TaskParams taskParams) throws Exception {
        return update(clusterState, rebalanceAssignments(clusterState, Optional.of(taskParams)));
    }

    private void rebalanceAssignments(ClusterState clusterState, Optional<StartTrainedModelDeploymentAction.TaskParams> optional, String str, ActionListener<TrainedModelAssignmentMetadata> actionListener) {
        this.threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
            logger.debug(() -> {
                return Strings.format("Rebalancing model allocations because [%s]", new Object[]{str});
            });
            try {
                final TrainedModelAssignmentMetadata.Builder rebalanceAssignments = rebalanceAssignments(clusterState, optional);
                submitUnbatchedTask(str, new ClusterStateUpdateTask() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService.6
                    private volatile boolean isUpdated;
                    private volatile boolean isChanged;

                    public ClusterState execute(ClusterState clusterState2) {
                        if (!TrainedModelAssignmentClusterService.this.areClusterStatesCompatibleForRebalance(clusterState, clusterState2)) {
                            TrainedModelAssignmentClusterService.this.rebalanceAssignments(clusterState2, optional, str, actionListener);
                            return clusterState2;
                        }
                        this.isUpdated = true;
                        ClusterState update = TrainedModelAssignmentClusterService.update(clusterState2, rebalanceAssignments);
                        this.isChanged = update != clusterState2;
                        return update;
                    }

                    public void onFailure(Exception exc) {
                        actionListener.onFailure(exc);
                    }

                    public void clusterStateProcessed(ClusterState clusterState2, ClusterState clusterState3) {
                        if (this.isUpdated) {
                            if (this.isChanged) {
                                ExecutorService executor = TrainedModelAssignmentClusterService.this.threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME);
                                String str2 = str;
                                executor.execute(() -> {
                                    TrainedModelAssignmentClusterService.this.systemAuditor.info(Messages.getMessage("Rebalanced trained model allocations because [{0}]", new Object[]{str2}));
                                });
                            }
                            actionListener.onResponse(TrainedModelAssignmentMetadata.fromState(clusterState3));
                        }
                    }
                });
            } catch (Exception e) {
                actionListener.onFailure(e);
            }
        });
    }

    private boolean areClusterStatesCompatibleForRebalance(ClusterState clusterState, ClusterState clusterState2) {
        List<DiscoveryNode> assignableNodes = getAssignableNodes(clusterState);
        List<DiscoveryNode> assignableNodes2 = getAssignableNodes(clusterState2);
        return assignableNodes.equals(assignableNodes2) && detectNodeLoads(assignableNodes, clusterState).equals(detectNodeLoads(assignableNodes2, clusterState2)) && MlMetadata.getMlMetadata(clusterState).equals(MlMetadata.getMlMetadata(clusterState2)) && TrainedModelAssignmentMetadata.fromState(clusterState).equals(TrainedModelAssignmentMetadata.fromState(clusterState2));
    }

    private TrainedModelAssignmentMetadata.Builder rebalanceAssignments(ClusterState clusterState, Optional<StartTrainedModelDeploymentAction.TaskParams> optional) throws Exception {
        List<DiscoveryNode> assignableNodes = getAssignableNodes(clusterState);
        logger.debug(() -> {
            return Strings.format("assignable nodes are %s", new Object[]{assignableNodes.stream().map((v0) -> {
                return v0.getId();
            }).toList()});
        });
        return new TrainedModelAssignmentRebalancer(TrainedModelAssignmentMetadata.fromState(clusterState), detectNodeLoads(assignableNodes, clusterState), optional).rebalance();
    }

    private static List<DiscoveryNode> getAssignableNodes(ClusterState clusterState) {
        Set<String> nodesShuttingDown = nodesShuttingDown(clusterState);
        return clusterState.getNodes().getNodes().values().stream().filter(StartTrainedModelDeploymentAction.TaskParams::mayAssignToNode).filter(discoveryNode -> {
            return !nodesShuttingDown.contains(discoveryNode.getId());
        }).toList();
    }

    private Map<DiscoveryNode, NodeLoad> detectNodeLoads(List<DiscoveryNode> list, ClusterState clusterState) {
        return (Map) list.stream().collect(Collectors.toMap(Function.identity(), discoveryNode -> {
            return this.nodeLoadDetector.detectNodeLoad(clusterState, null, discoveryNode, this.maxOpenJobs, this.maxMemoryPercentage, this.useAuto);
        }));
    }

    static ClusterState setToStopping(ClusterState clusterState, String str, String str2) {
        TrainedModelAssignment modelAssignment = TrainedModelAssignmentMetadata.fromState(clusterState).getModelAssignment(str);
        if (modelAssignment == null) {
            throw new ResourceNotFoundException("assignment for model with id [{}] not found", new Object[]{str});
        }
        if (modelAssignment.getAssignmentState().equals(AssignmentState.STOPPING)) {
            return clusterState;
        }
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState);
        builder.getAssignment(str).stopAssignment(str2);
        return update(clusterState, builder);
    }

    static ClusterState updateModelRoutingTable(ClusterState clusterState, UpdateTrainedModelAssignmentRoutingInfoAction.Request request) {
        String modelId = request.getModelId();
        String nodeId = request.getNodeId();
        TrainedModelAssignmentMetadata fromState = TrainedModelAssignmentMetadata.fromState(clusterState);
        logger.trace(() -> {
            return Strings.format("[%s] [%s] current metadata before update %s", new Object[]{modelId, nodeId, org.elasticsearch.common.Strings.toString(fromState)});
        });
        TrainedModelAssignment modelAssignment = fromState.getModelAssignment(modelId);
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState);
        if (request.getUpdate().getStateAndReason().isPresent() && ((RoutingStateAndReason) request.getUpdate().getStateAndReason().get()).getState().equals(RoutingState.STOPPED)) {
            if (modelAssignment == null || !modelAssignment.isRoutedToNode(nodeId)) {
                return clusterState;
            }
            builder.getAssignment(modelId).removeRoutingEntry(nodeId).calculateAndSetAssignmentState();
            return update(clusterState, builder);
        }
        if (modelAssignment == null) {
            throw new ResourceNotFoundException("assignment for model with id [{}] not found", new Object[]{modelId});
        }
        if (modelAssignment.getAssignmentState().equals(AssignmentState.STOPPING)) {
            logger.debug(() -> {
                return Strings.format("[%s] requested update from node [%s] while stopping; update was [%s]", new Object[]{modelId, nodeId, request.getUpdate()});
            });
            return clusterState;
        }
        if (!modelAssignment.isRoutedToNode(nodeId)) {
            throw new ResourceNotFoundException("assignment for model with id [{}]] is not routed to node [{}]", new Object[]{modelId, nodeId});
        }
        builder.getAssignment(modelId).updateExistingRoutingEntry(nodeId, request.getUpdate().apply((RoutingInfo) modelAssignment.getNodeRoutingTable().get(nodeId))).calculateAndSetAssignmentState();
        return update(clusterState, builder);
    }

    static ClusterState removeAssignment(ClusterState clusterState, String str) {
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState);
        if (!builder.hasModel(str)) {
            throw new ResourceNotFoundException("assignment for model with id [{}] not found", new Object[]{str});
        }
        logger.debug(() -> {
            return Strings.format("[%s] removing assignment", new Object[]{str});
        });
        return update(clusterState, builder.removeAssignment(str));
    }

    static ClusterState removeAllAssignments(ClusterState clusterState) {
        return TrainedModelAssignmentMetadata.fromState(clusterState).modelAssignments().isEmpty() ? clusterState : forceUpdate(clusterState, TrainedModelAssignmentMetadata.Builder.empty());
    }

    static Optional<String> detectReasonToRebalanceModels(ClusterChangedEvent clusterChangedEvent) {
        TrainedModelAssignmentMetadata fromState = TrainedModelAssignmentMetadata.fromState(clusterChangedEvent.state());
        return (fromState == null || fromState.modelAssignments().isEmpty()) ? Optional.empty() : detectReasonIfMlJobsStopped(clusterChangedEvent).or(() -> {
            Object obj = null;
            if (haveMlNodesChanged(clusterChangedEvent, fromState)) {
                obj = "nodes changed";
            } else if (fromState.hasOutdatedAssignments()) {
                obj = "outdated assignments detected";
            }
            return Optional.ofNullable(obj);
        });
    }

    static Optional<String> detectReasonIfMlJobsStopped(ClusterChangedEvent clusterChangedEvent) {
        if (!clusterChangedEvent.changedCustomMetadataSet().contains("persistent_tasks")) {
            return Optional.empty();
        }
        PersistentTasksCustomMetadata custom = clusterChangedEvent.previousState().getMetadata().custom("persistent_tasks");
        PersistentTasksCustomMetadata custom2 = clusterChangedEvent.state().getMetadata().custom("persistent_tasks");
        Set<String> findMlProcessTaskIds = findMlProcessTaskIds(custom);
        findMlProcessTaskIds.removeAll(findMlProcessTaskIds(custom2));
        Stream<String> stream = findMlProcessTaskIds.stream();
        Objects.requireNonNull(custom);
        Set set = (Set) stream.map(custom::getTask).map((v0) -> {
            return v0.getTaskName();
        }).map(MlTasks::prettyPrintTaskName).collect(Collectors.toSet());
        return findMlProcessTaskIds.size() == 1 ? Optional.of("ML [" + ((String) set.iterator().next()) + "] job stopped") : findMlProcessTaskIds.size() > 1 ? Optional.of("ML " + set + " jobs stopped") : Optional.empty();
    }

    private static Set<String> findMlProcessTaskIds(@Nullable PersistentTasksCustomMetadata persistentTasksCustomMetadata) {
        return persistentTasksCustomMetadata == null ? Set.of() : (Set) MlTasks.findMlProcessTasks(persistentTasksCustomMetadata).stream().map((v0) -> {
            return v0.getId();
        }).collect(Collectors.toSet());
    }

    static boolean haveMlNodesChanged(ClusterChangedEvent clusterChangedEvent, TrainedModelAssignmentMetadata trainedModelAssignmentMetadata) {
        Set<String> emptySet;
        boolean contains = clusterChangedEvent.changedCustomMetadataSet().contains("node_shutdown");
        if (!clusterChangedEvent.nodesChanged() && !contains) {
            return false;
        }
        Set<String> nodesShuttingDown = nodesShuttingDown(clusterChangedEvent.state());
        DiscoveryNodes.Delta nodesDelta = clusterChangedEvent.nodesDelta();
        Set<String> set = (Set) nodesDelta.removedNodes().stream().map((v0) -> {
            return v0.getId();
        }).collect(Collectors.toSet());
        Set<String> set2 = (Set) nodesDelta.addedNodes().stream().map((v0) -> {
            return v0.getId();
        }).collect(Collectors.toSet());
        if (contains) {
            Set<String> nodesShuttingDown2 = nodesShuttingDown(clusterChangedEvent.previousState());
            set2.addAll(Sets.difference(nodesShuttingDown2, nodesShuttingDown));
            emptySet = Sets.difference(nodesShuttingDown, nodesShuttingDown2);
            set.addAll(emptySet);
        } else {
            emptySet = Collections.emptySet();
        }
        Set set3 = emptySet;
        logger.debug(() -> {
            return Strings.format("added nodes %s; removed nodes %s; shutting down nodes %s; exiting shutdown nodes %s", new Object[]{set2, set, nodesShuttingDown, set3});
        });
        for (TrainedModelAssignment trainedModelAssignment : trainedModelAssignmentMetadata.modelAssignments().values()) {
            if (!trainedModelAssignment.getAssignmentState().equals(AssignmentState.STOPPING)) {
                for (String str : emptySet) {
                    if (trainedModelAssignment.isRoutedToNode(str)) {
                        logger.debug(() -> {
                            return Strings.format("should rebalance because model [%s] has allocations on shutting down node [%s]", new Object[]{trainedModelAssignment.getModelId(), str});
                        });
                        return true;
                    }
                }
                for (String str2 : set) {
                    if (trainedModelAssignment.isRoutedToNode(str2) && !nodesShuttingDown.contains(str2)) {
                        logger.debug(() -> {
                            return Strings.format("should rebalance because model [%s] has allocations on removed node [%s]", new Object[]{trainedModelAssignment.getModelId(), str2});
                        });
                        return true;
                    }
                }
                for (String str3 : set2) {
                    if (StartTrainedModelDeploymentAction.TaskParams.mayAssignToNode(clusterChangedEvent.state().nodes().get(str3)) && !nodesShuttingDown.contains(str3)) {
                        logger.debug(() -> {
                            return Strings.format("should rebalance because ML eligible node [%s] was added", new Object[]{str3});
                        });
                        return true;
                    }
                }
            }
        }
        return false;
    }

    static Set<String> nodesShuttingDown(ClusterState clusterState) {
        return (Set) NodesShutdownMetadata.getShutdowns(clusterState).map((v0) -> {
            return v0.getAllNodeMetadataMap();
        }).map((v0) -> {
            return v0.keySet();
        }).orElse(Collections.emptySet());
    }
}
