package org.elasticsearch.xpack.ml.inference.assignment;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlanner;
import org.elasticsearch.xpack.ml.job.NodeLoad;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.class */
class TrainedModelAssignmentRebalancer {
    private static final Logger logger;
    private final TrainedModelAssignmentMetadata currentMetadata;
    private final Map<DiscoveryNode, NodeLoad> nodeLoads;
    private final Optional<StartTrainedModelDeploymentAction.TaskParams> modelToAdd;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    public TrainedModelAssignmentRebalancer(TrainedModelAssignmentMetadata trainedModelAssignmentMetadata, Map<DiscoveryNode, NodeLoad> map, Optional<StartTrainedModelDeploymentAction.TaskParams> optional) {
        this.currentMetadata = (TrainedModelAssignmentMetadata) Objects.requireNonNull(trainedModelAssignmentMetadata);
        this.nodeLoads = (Map) Objects.requireNonNull(map);
        this.modelToAdd = (Optional) Objects.requireNonNull(optional);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public TrainedModelAssignmentMetadata.Builder rebalance() throws Exception {
        if (this.modelToAdd.isPresent() && this.currentMetadata.hasModel(this.modelToAdd.get().getModelId())) {
            throw new ResourceAlreadyExistsException("assignment for model with id [{}] already exists", new Object[]{this.modelToAdd.get().getModelId()});
        }
        if (!this.modelToAdd.isEmpty() || !areAllModelsSatisfiedAndNoOutdatedRoutingEntries()) {
            return buildAssignmentsFromPlan(computeAssignmentPlan());
        }
        logger.trace(() -> {
            return "No need to rebalance as all model deployments are satisfied";
        });
        return TrainedModelAssignmentMetadata.Builder.fromMetadata(this.currentMetadata);
    }

    private boolean areAllModelsSatisfiedAndNoOutdatedRoutingEntries() {
        Set set = (Set) this.nodeLoads.keySet().stream().map((v0) -> {
            return v0.getId();
        }).collect(Collectors.toSet());
        for (TrainedModelAssignment trainedModelAssignment : this.currentMetadata.modelAssignments().values()) {
            if (!trainedModelAssignment.isSatisfied(set) || trainedModelAssignment.hasOutdatedRoutingEntries()) {
                return false;
            }
        }
        return true;
    }

    AssignmentPlan computeAssignmentPlan() {
        List list = this.nodeLoads.entrySet().stream().filter(entry -> {
            return Strings.isNullOrEmpty(((NodeLoad) entry.getValue()).getError());
        }).map(entry2 -> {
            return new AssignmentPlan.Node(((DiscoveryNode) entry2.getKey()).getId(), getNodeFreeMemoryExcludingPerNodeOverheadAndNativeInference((NodeLoad) entry2.getValue()), getNodeAllocatedProcessors((DiscoveryNode) entry2.getKey()).orElse(0));
        }).toList();
        ArrayList arrayList = new ArrayList(this.currentMetadata.modelAssignments().size() + (this.modelToAdd.isPresent() ? 1 : 0));
        Set set = (Set) list.stream().map((v0) -> {
            return v0.id();
        }).collect(Collectors.toSet());
        Stream<R> map = this.currentMetadata.modelAssignments().values().stream().map(trainedModelAssignment -> {
            return new AssignmentPlan.Model(trainedModelAssignment.getModelId(), trainedModelAssignment.getTaskParams().estimateMemoryUsageBytes(), trainedModelAssignment.getTaskParams().getNumberOfAllocations(), trainedModelAssignment.getTaskParams().getThreadsPerAllocation(), (Map) trainedModelAssignment.getNodeRoutingTable().entrySet().stream().filter(entry3 -> {
                return set.contains(entry3.getKey());
            }).filter(entry4 -> {
                return ((RoutingInfo) entry4.getValue()).getCurrentAllocations() > 0 && ((RoutingInfo) entry4.getValue()).getTargetAllocations() > 0;
            }).filter(entry5 -> {
                return ((RoutingInfo) entry5.getValue()).getState().isAnyOf(new RoutingState[]{RoutingState.STARTING, RoutingState.STARTED, RoutingState.FAILED});
            }).collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, entry6 -> {
                return Integer.valueOf(((RoutingInfo) entry6.getValue()).getTargetAllocations());
            })), trainedModelAssignment.getMaxAssignedAllocations());
        });
        Objects.requireNonNull(arrayList);
        map.forEach((v1) -> {
            r1.add(v1);
        });
        this.modelToAdd.ifPresent(taskParams -> {
            arrayList.add(new AssignmentPlan.Model(taskParams.getModelId(), taskParams.estimateMemoryUsageBytes(), taskParams.getNumberOfAllocations(), taskParams.getThreadsPerAllocation(), Map.of(), 0));
        });
        return new AssignmentPlanner(list, arrayList).computePlan();
    }

    private static OptionalInt getNodeAllocatedProcessors(DiscoveryNode discoveryNode) {
        String str = (String) discoveryNode.getAttributes().get(MachineLearning.ALLOCATED_PROCESSORS_NODE_ATTR);
        try {
            return OptionalInt.of(Integer.parseInt(str));
        } catch (NumberFormatException e) {
            if ($assertionsDisabled || e == null) {
                return OptionalInt.empty();
            }
            throw new AssertionError("ml.allocated_processors should parse because we set it internally: invalid value was " + str);
        }
    }

    private static long getNodeFreeMemoryExcludingPerNodeOverheadAndNativeInference(NodeLoad nodeLoad) {
        return nodeLoad.getFreeMemoryExcludingPerNodeOverhead() - nodeLoad.getAssignedNativeInferenceMemory();
    }

    private TrainedModelAssignmentMetadata.Builder buildAssignmentsFromPlan(AssignmentPlan assignmentPlan) {
        TrainedModelAssignmentMetadata.Builder empty = TrainedModelAssignmentMetadata.Builder.empty();
        for (AssignmentPlan.Model model : assignmentPlan.models()) {
            TrainedModelAssignment modelAssignment = this.currentMetadata.getModelAssignment(model.id());
            TrainedModelAssignment.Builder empty2 = TrainedModelAssignment.Builder.empty((modelAssignment == null && this.modelToAdd.isPresent()) ? this.modelToAdd.get() : this.currentMetadata.getModelAssignment(model.id()).getTaskParams());
            if (modelAssignment != null) {
                empty2.setStartTime(modelAssignment.getStartTime());
                empty2.setMaxAssignedAllocations(modelAssignment.getMaxAssignedAllocations());
            }
            for (Map.Entry<AssignmentPlan.Node, Integer> entry : assignmentPlan.assignments(model).orElseGet(Map::of).entrySet()) {
                if (modelAssignment == null || !modelAssignment.isRoutedToNode(entry.getKey().id())) {
                    empty2.addRoutingEntry(entry.getKey().id(), new RoutingInfo(entry.getValue().intValue(), entry.getValue().intValue(), RoutingState.STARTING, ""));
                } else {
                    RoutingInfo routingInfo = (RoutingInfo) modelAssignment.getNodeRoutingTable().get(entry.getKey().id());
                    RoutingState state = routingInfo.getState();
                    String reason = routingInfo.getReason();
                    if (state == RoutingState.FAILED) {
                        state = RoutingState.STARTING;
                        reason = "";
                    }
                    empty2.addRoutingEntry(entry.getKey().id(), new RoutingInfo(routingInfo.getCurrentAllocations(), entry.getValue().intValue(), state, reason));
                }
            }
            empty2.calculateAndSetAssignmentState();
            Optional<String> explainAssignments = explainAssignments(assignmentPlan, this.nodeLoads, model);
            Objects.requireNonNull(empty2);
            explainAssignments.ifPresent(empty2::setReason);
            empty.addNewAssignment(model.id(), empty2);
        }
        return empty;
    }

    private Optional<String> explainAssignments(AssignmentPlan assignmentPlan, Map<DiscoveryNode, NodeLoad> map, AssignmentPlan.Model model) {
        if (assignmentPlan.satisfiesAllocations(model)) {
            return Optional.empty();
        }
        if (map.isEmpty()) {
            return Optional.of("No ML nodes exist in the cluster");
        }
        TreeMap treeMap = new TreeMap();
        for (Map.Entry<DiscoveryNode, NodeLoad> entry : map.entrySet()) {
            explainAssignment(assignmentPlan, entry.getKey(), entry.getValue(), model).ifPresent(str -> {
                treeMap.put(((DiscoveryNode) entry.getKey()).getId(), str);
            });
        }
        return !treeMap.isEmpty() ? Optional.of((String) treeMap.entrySet().stream().map(entry2 -> {
            return org.elasticsearch.core.Strings.format("Could not assign (more) allocations on node [%s]. Reason: %s", new Object[]{entry2.getKey(), entry2.getValue()});
        }).collect(Collectors.joining("|"))) : Optional.empty();
    }

    private Optional<String> explainAssignment(AssignmentPlan assignmentPlan, DiscoveryNode discoveryNode, NodeLoad nodeLoad, AssignmentPlan.Model model) {
        if (!Strings.isNullOrEmpty(nodeLoad.getError())) {
            return Optional.of(nodeLoad.getError());
        }
        if (model.memoryBytes() <= assignmentPlan.getRemainingNodeMemory(discoveryNode.getId())) {
            return model.threadsPerAllocation() > assignmentPlan.getRemainingNodeCores(discoveryNode.getId()) ? Optional.of(ParameterizedMessage.format("This node has insufficient allocated processors. Available processors [{}], free processors [{}], processors required for each allocation of this model [{}]", new Object[]{Integer.valueOf(getNodeAllocatedProcessors(discoveryNode).orElse(0)), Integer.valueOf(assignmentPlan.getRemainingNodeCores(discoveryNode.getId())), Integer.valueOf(model.threadsPerAllocation())})) : Optional.empty();
        }
        boolean z = nodeLoad.getNumAssignedJobsAndModels() > 0 || assignmentPlan.getRemainingNodeCores(nodeLoad.getNodeId()) < getNodeAllocatedProcessors(discoveryNode).orElse(0);
        long memoryBytes = model.memoryBytes() + (z ? 0L : MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes());
        long remainingNodeMemory = assignmentPlan.getRemainingNodeMemory(discoveryNode.getId()) + (z ? 0L : MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes());
        return Optional.of(ParameterizedMessage.format("This node has insufficient available memory. Available memory for ML [{} ({})], free memory [{} ({})], estimated memory required for this model [{} ({})].", new Object[]{Long.valueOf(nodeLoad.getMaxMlMemory()), ByteSizeValue.ofBytes(nodeLoad.getMaxMlMemory()).toString(), Long.valueOf(remainingNodeMemory), ByteSizeValue.ofBytes(remainingNodeMemory).toString(), Long.valueOf(memoryBytes), ByteSizeValue.ofBytes(memoryBytes).toString()}));
    }

    static {
        $assertionsDisabled = !TrainedModelAssignmentRebalancer.class.desiredAssertionStatus();
        logger = LogManager.getLogger(TrainedModelAssignmentRebalancer.class);
    }
}
