package org.elasticsearch.xpack.ml.action;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction;
import org.elasticsearch.action.support.master.MasterNodeRequest;
import org.elasticsearch.cluster.AckedClusterStateUpdateTask;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateUpdateTask;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.logging.HeaderWarning;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.license.License;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;

/* loaded from: input_file:org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.class */
public class TransportPutTrainedModelAliasAction extends AcknowledgedTransportMasterNodeAction<PutTrainedModelAliasAction.Request> {
    private static final Logger logger = LogManager.getLogger(TransportPutTrainedModelAliasAction.class);
    private final XPackLicenseState licenseState;
    private final TrainedModelProvider trainedModelProvider;
    private final InferenceAuditor auditor;

    @Inject
    public TransportPutTrainedModelAliasAction(TransportService transportService, TrainedModelProvider trainedModelProvider, ClusterService clusterService, ThreadPool threadPool, XPackLicenseState xPackLicenseState, ActionFilters actionFilters, InferenceAuditor inferenceAuditor, IndexNameExpressionResolver indexNameExpressionResolver) {
        super("cluster:admin/xpack/ml/inference/model_aliases/put", transportService, clusterService, threadPool, actionFilters, PutTrainedModelAliasAction.Request::new, indexNameExpressionResolver, "same");
        this.licenseState = xPackLicenseState;
        this.trainedModelProvider = trainedModelProvider;
        this.auditor = inferenceAuditor;
    }

