package org.elasticsearch.xpack.ml.action;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.tasks.BaseTasksRequest;
import org.elasticsearch.action.support.tasks.BaseTasksResponse;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.inference.deployment.ModelStats;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;

/* loaded from: input_file:org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.class */
public class TransportGetDeploymentStatsAction extends TransportTasksAction<TrainedModelDeploymentTask, GetDeploymentStatsAction.Request, GetDeploymentStatsAction.Response, AssignmentStats> {
    @Inject
    public TransportGetDeploymentStatsAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService) {
        super("cluster:internal/xpack/ml/trained_models/deployments/stats/get", clusterService, transportService, actionFilters, GetDeploymentStatsAction.Request::new, GetDeploymentStatsAction.Response::new, AssignmentStats::new, "management");
    }

    protected GetDeploymentStatsAction.Response newResponse(GetDeploymentStatsAction.Request request, List<AssignmentStats> list, List<TaskOperationFailure> list2, List<FailedNodeException> list3) {
        return new GetDeploymentStatsAction.Response(list2, list3, new ArrayList(((TreeMap) list.stream().collect(Collectors.toMap((v0) -> {
            return v0.getModelId();
        }, Function.identity(), (assignmentStats, assignmentStats2) -> {
            assignmentStats.getNodeStats().addAll(assignmentStats2.getNodeStats());
            return assignmentStats;
        }, TreeMap::new))).values()), r0.size());
    }

    protected void doExecute(Task task, GetDeploymentStatsAction.Request request, ActionListener<GetDeploymentStatsAction.Response> actionListener) {
        ClusterState state = this.clusterService.state();
        TrainedModelAssignmentMetadata fromState = TrainedModelAssignmentMetadata.fromState(state);
        ExpandedIdsMatcher.SimpleIdsMatcher simpleIdsMatcher = new ExpandedIdsMatcher.SimpleIdsMatcher(Strings.tokenizeToStringArray(request.getDeploymentId(), ","));
        ArrayList arrayList = new ArrayList();
        HashSet hashSet = new HashSet();
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, TrainedModelAssignment> entry : fromState.modelAssignments().entrySet()) {
            String key = entry.getKey();
            if (simpleIdsMatcher.idMatches(key)) {
                arrayList.add(key);
                hashSet.addAll(Arrays.asList(entry.getValue().getStartedNodes()));
                hashMap.put(entry.getValue(), (Map) entry.getValue().getNodeRoutingTable().entrySet().stream().filter(entry2 -> {
                    return !RoutingState.STARTED.equals(((RoutingInfo) entry2.getValue()).getState());
                }).collect(Collectors.toMap((v0) -> {
                    return v0.getKey();
                }, (v0) -> {
                    return v0.getValue();
                })));
            }
        }
        if (arrayList.isEmpty()) {
            actionListener.onResponse(new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), 0L));
            return;
        }
        request.setNodes((String[]) hashSet.toArray(i -> {
            return new String[i];
        }));
        request.setExpandedIds(arrayList);
        super.doExecute(task, request, actionListener.delegateFailure((actionListener2, response) -> {
            GetDeploymentStatsAction.Response addFailedRoutes = addFailedRoutes(response, hashMap, state.nodes());
            for (AssignmentStats assignmentStats : addFailedRoutes.getStats().results()) {
                TrainedModelAssignment modelAssignment = fromState.getModelAssignment(assignmentStats.getModelId());
                if (modelAssignment != null) {
                    assignmentStats.setState(modelAssignment.getAssignmentState()).setReason((String) modelAssignment.getReason().orElse(null));
                    if (!modelAssignment.getNodeRoutingTable().isEmpty() && modelAssignment.getNodeRoutingTable().values().stream().allMatch(routingInfo -> {
                        return routingInfo.getState().equals(RoutingState.FAILED);
                    })) {
                        assignmentStats.setState(AssignmentState.FAILED);
                        if (assignmentStats.getReason() == null) {
                            assignmentStats.setReason("All node routes are failed; see node route reason for details");
                        }
                    }
                    if (modelAssignment.getAssignmentState().isAnyOf(new AssignmentState[]{AssignmentState.STARTED, AssignmentState.STARTING})) {
                        assignmentStats.setAllocationStatus((AllocationStatus) modelAssignment.calculateAllocationStatus().orElse(null));
                    }
                }
            }
            actionListener2.onResponse(addFailedRoutes);
        }));
    }

    static GetDeploymentStatsAction.Response addFailedRoutes(GetDeploymentStatsAction.Response response, Map<TrainedModelAssignment, Map<String, RoutingInfo>> map, DiscoveryNodes discoveryNodes) {
        Map map2 = (Map) map.keySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getModelId();
        }, Function.identity()));
        ArrayList arrayList = new ArrayList();
        for (AssignmentStats assignmentStats : response.getStats().results()) {
            if (map2.containsKey(assignmentStats.getModelId())) {
                Map<String, RoutingInfo> map3 = map.get(map2.get(assignmentStats.getModelId()));
                ArrayList arrayList2 = new ArrayList();
                HashSet hashSet = new HashSet();
                for (AssignmentStats.NodeStats nodeStats : assignmentStats.getNodeStats()) {
                    if (map3.containsKey(nodeStats.getNode().getId())) {
                        RoutingInfo routingInfo = map3.get(nodeStats.getNode().getId());
                        arrayList2.add(AssignmentStats.NodeStats.forNotStartedState(nodeStats.getNode(), routingInfo.getState(), routingInfo.getReason()));
                    } else {
                        arrayList2.add(nodeStats);
                    }
                    hashSet.add(nodeStats.getNode().getId());
                }
                for (Map.Entry<String, RoutingInfo> entry : map3.entrySet()) {
                    if (!hashSet.contains(entry.getKey())) {
                        arrayList2.add(AssignmentStats.NodeStats.forNotStartedState(discoveryNodes.get(entry.getKey()), entry.getValue().getState(), entry.getValue().getReason()));
                    }
                }
                arrayList2.sort(Comparator.comparing(nodeStats2 -> {
                    return nodeStats2.getNode().getId();
                }));
                arrayList.add(new AssignmentStats(assignmentStats.getModelId(), assignmentStats.getThreadsPerAllocation(), assignmentStats.getNumberOfAllocations(), assignmentStats.getQueueCapacity(), assignmentStats.getCacheSize(), assignmentStats.getStartTime(), arrayList2));
            } else {
                arrayList.add(assignmentStats);
            }
        }
        for (Map.Entry<TrainedModelAssignment, Map<String, RoutingInfo>> entry2 : map.entrySet()) {
            TrainedModelAssignment key = entry2.getKey();
            String modelId = key.getTaskParams().getModelId();
            if (!response.getStats().results().stream().anyMatch(assignmentStats2 -> {
                return modelId.equals(assignmentStats2.getModelId());
            })) {
                ArrayList arrayList3 = new ArrayList();
                for (Map.Entry<String, RoutingInfo> entry3 : entry2.getValue().entrySet()) {
                    arrayList3.add(AssignmentStats.NodeStats.forNotStartedState(discoveryNodes.get(entry3.getKey()), entry3.getValue().getState(), entry3.getValue().getReason()));
                }
                arrayList3.sort(Comparator.comparing(nodeStats3 -> {
                    return nodeStats3.getNode().getId();
                }));
                arrayList.add(new AssignmentStats(modelId, Integer.valueOf(key.getTaskParams().getThreadsPerAllocation()), Integer.valueOf(key.getTaskParams().getNumberOfAllocations()), Integer.valueOf(key.getTaskParams().getQueueCapacity()), (ByteSizeValue) key.getTaskParams().getCacheSize().orElse(null), key.getStartTime(), arrayList3));
            }
        }
        arrayList.sort(Comparator.comparing((v0) -> {
            return v0.getModelId();
        }));
        return new GetDeploymentStatsAction.Response(response.getTaskFailures(), response.getNodeFailures(), arrayList, arrayList.size());
    }

    protected void taskOperation(Task task, GetDeploymentStatsAction.Request request, TrainedModelDeploymentTask trainedModelDeploymentTask, ActionListener<AssignmentStats> actionListener) {
        Optional<ModelStats> modelStats = trainedModelDeploymentTask.modelStats();
        ArrayList arrayList = new ArrayList();
        if (modelStats.isPresent()) {
            ModelStats modelStats2 = modelStats.get();
            arrayList.add(AssignmentStats.NodeStats.forStartedState(this.clusterService.localNode(), modelStats2.timingStats().getCount(), Double.valueOf(modelStats2.timingStats().getAverage()), modelStats2.pendingCount(), modelStats2.errorCount(), modelStats2.cacheHitCount(), modelStats2.rejectedExecutionCount(), modelStats2.timeoutCount(), modelStats2.lastUsed(), modelStats2.startTime(), modelStats2.threadsPerAllocation(), modelStats2.numberOfAllocations(), modelStats2.peakThroughput(), modelStats2.throughputLastPeriod(), modelStats2.avgInferenceTimeLastPeriod(), modelStats2.cacheHitCountLastPeriod()));
        } else {
            arrayList.add(AssignmentStats.NodeStats.forNotStartedState(this.clusterService.localNode(), RoutingState.STOPPED, ""));
        }
        TrainedModelAssignment modelAssignment = TrainedModelAssignmentMetadata.fromState(this.clusterService.state()).getModelAssignment(trainedModelDeploymentTask.getModelId());
        actionListener.onResponse(new AssignmentStats(trainedModelDeploymentTask.getModelId(), Integer.valueOf(trainedModelDeploymentTask.getParams().getThreadsPerAllocation()), Integer.valueOf(modelAssignment == null ? trainedModelDeploymentTask.getParams().getNumberOfAllocations() : modelAssignment.getTaskParams().getNumberOfAllocations()), Integer.valueOf(trainedModelDeploymentTask.getParams().getQueueCapacity()), (ByteSizeValue) trainedModelDeploymentTask.getParams().getCacheSize().orElse(null), TrainedModelAssignmentMetadata.fromState(this.clusterService.state()).getModelAssignment(trainedModelDeploymentTask.getModelId()).getStartTime(), arrayList));
    }

    protected /* bridge */ /* synthetic */ void taskOperation(Task task, BaseTasksRequest baseTasksRequest, Task task2, ActionListener actionListener) {
        taskOperation(task, (GetDeploymentStatsAction.Request) baseTasksRequest, (TrainedModelDeploymentTask) task2, (ActionListener<AssignmentStats>) actionListener);
    }

    protected /* bridge */ /* synthetic */ BaseTasksResponse newResponse(BaseTasksRequest baseTasksRequest, List list, List list2, List list3) {
        return newResponse((GetDeploymentStatsAction.Request) baseTasksRequest, (List<AssignmentStats>) list, (List<TaskOperationFailure>) list2, (List<FailedNodeException>) list3);
    }

    protected /* bridge */ /* synthetic */ void doExecute(Task task, BaseTasksRequest baseTasksRequest, ActionListener actionListener) {
        doExecute(task, (GetDeploymentStatsAction.Request) baseTasksRequest, (ActionListener<GetDeploymentStatsAction.Response>) actionListener);
    }

    protected /* bridge */ /* synthetic */ void doExecute(Task task, ActionRequest actionRequest, ActionListener actionListener) {
        doExecute(task, (GetDeploymentStatsAction.Request) actionRequest, (ActionListener<GetDeploymentStatsAction.Response>) actionListener);
    }
}
