package edu.berkeley.nlp.crf;

import edu.berkeley.nlp.classify.Encoding;
import edu.berkeley.nlp.classify.FeatureExtractor;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.LBFGSMinimizer;
import edu.berkeley.nlp.util.Lists;
import edu.berkeley.nlp.util.Logger;
import edu.berkeley.nlp.util.PriorityQueue;
import fig.basic.Indexer;
import fig.basic.Pair;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:edu/berkeley/nlp/crf/ChainCRFTagger.class */
public class ChainCRFTagger<V, E, L> implements Serializable {
    private static final long serialVersionUID = 9165167851374358823L;
    private final Encoding<?, L> encoding;
    private final Inference<V, E, ?, L> inf;
    private final double[] w;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/berkeley/nlp/crf/ChainCRFTagger$Factory.class */
    public static class Factory<V, E, F, L> {
        private final FeatureExtractor<V, F> vertexExtractor;
        private final FeatureExtractor<E, F> edgeExtractor;
        private final double sigma;
        private final int iterations;

        public Factory(FeatureExtractor<V, F> featureExtractor, FeatureExtractor<E, F> featureExtractor2, double d, int i) {
            this.vertexExtractor = featureExtractor;
            this.edgeExtractor = featureExtractor2;
            this.sigma = d;
            this.iterations = i;
        }

        public ChainCRFTagger<V, E, L> trainTagger(List<? extends LabeledInstanceSequence<V, E, L>> list) {
            Encoding<F, L> buildEncoding = buildEncoding(list);
            CRFObjectiveFunction cRFObjectiveFunction = new CRFObjectiveFunction(list, buildEncoding, this.vertexExtractor, this.edgeExtractor, this.sigma);
            LBFGSMinimizer lBFGSMinimizer = new LBFGSMinimizer(this.iterations);
            Logger.startTrack("Training with LBFGS", new Object[0]);
            double[] minimize = lBFGSMinimizer.minimize(cRFObjectiveFunction, DoubleArrays.constantArray(0.0d, buildEncoding.getNumFeatures() * buildEncoding.getNumLabels()), 1.0E-4d, true);
            Logger.endTrack();
            return new ChainCRFTagger<>(buildEncoding, new Inference(buildEncoding, this.vertexExtractor, this.edgeExtractor), minimize);
        }

        /* JADX WARN: Multi-variable type inference failed */
        private Encoding<F, L> buildEncoding(List<? extends LabeledInstanceSequence<V, E, L>> list) {
            Indexer indexer = new Indexer();
            Indexer indexer2 = new Indexer();
            for (LabeledInstanceSequence<V, E, L> labeledInstanceSequence : list) {
                for (int i = 0; i < labeledInstanceSequence.getSequenceLength(); i++) {
                    indexer2.add(labeledInstanceSequence.getGoldLabel(i));
                }
            }
            for (LabeledInstanceSequence<V, E, L> labeledInstanceSequence2 : list) {
                for (int i2 = 0; i2 < labeledInstanceSequence2.getSequenceLength(); i2++) {
                    indexer.addAll(this.vertexExtractor.extractFeatures(labeledInstanceSequence2.getVertexInstance(i2)).keySet());
                    if (i2 > 0) {
                        for (int i3 = 0; i3 < indexer2.size(); i3++) {
                            indexer.addAll(this.edgeExtractor.extractFeatures(labeledInstanceSequence2.getEdgeInstance(i2, indexer2.getObject(i3))).keySet());
                        }
                    }
                }
            }
            return new Encoding<>(indexer, indexer2);
        }
    }

    static {
        $assertionsDisabled = !ChainCRFTagger.class.desiredAssertionStatus();
    }

    public ChainCRFTagger(Encoding<?, L> encoding, Inference<V, E, ?, L> inference, double[] dArr) {
        this.encoding = encoding;
        this.inf = inference;
        this.w = dArr;
    }

    public List<L> getViterbiLabelSequence(InstanceSequence<V, E, L> instanceSequence) {
        return getTopKLabelSequencesAndScores(instanceSequence, 1).get(0).getFirst();
    }

    public List<Pair<List<L>, Double>> getTopKLabelSequencesAndScores(InstanceSequence<V, E, L> instanceSequence, int i) {
        Pair<int[][][][], double[][][]> kBestChartAndBacktrace = this.inf.getKBestChartAndBacktrace(instanceSequence, this.w, i);
        ArrayList arrayList = new ArrayList(i);
        PriorityQueue<Pair<Integer, Integer>> buildRankedScoreQueue = buildRankedScoreQueue(kBestChartAndBacktrace.getSecond()[instanceSequence.getSequenceLength() - 1]);
        for (int i2 = 0; i2 < i && buildRankedScoreQueue.hasNext(); i2++) {
            double priority = buildRankedScoreQueue.getPriority();
            Pair<Integer, Integer> next = buildRankedScoreQueue.next();
            arrayList.add(Pair.makePair(rebuildChain(kBestChartAndBacktrace.getFirst(), next.getFirst().intValue(), next.getSecond().intValue()), Double.valueOf(priority)));
        }
        return arrayList;
    }

    private List<L> rebuildChain(int[][][][] iArr, int i, int i2) {
        int length = iArr.length;
        ArrayList arrayList = new ArrayList(length);
        int i3 = i;
        int i4 = i2;
        for (int i5 = length - 1; i5 >= 0; i5--) {
            arrayList.add(this.encoding.getLabel(i3));
            int i6 = iArr[i5][i3][i4][0];
            i4 = iArr[i5][i3][i4][1];
            i3 = i6;
        }
        if (!$assertionsDisabled && (i3 != -1 || i4 != 0)) {
            throw new AssertionError();
        }
        Lists.reverse(arrayList);
        return arrayList;
    }

    private PriorityQueue<Pair<Integer, Integer>> buildRankedScoreQueue(double[][] dArr) {
        PriorityQueue<Pair<Integer, Integer>> priorityQueue = new PriorityQueue<>();
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                priorityQueue.add(Pair.makePair(Integer.valueOf(i), Integer.valueOf(i2)), dArr[i][i2]);
            }
        }
        return priorityQueue;
    }
}
