package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.OptionalInt;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizationResult;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.class */
public class BertTokenizer extends NlpTokenizer {
    public static final String UNKNOWN_TOKEN = "[UNK]";
    public static final String SEPARATOR_TOKEN = "[SEP]";
    public static final String PAD_TOKEN = "[PAD]";
    public static final String CLASS_TOKEN = "[CLS]";
    public static final String MASK_TOKEN = "[MASK]";
    private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);
    private final WordPieceAnalyzer wordPieceAnalyzer;
    protected final List<String> originalVocab;
    private final SortedMap<String, Integer> vocab;
    protected final boolean withSpecialTokens;
    private final int maxSequenceLength;
    protected final int sepTokenId;
    private final int clsTokenId;
    private final String padToken;
    protected final int padTokenId;
    private final String maskToken;
    private final String unknownToken;

    /* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer$Builder.class */
    public static class Builder {
        protected final List<String> originalVocab;
        protected final SortedMap<String, Integer> vocab;
        protected boolean doLowerCase;
        protected boolean withSpecialTokens;
        protected int span;
        protected int maxSequenceLength;
        protected Set<String> neverSplit;
        protected boolean doTokenizeCjKChars = true;
        protected Boolean doStripAccents = null;

        protected Builder(List<String> list, Tokenization tokenization) {
            this.span = -1;
            this.originalVocab = list;
            this.vocab = buildSortedVocab(list);
            this.doLowerCase = tokenization.doLowerCase();
            this.withSpecialTokens = tokenization.withSpecialTokens();
            this.maxSequenceLength = tokenization.maxSequenceLength();
            this.span = tokenization.getSpan();
        }

        private static SortedMap<String, Integer> buildSortedVocab(List<String> list) {
            TreeMap treeMap = new TreeMap();
            for (int i = 0; i < list.size(); i++) {
                treeMap.put(list.get(i), Integer.valueOf(i));
            }
            return treeMap;
        }

        public Builder setDoLowerCase(boolean z) {
            this.doLowerCase = z;
            return this;
        }

        public Builder setDoTokenizeCjKChars(boolean z) {
            this.doTokenizeCjKChars = z;
            return this;
        }

        public Builder setDoStripAccents(Boolean bool) {
            this.doStripAccents = bool;
            return this;
        }

        public Builder setNeverSplit(Set<String> set) {
            this.neverSplit = set;
            return this;
        }

        public Builder setMaxSequenceLength(int i) {
            this.maxSequenceLength = i;
            return this;
        }

        public Builder setWithSpecialTokens(boolean z) {
            this.withSpecialTokens = z;
            return this;
        }

        public BertTokenizer build() {
            if (this.doStripAccents == null) {
                this.doStripAccents = Boolean.valueOf(this.doLowerCase);
            }
            if (this.neverSplit == null) {
                this.neverSplit = Collections.emptySet();
            }
            return new BertTokenizer(this.originalVocab, this.vocab, this.doLowerCase, this.doTokenizeCjKChars, this.doStripAccents.booleanValue(), this.withSpecialTokens, this.maxSequenceLength, this.neverSplit);
        }
    }

    protected BertTokenizer(List<String> list, SortedMap<String, Integer> sortedMap, boolean z, boolean z2, boolean z3, boolean z4, int i, Set<String> set) {
        this(list, sortedMap, z, z2, z3, z4, i, Sets.union(set, NEVER_SPLIT), SEPARATOR_TOKEN, CLASS_TOKEN, PAD_TOKEN, MASK_TOKEN, "[UNK]");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BertTokenizer(List<String> list, SortedMap<String, Integer> sortedMap, boolean z, boolean z2, boolean z3, boolean z4, int i, Set<String> set, String str, String str2, String str3, String str4, String str5) {
        this.wordPieceAnalyzer = new WordPieceAnalyzer(list, new ArrayList(set), z, z2, z3, str5);
        this.originalVocab = list;
        this.vocab = sortedMap;
        this.withSpecialTokens = z4;
        this.maxSequenceLength = i;
        if (!sortedMap.containsKey(str5)) {
            throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", new Object[]{str5});
        }
        if (!sortedMap.containsKey(str3)) {
            throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", new Object[]{str3});
        }
        this.padTokenId = sortedMap.get(str3).intValue();
        if (z4) {
            Set difference = Sets.difference(Set.of(str, str2), sortedMap.keySet());
            if (!difference.isEmpty()) {
                throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required {} token(s)", new Object[]{difference});
            }
            this.sepTokenId = sortedMap.get(str).intValue();
            this.clsTokenId = sortedMap.get(str2).intValue();
        } else {
            this.sepTokenId = -1;
            this.clsTokenId = -1;
        }
        this.padToken = str3;
        this.maskToken = str4;
        this.unknownToken = str5;
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer
    int sepTokenId() {
        return this.sepTokenId;
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer
    int maxSequenceLength() {
        return this.maxSequenceLength;
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer
    boolean isWithSpecialTokens() {
        return this.withSpecialTokens;
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer
    int clsTokenId() {
        return this.clsTokenId;
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer
    public String getPadToken() {
        return this.padToken;
    }

    public String getUnknownToken() {
        return this.unknownToken;
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer
    public OptionalInt getPadTokenId() {
        Integer num = this.vocab.get(this.padToken);
        return num != null ? OptionalInt.of(num.intValue()) : OptionalInt.empty();
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer
    public OptionalInt getMaskTokenId() {
        Integer num = this.vocab.get(this.maskToken);
        return num != null ? OptionalInt.of(num.intValue()) : OptionalInt.empty();
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer
    public String getMaskToken() {
        return this.maskToken;
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer
    public TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokens> list) {
        return new BertTokenizationResult(this.originalVocab, list, this.vocab.get(this.padToken).intValue());
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer
    TokenizationResult.TokensBuilder createTokensBuilder(int i, int i2, boolean z) {
        return new BertTokenizationResult.BertTokensBuilder(z, i, i2);
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer
    public NlpTask.RequestBuilder requestBuilder() {
        return (list, str, truncate, i) -> {
            return buildTokenizationResult((List) IntStream.range(0, list.size()).boxed().flatMap(num -> {
                return tokenize((String) list.get(num.intValue()), truncate, i, num.intValue()).stream();
            }).collect(Collectors.toList())).buildRequest(str, truncate);
        };
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer
    int getNumExtraTokensForSeqPair() {
        return 3;
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer
    public NlpTokenizer.InnerTokenization innerTokenize(String str) {
        ArrayList arrayList = new ArrayList();
        try {
            TokenStream tokenStream = this.wordPieceAnalyzer.tokenStream("input", str);
            try {
                tokenStream.reset();
                PositionIncrementAttribute addAttribute = tokenStream.addAttribute(PositionIncrementAttribute.class);
                int i = -1;
                while (tokenStream.incrementToken()) {
                    i += addAttribute.getPositionIncrement();
                    arrayList.add(Integer.valueOf(i));
                }
                if (tokenStream != null) {
                    tokenStream.close();
                }
                return new NlpTokenizer.InnerTokenization(new ArrayList(this.wordPieceAnalyzer.getTokens()), arrayList);
            } finally {
            }
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    public void close() {
        this.wordPieceAnalyzer.close();
    }

    public int getMaxSequenceLength() {
        return this.maxSequenceLength;
    }

    public static Builder builder(List<String> list, Tokenization tokenization) {
        return new Builder(list, tokenization);
    }
}
