package org.elasticsearch.xpack.ml.inference.assignment.planning;

import java.security.AccessController;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;
import org.ojalgo.optimisation.ExpressionsBasedModel;
import org.ojalgo.optimisation.Optimisation;
import org.ojalgo.optimisation.Variable;
import org.ojalgo.structure.Access1D;
import org.ojalgo.type.CalendarDateDuration;
import org.ojalgo.type.CalendarDateUnit;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.class */
class LinearProgrammingPlanSolver {
    private static final Logger logger = LogManager.getLogger(LinearProgrammingPlanSolver.class);
    private static final long RANDOMIZATION_SEED = 738921734;
    private static final double L1 = 0.9d;
    private static final double INITIAL_W = 0.2d;
    private static final int RANDOMIZED_ROUNDING_ROUNDS = 20;
    private static final int MEMORY_COMPLEXITY_SPARSE_THRESHOLD = 4000000;
    private static final int MEMORY_COMPLEXITY_LIMIT = 10000000;
    private final Random random = new Random(RANDOMIZATION_SEED);
    private final List<AssignmentPlan.Node> nodes;
    private final List<AssignmentPlan.Model> models;
    private final Map<AssignmentPlan.Node, Double> normalizedMemoryPerNode;
    private final Map<AssignmentPlan.Node, Integer> coresPerNode;
    private final Map<AssignmentPlan.Model, Double> normalizedMemoryPerModel;
    private final int maxNodeCores;
    private final long maxModelMemoryBytes;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LinearProgrammingPlanSolver(List<AssignmentPlan.Node> list, List<AssignmentPlan.Model> list2) {
        this.nodes = list;
        this.maxNodeCores = ((Integer) this.nodes.stream().map((v0) -> {
            return v0.cores();
        }).max((v0, v1) -> {
            return v0.compareTo(v1);
        }).orElse(0)).intValue();
        long longValue = ((Long) list.stream().map((v0) -> {
            return v0.availableMemoryBytes();
        }).max((v0, v1) -> {
            return v0.compareTo(v1);
        }).orElse(0L)).longValue();
        this.models = list2.stream().filter(model -> {
            return !model.currentAllocationsByNodeId().isEmpty() || model.memoryBytes() <= longValue;
        }).filter(model2 -> {
            return model2.threadsPerAllocation() <= this.maxNodeCores;
        }).toList();
        this.maxModelMemoryBytes = ((Long) this.models.stream().map((v0) -> {
            return v0.memoryBytes();
        }).max((v0, v1) -> {
            return v0.compareTo(v1);
        }).orElse(1L)).longValue();
        this.normalizedMemoryPerNode = (Map) this.nodes.stream().collect(Collectors.toMap(Function.identity(), node -> {
            return Double.valueOf(node.availableMemoryBytes() / this.maxModelMemoryBytes);
        }));
        this.coresPerNode = (Map) this.nodes.stream().collect(Collectors.toMap(Function.identity(), (v0) -> {
            return v0.cores();
        }));
        this.normalizedMemoryPerModel = (Map) this.models.stream().collect(Collectors.toMap(Function.identity(), model3 -> {
            return Double.valueOf(model3.memoryBytes() / this.maxModelMemoryBytes);
        }));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public AssignmentPlan solvePlan(boolean z) {
        if (this.models.isEmpty() || this.maxNodeCores == 0) {
            return AssignmentPlan.builder(this.nodes, this.models).build();
        }
        Tuple<Map<Tuple<AssignmentPlan.Model, AssignmentPlan.Node>, Double>, AssignmentPlan> calculateWeightsAndBinPackingPlan = calculateWeightsAndBinPackingPlan();
        if (z) {
            return (AssignmentPlan) calculateWeightsAndBinPackingPlan.v2();
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        if (!solveLinearProgram((Map) calculateWeightsAndBinPackingPlan.v1(), hashMap, hashMap2)) {
            return (AssignmentPlan) calculateWeightsAndBinPackingPlan.v2();
        }
        AssignmentPlan computePlan = new RandomizedAssignmentRounding(this.random, 20, this.nodes, this.models).computePlan(hashMap, hashMap2);
        AssignmentPlan assignmentPlan = (AssignmentPlan) calculateWeightsAndBinPackingPlan.v2();
        if (assignmentPlan.compareTo(computePlan) > 0) {
            computePlan = assignmentPlan;
            logger.debug(() -> {
                return "Best plan is from bin packing";
            });
        } else {
            logger.debug(() -> {
                return "Best plan is from LP solver";
            });
        }
        return computePlan;
    }

    private double weightForAllocationVar(AssignmentPlan.Model model, AssignmentPlan.Node node, Map<Tuple<AssignmentPlan.Model, AssignmentPlan.Node>, Double> map) {
        return ((1.0d + map.get(Tuple.tuple(model, node)).doubleValue()) - (model.memoryBytes() > node.availableMemoryBytes() ? 10 : 0)) - ((L1 * this.normalizedMemoryPerModel.get(model).doubleValue()) / this.maxNodeCores);
    }

    private Tuple<Map<Tuple<AssignmentPlan.Model, AssignmentPlan.Node>, Double>, AssignmentPlan> calculateWeightsAndBinPackingPlan() {
        logger.debug(() -> {
            return "Calculating weights and bin packing plan";
        });
        double d = 0.2d;
        double size = (INITIAL_W / this.nodes.size()) / this.models.size();
        HashMap hashMap = new HashMap();
        AssignmentPlan.Builder builder = AssignmentPlan.builder(this.nodes, this.models);
        for (AssignmentPlan.Model model : this.models.stream().sorted(Comparator.comparingDouble(this::descendingSizeAnyFitsModelOrder)).toList()) {
            do {
                double d2 = d;
                Iterator<AssignmentPlan.Node> it = this.nodes.stream().sorted(Comparator.comparingDouble(node -> {
                    return descendingSizeAnyFitsNodeOrder(node, model, builder);
                })).toList().iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    AssignmentPlan.Node next = it.next();
                    int min = Math.min(builder.getRemainingCores(next) / model.threadsPerAllocation(), builder.getRemainingAllocations(model));
                    if (min > 0 && builder.canAssign(model, next, min)) {
                        builder.assignModelToNode(model, next, min);
                        hashMap.put(Tuple.tuple(model, next), Double.valueOf(d));
                        d -= size;
                        break;
                    }
                }
                if (d2 != d) {
                }
            } while (builder.getRemainingAllocations(model) > 0);
        }
        double d3 = d;
        for (AssignmentPlan.Model model2 : this.models) {
            for (AssignmentPlan.Node node2 : this.nodes) {
                hashMap.computeIfAbsent(Tuple.tuple(model2, node2), tuple -> {
                    return Double.valueOf(this.random.nextDouble(minWeight(model2, node2, d3), maxWeight(model2, node2, d3)));
                });
            }
        }
        logger.trace(() -> {
            return "Weights = " + hashMap;
        });
        AssignmentPlan build = builder.build();
        logger.debug(() -> {
            return "Bin packing plan =\n" + build.prettyPrint();
        });
        return Tuple.tuple(hashMap, build);
    }

    private double descendingSizeAnyFitsModelOrder(AssignmentPlan.Model model) {
        return (model.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * (-this.normalizedMemoryPerModel.get(model).doubleValue()) * model.threadsPerAllocation();
    }

    private double descendingSizeAnyFitsNodeOrder(AssignmentPlan.Node node, AssignmentPlan.Model model, AssignmentPlan.Builder builder) {
        return (((model.currentAllocationsByNodeId().containsKey(node.id()) ? 0 : 1) + (builder.getRemainingCores(node) >= builder.getRemainingThreads(model) ? 0 : 1)) + (0.01d * distance(builder.getRemainingCores(node), builder.getRemainingThreads(model)))) - (0.01d * this.normalizedMemoryPerNode.get(node).doubleValue());
    }

    @SuppressForbidden(reason = "Math#abs(int) is safe here as we protect against MIN_VALUE")
    private static int distance(int i, int i2) {
        int i3 = i - i2;
        if (i3 == Integer.MIN_VALUE) {
            return Integer.MAX_VALUE;
        }
        return Math.abs(i3);
    }

    private double minWeight(AssignmentPlan.Model model, AssignmentPlan.Node node, double d) {
        if (model.currentAllocationsByNodeId().containsKey(node.id())) {
            return d / 2.0d;
        }
        return 0.0d;
    }

    private double maxWeight(AssignmentPlan.Model model, AssignmentPlan.Node node, double d) {
        return model.currentAllocationsByNodeId().containsKey(node.id()) ? d : d / 2.0d;
    }

    private boolean solveLinearProgram(Map<Tuple<AssignmentPlan.Model, AssignmentPlan.Node>, Double> map, Map<Tuple<AssignmentPlan.Model, AssignmentPlan.Node>, Double> map2, Map<Tuple<AssignmentPlan.Model, AssignmentPlan.Node>, Double> map3) {
        if (memoryComplexity() > MEMORY_COMPLEXITY_LIMIT) {
            logger.debug(() -> {
                return "Problem size to big to solve with linear programming; falling back to bin packing solution";
            });
            return false;
        }
        Optimisation.Options abort = new Optimisation.Options().abort(new CalendarDateDuration(10.0d, CalendarDateUnit.SECOND));
        if (memoryComplexity() > MEMORY_COMPLEXITY_SPARSE_THRESHOLD) {
            logger.debug(() -> {
                return "Problem size is large enough to switch to sparse solver";
            });
            abort.sparse = true;
        }
        ExpressionsBasedModel expressionsBasedModel = new ExpressionsBasedModel(abort);
        HashMap hashMap = new HashMap();
        for (AssignmentPlan.Model model : this.models) {
            for (AssignmentPlan.Node node : this.nodes) {
                hashMap.put(Tuple.tuple(model, node), (Variable) expressionsBasedModel.addVariable("allocations_of_model_" + model.id() + "_on_node_" + node.id()).integer(false).lower(0.0d).weight(weightForAllocationVar(model, node, map)));
            }
        }
        for (AssignmentPlan.Model model2 : this.models) {
            expressionsBasedModel.addExpression("allocations_of_model_" + model2.id() + "_not_more_than_required").lower(model2.getCurrentAssignedAllocations()).upper(model2.allocations()).setLinearFactorsSimple(varsForModel(model2, hashMap));
        }
        double[] array = this.models.stream().mapToDouble(model3 -> {
            return model3.threadsPerAllocation();
        }).toArray();
        for (AssignmentPlan.Node node2 : this.nodes) {
            expressionsBasedModel.addExpression("threads_on_node_" + node2.id() + "_not_more_than_cores").upper(this.coresPerNode.get(node2)).setLinearFactors(varsForNode(node2, hashMap), Access1D.wrap(array));
        }
        for (AssignmentPlan.Node node3 : this.nodes) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            this.models.stream().filter(model4 -> {
                return !model4.currentAllocationsByNodeId().containsKey(node3.id());
            }).forEach(model5 -> {
                arrayList.add((Variable) hashMap.get(Tuple.tuple(model5, node3)));
                arrayList2.add(Double.valueOf((this.normalizedMemoryPerModel.get(model5).doubleValue() * model5.threadsPerAllocation()) / this.coresPerNode.get(node3).intValue()));
            });
            expressionsBasedModel.addExpression("used_memory_on_node_" + node3.id() + "_not_more_than_available").upper(this.normalizedMemoryPerNode.get(node3)).setLinearFactors(arrayList, Access1D.wrap(arrayList2));
        }
        Optimisation.Result privilegedModelMaximise = privilegedModelMaximise(expressionsBasedModel);
        if (!privilegedModelMaximise.getState().isFeasible()) {
            logger.debug("Linear programming solution state [{}] is not feasible", privilegedModelMaximise.getState());
            return false;
        }
        for (AssignmentPlan.Model model6 : this.models) {
            Iterator<AssignmentPlan.Node> it = this.nodes.iterator();
            while (it.hasNext()) {
                Tuple<AssignmentPlan.Model, AssignmentPlan.Node> tuple = Tuple.tuple(model6, it.next());
                map2.put(tuple, Double.valueOf(hashMap.get(tuple).getValue().doubleValue()));
                map3.put(tuple, Double.valueOf((hashMap.get(tuple).getValue().doubleValue() * model6.threadsPerAllocation()) / this.coresPerNode.get(r0).intValue()));
            }
        }
        logger.debug(() -> {
            return "LP solver result =\n" + prettyPrintSolverResult(map3, map2);
        });
        return true;
    }

    private static Optimisation.Result privilegedModelMaximise(ExpressionsBasedModel expressionsBasedModel) {
        return (Optimisation.Result) AccessController.doPrivileged(() -> {
            return expressionsBasedModel.maximise();
        });
    }

    private int memoryComplexity() {
        return (this.nodes.size() + this.models.size()) * this.nodes.size() * this.models.size();
    }

    private List<Variable> varsForModel(AssignmentPlan.Model model, Map<Tuple<AssignmentPlan.Model, AssignmentPlan.Node>, Variable> map) {
        return this.nodes.stream().map(node -> {
            return (Variable) map.get(Tuple.tuple(model, node));
        }).toList();
    }

    private List<Variable> varsForNode(AssignmentPlan.Node node, Map<Tuple<AssignmentPlan.Model, AssignmentPlan.Node>, Variable> map) {
        return this.models.stream().map(model -> {
            return (Variable) map.get(Tuple.tuple(model, node));
        }).toList();
    }

    private String prettyPrintSolverResult(Map<Tuple<AssignmentPlan.Model, AssignmentPlan.Node>, Double> map, Map<Tuple<AssignmentPlan.Model, AssignmentPlan.Node>, Double> map2) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < this.nodes.size(); i++) {
            AssignmentPlan.Node node = this.nodes.get(i);
            sb.append(node + " ->");
            for (AssignmentPlan.Model model : this.models) {
                if (map2.get(Tuple.tuple(model, node)).doubleValue() > 0.0d) {
                    sb.append(" ");
                    sb.append(model.id());
                    sb.append(" (mem = ");
                    sb.append(ByteSizeValue.ofBytes(model.memoryBytes()));
                    sb.append(") (allocations = ");
                    sb.append(map2.get(Tuple.tuple(model, node)));
                    sb.append("/");
                    sb.append(model.allocations());
                    sb.append(") (y = ");
                    sb.append(map.get(Tuple.tuple(model, node)));
                    sb.append(")");
                }
            }
            if (i < this.nodes.size() - 1) {
                sb.append('\n');
            }
        }
        return sb.toString();
    }
}