    protected void masterOperation(Task task, PutTrainedModelAliasAction.Request request, ClusterState clusterState, ActionListener<AcknowledgedResponse> actionListener) throws Exception {
        boolean check = MachineLearningField.ML_API_FEATURE.check(this.licenseState);
        Predicate predicate = trainedModelConfig -> {
            return check || trainedModelConfig.getLicenseLevel() == License.OperationMode.BASIC;
        };
        String modelId = ModelAliasMetadata.fromState(clusterState).getModelId(request.getModelAlias());
        if (modelId != null && !request.isReassign()) {
            actionListener.onFailure(ExceptionsHelper.badRequestException("cannot assign model_alias [{}] to model_id [{}] as model_alias already refers to [{}]. Set parameter [reassign] to [true] if model_alias should be reassigned.", new Object[]{request.getModelAlias(), request.getModelId(), modelId}));
            return;
        }
        HashSet hashSet = new HashSet();
        hashSet.add(request.getModelAlias());
        hashSet.add(request.getModelId());
        if (modelId != null) {
            hashSet.add(modelId);
        }
        TrainedModelProvider trainedModelProvider = this.trainedModelProvider;
        GetTrainedModelsAction.Includes empty = GetTrainedModelsAction.Includes.empty();
        CheckedConsumer checkedConsumer = list -> {
            TrainedModelConfig trainedModelConfig2 = null;
            TrainedModelConfig trainedModelConfig3 = null;
            Iterator it = list.iterator();
            while (it.hasNext()) {
                TrainedModelConfig trainedModelConfig4 = (TrainedModelConfig) it.next();
                if (trainedModelConfig4.getModelId().equals(request.getModelId())) {
                    trainedModelConfig2 = trainedModelConfig4;
                }
                if (trainedModelConfig4.getModelId().equals(modelId)) {
                    trainedModelConfig3 = trainedModelConfig4;
                }
                if (trainedModelConfig4.getModelId().equals(request.getModelAlias())) {
                    actionListener.onFailure(ExceptionsHelper.badRequestException("model_alias cannot be the same as an existing trained model_id", new Object[0]));
                    return;
                }
            }
            if (trainedModelConfig2 == null) {
                actionListener.onFailure(ExceptionsHelper.missingTrainedModel(request.getModelId()));
                return;
            }
            if (!predicate.test(trainedModelConfig2)) {
                actionListener.onFailure(LicenseUtils.newComplianceException("ml"));
                return;
            }
            if (trainedModelConfig2.getModelType() == TrainedModelType.PYTORCH) {
                actionListener.onFailure(ExceptionsHelper.badRequestException("model_alias is not supported on pytorch models", new Object[0]));
                return;
            }
            if (trainedModelConfig3 != null) {
                if (trainedModelConfig2.getInferenceConfig() != null && trainedModelConfig3.getInferenceConfig() != null && !trainedModelConfig2.getInferenceConfig().getName().equals(trainedModelConfig3.getInferenceConfig().getName())) {
                    actionListener.onFailure(ExceptionsHelper.badRequestException("cannot reassign model_alias [{}] to model [{}] with inference config type [{}] from model [{}] with type [{}]", new Object[]{request.getModelAlias(), trainedModelConfig2.getModelId(), trainedModelConfig2.getInferenceConfig().getName(), trainedModelConfig3.getModelId(), trainedModelConfig3.getInferenceConfig().getName()}));
                    return;
                }
                HashSet hashSet2 = new HashSet(trainedModelConfig3.getInput().getFieldNames());
                HashSet hashSet3 = new HashSet(trainedModelConfig2.getInput().getFieldNames());
                if (Sets.difference(hashSet2, hashSet3).size() > hashSet2.size() / 2 || Sets.intersection(hashSet3, hashSet2).size() < hashSet2.size() / 2) {
                    String message = Messages.getMessage("The input fields for new model [{0}] and for old model [{1}] differ significantly, model results may change drastically.", new Object[]{request.getModelId(), modelId});
                    this.auditor.warning(modelId, message);
                    logger.warn("[{}] {}", modelId, message);
                    HeaderWarning.addWarning(message, new Object[0]);
                }
            }
            submitUnbatchedTask("update-model-alias", new AckedClusterStateUpdateTask(request, actionListener) { // from class: org.elasticsearch.xpack.ml.action.TransportPutTrainedModelAliasAction.1
                public ClusterState execute(ClusterState clusterState2) {
                    return TransportPutTrainedModelAliasAction.updateModelAlias(clusterState2, request);
                }
            });
        };
        Objects.requireNonNull(actionListener);
        trainedModelProvider.getTrainedModels((Set<String>) hashSet, empty, true, (TaskId) null, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    @SuppressForbidden(reason = "legacy usage of unbatched task")
    private void submitUnbatchedTask(String str, ClusterStateUpdateTask clusterStateUpdateTask) {
        this.clusterService.submitUnbatchedStateUpdateTask(str, clusterStateUpdateTask);
    }

    static ClusterState updateModelAlias(ClusterState clusterState, PutTrainedModelAliasAction.Request request) {
        ClusterState.Builder builder = ClusterState.builder(clusterState);
        ModelAliasMetadata fromState = ModelAliasMetadata.fromState(clusterState);
        String modelId = fromState.getModelId(request.getModelAlias());
        HashMap hashMap = new HashMap(fromState.modelAliases());
        if (modelId == null) {
            logger.info("creating new model_alias [{}] for model [{}]", request.getModelAlias(), request.getModelId());
        } else {
            logger.info("updating model_alias [{}] to refer to model [{}] from model [{}]", request.getModelAlias(), request.getModelId(), modelId);
        }
        hashMap.put(request.getModelAlias(), new ModelAliasMetadata.ModelAliasEntry(request.getModelId()));
        builder.metadata(Metadata.builder(clusterState.getMetadata()).putCustom(ModelAliasMetadata.NAME, new ModelAliasMetadata(hashMap)).build());
        return builder.build();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ClusterBlockException checkBlock(PutTrainedModelAliasAction.Request request, ClusterState clusterState) {
        return clusterState.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
    }

    protected /* bridge */ /* synthetic */ void masterOperation(Task task, MasterNodeRequest masterNodeRequest, ClusterState clusterState, ActionListener actionListener) throws Exception {
        masterOperation(task, (PutTrainedModelAliasAction.Request) masterNodeRequest, clusterState, (ActionListener<AcknowledgedResponse>) actionListener);
    }
}
