package org.elasticsearch.xpack.ml.inference.assignment.planning;

import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Strings;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.class */
public class AssignmentPlanner {
    private static final Logger logger;
    private final List<AssignmentPlan.Node> nodes;
    private final List<AssignmentPlan.Model> models;
    static final /* synthetic */ boolean $assertionsDisabled;

    public AssignmentPlanner(List<AssignmentPlan.Node> list, List<AssignmentPlan.Model> list2) {
        this.nodes = list.stream().sorted(Comparator.comparing((v0) -> {
            return v0.id();
        })).toList();
        this.models = list2.stream().sorted(Comparator.comparing((v0) -> {
            return v0.id();
        })).toList();
    }

    public AssignmentPlan computePlan() {
        return computePlan(true);
    }

    private AssignmentPlan computePlan(boolean z) {
        AssignmentPlan assignmentPlan;
        logger.debug(() -> {
            return Strings.format("Computing plan for nodes = %s; models = %s", new Object[]{this.nodes, this.models});
        });
        AssignmentPlan solveSatisfyingCurrentAssignments = solveSatisfyingCurrentAssignments();
        logger.debug(() -> {
            return "Plan satisfying current assignments =\n" + solveSatisfyingCurrentAssignments.prettyPrint();
        });
        if (solveSatisfyingCurrentAssignments.arePreviouslyAssignedModelsAssigned() || !z) {
            assignmentPlan = solveSatisfyingCurrentAssignments;
        } else {
            AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated = solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated();
            logger.debug(() -> {
                return "Plan with at least one allocation for previously assigned models =\n" + solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated.prettyPrint();
            });
            if (solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated.arePreviouslyAssignedModelsAssigned()) {
                assignmentPlan = solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated;
            } else {
                assignmentPlan = solveSatisfyingCurrentAssignments.countPreviouslyAssignedModelsThatAreStillAssigned() >= solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated.countPreviouslyAssignedModelsThatAreStillAssigned() ? solveSatisfyingCurrentAssignments : solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated;
            }
        }
        AssignmentPlan assignmentPlan2 = assignmentPlan;
        logger.debug(() -> {
            return "Best plan =\n" + assignmentPlan2.prettyPrint();
        });
        AssignmentPlan assignmentPlan3 = assignmentPlan;
        logger.debug(() -> {
            return prettyPrintOverallStats(assignmentPlan3);
        });
        return assignmentPlan;
    }

    private AssignmentPlan solveSatisfyingCurrentAssignments() {
        AssignmentPlan assignmentPlan;
        AssignmentPlan solveKeepingOneAllocationOnCurrentAssignments = solveKeepingOneAllocationOnCurrentAssignments();
        if (!solveKeepingOneAllocationOnCurrentAssignments.satisfiesCurrentAssignments()) {
            assignmentPlan = solvePreservingAllAllocationsOnCurrentAssignments();
        } else if (solveKeepingOneAllocationOnCurrentAssignments.satisfiesAllModels()) {
            assignmentPlan = solveKeepingOneAllocationOnCurrentAssignments;
        } else {
            AssignmentPlan solvePreservingAllAllocationsOnCurrentAssignments = solvePreservingAllAllocationsOnCurrentAssignments();
            assignmentPlan = solvePreservingAllAllocationsOnCurrentAssignments.compareTo(solveKeepingOneAllocationOnCurrentAssignments) >= 0 ? solvePreservingAllAllocationsOnCurrentAssignments : solveKeepingOneAllocationOnCurrentAssignments;
        }
        return assignmentPlan;
    }

