package edu.berkeley.nlp.crf;

import edu.berkeley.nlp.classify.Encoding;
import edu.berkeley.nlp.classify.FeatureExtractor;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Logger;
import fig.basic.Pair;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:edu/berkeley/nlp/crf/Counts.class */
public class Counts<V, E, F, L> {
    private final Encoding<F, L> encoding;
    private final FeatureExtractor<V, F> vertexExtractor;
    private final FeatureExtractor<E, F> edgeExtractor;
    private final Inference<V, E, F, L> inf;

    public Counts(Encoding<F, L> encoding, FeatureExtractor<V, F> featureExtractor, FeatureExtractor<E, F> featureExtractor2) {
        this.encoding = encoding;
        this.vertexExtractor = featureExtractor;
        this.edgeExtractor = featureExtractor2;
        this.inf = new Inference<>(encoding, featureExtractor, featureExtractor2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public List<Counter<F>> getEmpiricalCounts(List<? extends LabeledInstanceSequence<V, E, L>> list) {
        int numLabels = this.encoding.getNumLabels();
        ArrayList arrayList = new ArrayList(numLabels);
        for (int i = 0; i < numLabels; i++) {
            arrayList.add(new Counter<>());
        }
        for (LabeledInstanceSequence<V, E, L> labeledInstanceSequence : list) {
            for (int i2 = 0; i2 < labeledInstanceSequence.getSequenceLength(); i2++) {
                Counter<T> extractFeatures = this.vertexExtractor.extractFeatures(labeledInstanceSequence.getVertexInstance(i2));
                int labelIndex = this.encoding.getLabelIndex(labeledInstanceSequence.getGoldLabel(i2));
                arrayList.get(labelIndex).incrementAll(extractFeatures);
                if (i2 > 0) {
                    arrayList.get(labelIndex).incrementAll(this.edgeExtractor.extractFeatures(labeledInstanceSequence.getEdgeInstance(i2, labeledInstanceSequence.getGoldLabel(i2 - 1))));
                }
            }
        }
        return arrayList;
    }

    public Pair<Double, List<Counter<F>>> getLogNormalizationAndExpectedCounts(List<? extends InstanceSequence<V, E, L>> list, double[] dArr) {
        int numLabels = this.encoding.getNumLabels();
        ArrayList arrayList = new ArrayList(numLabels);
        for (int i = 0; i < numLabels; i++) {
            arrayList.add(new Counter());
        }
        double d = 0.0d;
        Logger.startTrack("Computing expected counts", new Object[0]);
        int i2 = 0;
        for (InstanceSequence<V, E, L> instanceSequence : list) {
            double[][] alphas = this.inf.getAlphas(instanceSequence, dArr);
            double[][] betas = this.inf.getBetas(instanceSequence, dArr);
            d += Math.log(this.inf.getNormalizationConstant(alphas, betas));
            double[][] vertexPosteriors = this.inf.getVertexPosteriors(alphas, betas);
            double[][][] edgePosteriors = this.inf.getEdgePosteriors(instanceSequence, dArr, alphas, betas);
            for (int i3 = 0; i3 < instanceSequence.getSequenceLength(); i3++) {
                Counter extractFeatures = this.vertexExtractor.extractFeatures(instanceSequence.getVertexInstance(i3));
                for (int i4 = 0; i4 < numLabels; i4++) {
                    ((Counter) arrayList.get(i4)).incrementAll(extractFeatures.scaledClone(vertexPosteriors[i3][i4]));
                }
                if (i3 > 0) {
                    for (int i5 = 0; i5 < numLabels; i5++) {
                        Counter extractFeatures2 = this.edgeExtractor.extractFeatures(instanceSequence.getEdgeInstance(i3, this.encoding.getLabel(i5)));
                        for (int i6 = 0; i6 < numLabels; i6++) {
                            ((Counter) arrayList.get(i6)).incrementAll(extractFeatures2.scaledClone(edgePosteriors[i3][i5][i6]));
                        }
                    }
                }
            }
            i2++;
            Logger.logs("Processed %d/%d sentences", Integer.valueOf(i2), Integer.valueOf(list.size()));
        }
        Logger.endTrack();
        return Pair.makePair(Double.valueOf(d), arrayList);
    }
}
