package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.OptionalInt;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.class */
public abstract class TokenizationResult {
    public static final int SPECIAL_TOKEN_POSITION = -1;
    private final List<String> vocab;
    private final List<Tokens> tokens;
    private final int maxLength;
    private final int padTokenId;

    /* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens.class */
    public static final class Tokens extends Record {
        private final List<String> input;
        private final List<List<? extends DelimitedToken>> tokens;
        private final boolean truncated;
        private final int[] tokenIds;
        private final int[] tokenMap;
        private final int spanPrev;
        private final int sequenceId;
        private final int seqPairOffset;
        static final /* synthetic */ boolean $assertionsDisabled;

        public Tokens(List<String> list, List<List<? extends DelimitedToken>> list2, boolean z, int[] iArr, int[] iArr2, int i, int i2, int i3) {
            if (!$assertionsDisabled && iArr.length != iArr2.length) {
                throw new AssertionError();
            }
            if (i != -1 && z) {
                throw new IllegalArgumentException("should not truncate when windowing is enabled");
            }
            this.input = list;
            this.tokens = list2;
            this.truncated = z;
            this.tokenIds = iArr;
            this.tokenMap = iArr2;
            this.spanPrev = i;
            this.sequenceId = i2;
            this.seqPairOffset = i3;
        }

        public OptionalInt getTokenIndex(int i) {
            return IntStream.range(0, this.tokenIds.length).filter(i2 -> {
                return i == this.tokenIds[i2];
            }).findFirst();
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, Tokens.class), Tokens.class, "input;tokens;truncated;tokenIds;tokenMap;spanPrev;sequenceId;seqPairOffset", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->input:Ljava/util/List;", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->tokens:Ljava/util/List;", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->truncated:Z", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->tokenIds:[I", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->tokenMap:[I", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->spanPrev:I", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->sequenceId:I", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->seqPairOffset:I").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, Tokens.class), Tokens.class, "input;tokens;truncated;tokenIds;tokenMap;spanPrev;sequenceId;seqPairOffset", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->input:Ljava/util/List;", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->tokens:Ljava/util/List;", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->truncated:Z", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->tokenIds:[I", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->tokenMap:[I", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->spanPrev:I", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->sequenceId:I", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->seqPairOffset:I").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, Tokens.class, Object.class), Tokens.class, "input;tokens;truncated;tokenIds;tokenMap;spanPrev;sequenceId;seqPairOffset", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->input:Ljava/util/List;", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->tokens:Ljava/util/List;", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->truncated:Z", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->tokenIds:[I", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->tokenMap:[I", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->spanPrev:I", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->sequenceId:I", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$Tokens;->seqPairOffset:I").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public List<String> input() {
            return this.input;
        }

        public List<List<? extends DelimitedToken>> tokens() {
            return this.tokens;
        }

        public boolean truncated() {
            return this.truncated;
        }

        public int[] tokenIds() {
            return this.tokenIds;
        }

        public int[] tokenMap() {
            return this.tokenMap;
        }

        public int spanPrev() {
            return this.spanPrev;
        }

        public int sequenceId() {
            return this.sequenceId;
        }

        public int seqPairOffset() {
            return this.seqPairOffset;
        }

        static {
            $assertionsDisabled = !TokenizationResult.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult$TokensBuilder.class */
    public interface TokensBuilder {
        TokensBuilder addSequence(List<Integer> list, List<Integer> list2);

        TokensBuilder addSequencePair(List<Integer> list, List<Integer> list2, List<Integer> list3, List<Integer> list4);

        Tokens build(List<String> list, boolean z, List<List<? extends DelimitedToken>> list2, int i, int i2);

        default Tokens build(String str, boolean z, List<? extends DelimitedToken> list, int i, int i2) {
            return build(List.of(str), z, List.of(list), i, i2);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public TokenizationResult(List<String> list, List<Tokens> list2, int i) {
        this.vocab = list;
        this.tokens = list2;
        this.padTokenId = i;
        int i2 = 0;
        HashSet hashSet = new HashSet();
        for (Tokens tokens : list2) {
            i2 = Math.max(tokens.tokenIds.length, i2);
            if (hashSet.contains(Integer.valueOf(tokens.sequenceId())) && tokens.spanPrev == -1) {
                throw new IllegalArgumentException("cannot window a sequence without a configured span");
            }
            hashSet.add(Integer.valueOf(tokens.sequenceId));
        }
        this.maxLength = i2;
    }

    public Map<Integer, List<Tokens>> getTokensBySequenceId() {
        return (Map) this.tokens.stream().collect(Collectors.groupingBy((v0) -> {
            return v0.sequenceId();
        }));
    }

    List<Tokens> getTokens() {
        return this.tokens;
    }

    public String getFromVocab(int i) {
        return this.vocab.get(i);
    }

    public String decode(String str) {
        return str;
    }

    public Tokens getTokenization(int i) {
        return this.tokens.get(i);
    }

    public boolean anyTruncated() {
        return this.tokens.stream().anyMatch((v0) -> {
            return v0.truncated();
        });
    }

    public boolean isEmpty() {
        return this.tokens.isEmpty() || this.tokens.stream().allMatch(tokens -> {
            return tokens.tokenIds.length == 0;
        });
    }

    public abstract NlpTask.Request buildRequest(String str, Tokenization.Truncate truncate) throws IOException;

    /* JADX INFO: Access modifiers changed from: protected */
    public void writePaddedTokens(String str, XContentBuilder xContentBuilder) throws IOException {
        xContentBuilder.startArray(str);
        for (Tokens tokens : this.tokens) {
            xContentBuilder.startArray();
            for (int i : tokens.tokenIds) {
                xContentBuilder.value(i);
            }
            for (int length = tokens.tokenIds.length; length < this.maxLength; length++) {
                xContentBuilder.value(this.padTokenId);
            }
            xContentBuilder.endArray();
        }
        xContentBuilder.endArray();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void writeAttentionMask(String str, XContentBuilder xContentBuilder) throws IOException {
        xContentBuilder.startArray(str);
        for (Tokens tokens : this.tokens) {
            xContentBuilder.startArray();
            for (int i : tokens.tokenIds) {
                xContentBuilder.value(1);
            }
            for (int length = tokens.tokenIds.length; length < this.maxLength; length++) {
                xContentBuilder.value(this.padTokenId);
            }
            xContentBuilder.endArray();
        }
        xContentBuilder.endArray();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void writeTokenTypeIds(String str, XContentBuilder xContentBuilder) throws IOException {
        xContentBuilder.startArray(str);
        for (Tokens tokens : this.tokens) {
            xContentBuilder.startArray();
            if (tokens.seqPairOffset <= 0) {
                for (int i = 0; i < this.maxLength; i++) {
                    xContentBuilder.value(0);
                }
            } else {
                for (int i2 = 0; i2 < tokens.seqPairOffset; i2++) {
                    xContentBuilder.value(0);
                }
                for (int i3 = tokens.seqPairOffset; i3 < this.maxLength; i3++) {
                    xContentBuilder.value(1);
                }
            }
            xContentBuilder.endArray();
        }
        xContentBuilder.endArray();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void writePositionIds(String str, XContentBuilder xContentBuilder) throws IOException {
        xContentBuilder.startArray(str);
        for (int i = 0; i < this.tokens.size(); i++) {
            xContentBuilder.startArray();
            for (int i2 = 0; i2 < this.maxLength; i2++) {
                xContentBuilder.value(i2);
            }
            xContentBuilder.endArray();
        }
        xContentBuilder.endArray();
    }
}