    private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated() {
        logger.debug(() -> {
            return "Attempting to solve assigning at least one allocations to previously assigned models";
        });
        AssignmentPlan solvePlan = new LinearProgrammingPlanSolver(this.nodes, this.models.stream().filter(model -> {
            return model.hasEverBeenAllocated();
        }).map(model2 -> {
            return new AssignmentPlan.Model(model2.id(), model2.memoryBytes(), 1, model2.threadsPerAllocation(), model2.currentAllocationsByNodeId(), model2.maxAssignedAllocations());
        }).toList()).solvePlan(true);
        HashMap hashMap = new HashMap();
        for (AssignmentPlan.Model model3 : solvePlan.models()) {
            Set<AssignmentPlan.Node> keySet = solvePlan.assignments(model3).orElse(Map.of()).keySet();
            if (!keySet.isEmpty()) {
                if (!$assertionsDisabled && keySet.size() != 1) {
                    throw new AssertionError();
                }
                hashMap.put(model3.id(), keySet.iterator().next().id());
            }
        }
        return new AssignmentPlanner(this.nodes, this.models.stream().map(model4 -> {
            return new AssignmentPlan.Model(model4.id(), model4.memoryBytes(), model4.allocations(), model4.threadsPerAllocation(), hashMap.containsKey(model4.id()) ? Map.of((String) hashMap.get(model4.id()), 1) : Map.of(), model4.maxAssignedAllocations());
        }).toList()).computePlan(false);
    }

    private AssignmentPlan solveKeepingOneAllocationOnCurrentAssignments() {
        logger.trace(() -> {
            return Strings.format("Solving preserving one allocation on current assignments", new Object[0]);
        });
        return solvePreservingCurrentAssignments(new PreserveOneAllocation(this.nodes, this.models));
    }

    private AssignmentPlan solvePreservingAllAllocationsOnCurrentAssignments() {
        logger.trace(() -> {
            return Strings.format("Solving preserving all allocations on current assignments", new Object[0]);
        });
        return solvePreservingCurrentAssignments(new PreserveAllAllocations(this.nodes, this.models));
    }

    private AssignmentPlan solvePreservingCurrentAssignments(AbstractPreserveAllocations abstractPreserveAllocations) {
        List<AssignmentPlan.Node> nodesPreservingAllocations = abstractPreserveAllocations.nodesPreservingAllocations();
        List<AssignmentPlan.Model> modelsPreservingAllocations = abstractPreserveAllocations.modelsPreservingAllocations();
        logger.trace(() -> {
            return Strings.format("Nodes after applying allocation preserving strategy = %s", new Object[]{nodesPreservingAllocations});
        });
        logger.trace(() -> {
            return Strings.format("Models after applying allocation preserving strategy = %s", new Object[]{modelsPreservingAllocations});
        });
        return abstractPreserveAllocations.mergePreservedAllocations(new LinearProgrammingPlanSolver(nodesPreservingAllocations, modelsPreservingAllocations).solvePlan(false));
    }

    private String prettyPrintOverallStats(AssignmentPlan assignmentPlan) {
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        long sum = this.nodes.stream().map((v0) -> {
            return v0.availableMemoryBytes();
        }).mapToLong((v0) -> {
            return v0.longValue();
        }).sum();
        int sum2 = this.nodes.stream().map((v0) -> {
            return v0.cores();
        }).mapToInt((v0) -> {
            return v0.intValue();
        }).sum();
        long j = 0;
        for (AssignmentPlan.Model model : this.models) {
            i += model.allocations();
            if (assignmentPlan.assignments(model).isPresent()) {
                int sum3 = assignmentPlan.assignments(model).get().values().stream().mapToInt((v0) -> {
                    return v0.intValue();
                }).sum();
                i2 += sum3;
                i3 += sum3 * model.threadsPerAllocation();
                j += model.memoryBytes() * assignmentPlan.assignments(model).get().values().size();
            }
        }
        return "Overall Stats: (used memory = " + ByteSizeValue.ofBytes(j) + ") (total available memory = " + ByteSizeValue.ofBytes(sum) + ") (allocations = " + i2 + "/" + i + ") (cores = " + i3 + "/" + sum2 + ")";
    }

    static {
        $assertionsDisabled = !AssignmentPlanner.class.desiredAssertionStatus();
        logger = LogManager.getLogger(AssignmentPlanner.class);
    }
}
