package edu.berkeley.nlp.PCFGLA.reranker;

import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Lists;
import edu.berkeley.nlp.util.PriorityQueue;
import fig.basic.Pair;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/berkeley/nlp/PCFGLA/reranker/ForestReranker.class */
public class ForestReranker {
    private final BaseModel baseModel;
    private final FeatureExtractorManager featureExtractorManager;
    private final LocalFeatureExtractor localFeatureExtractor;
    private final NonlocalFeatureExtractor nonlocalFeatureExtractor;
    private final int beamSize;
    private double[] w = defaultWeightVector();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/berkeley/nlp/PCFGLA/reranker/ForestReranker$BeamKey.class */
    public static class BeamKey {
        public final boolean unaryEdge;
        public final int edgeIndex;
        public final int leftChildK;
        public final int rightChildK;

        public BeamKey(boolean z, int i, int i2, int i3) {
            this.unaryEdge = z;
            this.edgeIndex = i;
            this.leftChildK = i2;
            this.rightChildK = i3;
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof BeamKey)) {
                return false;
            }
            BeamKey beamKey = (BeamKey) obj;
            return beamKey.unaryEdge == this.unaryEdge && beamKey.edgeIndex == this.edgeIndex && beamKey.leftChildK == this.leftChildK && beamKey.rightChildK == this.rightChildK;
        }

        public int hashCode() {
            return (this.edgeIndex * 39 * 39) + (this.leftChildK * 39) + this.rightChildK;
        }
    }

    public FeatureExtractorManager getFeatureExtractorManager() {
        return this.featureExtractorManager;
    }

    public ForestReranker(BaseModel baseModel, FeatureExtractorManager featureExtractorManager, int i) {
        this.baseModel = baseModel;
        this.featureExtractorManager = featureExtractorManager;
        this.localFeatureExtractor = featureExtractorManager.getLocalFeatureExtractor();
        this.nonlocalFeatureExtractor = featureExtractorManager.getNonlocalFeatureExtractor();
        this.beamSize = i;
    }

    private double[] defaultWeightVector() {
        return DoubleArrays.constantArray(0.0d, this.featureExtractorManager.getTotalNumFeatures());
    }

    public Tree<String> getBestParse(PrunedForest prunedForest) {
        return this.baseModel.relabelStates(getViterbiStateTree(prunedForest), prunedForest.getSentence());
    }

    public Tree<String> getBestParse(PrunedForest prunedForest, int[][] iArr, int[][] iArr2) {
        return this.baseModel.relabelStates(getViterbiStateTree(prunedForest, iArr, iArr2), prunedForest.getSentence());
    }

    public Tree<Node> getViterbiStateTree(PrunedForest prunedForest) {
        return getViterbiStateTree(prunedForest, null, null);
    }

    public Tree<Node> getViterbiStateTree(PrunedForest prunedForest, int[][] iArr, int[][] iArr2) {
        RerankedForest rerankForest = rerankForest(prunedForest, iArr, iArr2);
        if (rerankForest.hasParseFailure()) {
            return null;
        }
        return rerankForest.getViterbiTree();
    }

    public Pair<Double, int[]> getViterbiTreeFeatureVector(PrunedForest prunedForest) {
        return getViterbiTreeFeatureVector(prunedForest, null, null);
    }

    public Pair<Double, int[]> getViterbiTreeFeatureVector(PrunedForest prunedForest, int[][] iArr, int[][] iArr2) {
        RerankedForest rerankForest = rerankForest(prunedForest, iArr, iArr2);
        if (rerankForest.hasParseFailure()) {
            return null;
        }
        return getViterbiTreeFeatureVector(rerankForest);
    }

    public Pair<Double, int[]> getViterbiTreeFeatureVector(RerankedForest rerankedForest) {
        ArrayList arrayList = new ArrayList();
        double d = 0.0d;
        for (int[] iArr : this.localFeatureExtractor.precomputeLocalIndicatorFeatures(rerankedForest.getBinaryEdgesFromViterbiTree(), rerankedForest.sentence)) {
            concat(arrayList, iArr);
        }
        for (int[] iArr2 : this.localFeatureExtractor.precomputeLocalIndicatorFeatures(rerankedForest.getUnaryEdgesFromViterbiTree(), rerankedForest.sentence)) {
            concat(arrayList, iArr2);
        }
        for (int i : rerankedForest.getBinaryEdgeIndicesFromViterbiTree()) {
            arrayList.addAll(this.nonlocalFeatureExtractor.computeNonlocalIndicatorFeaturesForBinaryEdge(rerankedForest.baseForest, i, rerankedForest.isUnaryEdgeBacktrace, rerankedForest.edgeIndexBacktrace, rerankedForest.childKBacktrace, 0, 0, rerankedForest.sentence));
            d += rerankedForest.baseForest.getBinaryEdgeScore(i);
        }
        for (int i2 : rerankedForest.getUnaryEdgeIndicesFromViterbiTree()) {
            arrayList.addAll(this.nonlocalFeatureExtractor.computeNonlocalIndicatorFeaturesForUnaryEdge(rerankedForest.baseForest, i2, rerankedForest.isUnaryEdgeBacktrace, rerankedForest.edgeIndexBacktrace, rerankedForest.childKBacktrace, 0, rerankedForest.sentence));
            d += rerankedForest.baseForest.getUnaryEdgeScore(i2);
        }
        return Pair.makePair(Double.valueOf(d), Lists.toPrimitiveArray((List<Integer>) arrayList));
    }

    private void concat(List<Integer> list, int[] iArr) {
        for (int i : iArr) {
            list.add(Integer.valueOf(i));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v10, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v13, types: [boolean[], boolean[][]] */
    /* JADX WARN: Type inference failed for: r0v16, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r0v19, types: [int[][], int[][][]] */
    private RerankedForest rerankForest(PrunedForest prunedForest, int[][] iArr, int[][] iArr2) {
        List<String> sentence = prunedForest.getSentence();
        Node[] nodes = prunedForest.getNodes();
        int[][] iArr3 = iArr;
        if (iArr3 == null) {
            iArr3 = this.localFeatureExtractor.precomputeLocalIndicatorFeatures(prunedForest.getBinaryEdges(), sentence);
        }
        int[][] iArr4 = iArr2;
        if (iArr4 == null) {
            iArr4 = this.localFeatureExtractor.precomputeLocalIndicatorFeatures(prunedForest.getUnaryEdges(), sentence);
        }
        ?? r0 = new double[nodes.length];
        ?? r02 = new boolean[nodes.length];
        ?? r03 = new int[nodes.length];
        ?? r04 = new int[nodes.length];
        for (int i = 0; i < nodes.length; i++) {
            if (isPosNode(nodes[i])) {
                double[] dArr = new double[1];
                dArr[0] = this.w[0] * prunedForest.getLexicalNodeScore(i);
                r0[i] = dArr;
            } else {
                PriorityQueue priorityQueue = new PriorityQueue();
                ArrayList arrayList = new ArrayList();
                for (int i2 : prunedForest.getBinaryEdgesByNode(i)) {
                    pushOntoHeap(new BeamKey(false, i2, 0, 0), priorityQueue, prunedForest, iArr3, iArr4, r0, r02, r03, r04, sentence);
                }
                for (int i3 : prunedForest.getUnaryEdgesByNode(i)) {
                    pushOntoHeap(new BeamKey(true, i3, 0, -1), priorityQueue, prunedForest, iArr3, iArr4, r0, r02, r03, r04, sentence);
                }
                while (arrayList.size() < this.beamSize && priorityQueue.hasNext()) {
                    double priority = priorityQueue.getPriority();
                    BeamKey beamKey = (BeamKey) priorityQueue.next();
                    arrayList.add(Pair.makePair(Double.valueOf(priority), beamKey));
                    if (beamKey.unaryEdge) {
                        pushOntoHeap(new BeamKey(true, beamKey.edgeIndex, beamKey.leftChildK + 1, -1), priorityQueue, prunedForest, iArr3, iArr4, r0, r02, r03, r04, sentence);
                    } else {
                        pushOntoHeap(new BeamKey(false, beamKey.edgeIndex, beamKey.leftChildK + 1, beamKey.rightChildK), priorityQueue, prunedForest, iArr3, iArr4, r0, r02, r03, r04, sentence);
                        pushOntoHeap(new BeamKey(false, beamKey.edgeIndex, beamKey.leftChildK, beamKey.rightChildK + 1), priorityQueue, prunedForest, iArr3, iArr4, r0, r02, r03, r04, sentence);
                    }
                }
                storeBuffer(i, arrayList, r0, r02, r03, r04);
            }
        }
        return new RerankedForest(prunedForest, sentence, r0, r02, r03, r04);
    }

    private boolean isPosNode(Node node) {
        return this.baseModel.isPosTag(node.state);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void storeBuffer(int i, List<Pair<Double, BeamKey>> list, double[][] dArr, boolean[][] zArr, int[][] iArr, int[][][] iArr2) {
        int size = list.size();
        dArr[i] = new double[size];
        zArr[i] = new boolean[size];
        iArr[i] = new int[size];
        iArr2[i] = new int[size];
        for (int i2 = 0; i2 < size; i2++) {
            dArr[i][i2] = list.get(i2).getFirst().doubleValue();
            BeamKey second = list.get(i2).getSecond();
            zArr[i][i2] = second.unaryEdge;
            iArr[i][i2] = second.edgeIndex;
            if (second.unaryEdge) {
                int[] iArr3 = new int[1];
                iArr3[0] = second.leftChildK;
                iArr2[i][i2] = iArr3;
            } else {
                int[] iArr4 = new int[2];
                iArr4[0] = second.leftChildK;
                iArr4[1] = second.rightChildK;
                iArr2[i][i2] = iArr4;
            }
        }
    }

    private void pushOntoHeap(BeamKey beamKey, PriorityQueue<BeamKey> priorityQueue, PrunedForest prunedForest, int[][] iArr, int[][] iArr2, double[][] dArr, boolean[][] zArr, int[][] iArr3, int[][][] iArr4, List<String> list) {
        if (beamKey.unaryEdge) {
            int unaryChildNodeIndex = prunedForest.getUnaryChildNodeIndex(beamKey.edgeIndex);
            if (beamKey.leftChildK >= dArr[unaryChildNodeIndex].length) {
                return;
            }
            int[] iArr5 = iArr2[beamKey.edgeIndex];
            List<Integer> computeNonlocalIndicatorFeaturesForUnaryEdge = this.nonlocalFeatureExtractor.computeNonlocalIndicatorFeaturesForUnaryEdge(prunedForest, beamKey.edgeIndex, zArr, iArr3, iArr4, beamKey.leftChildK, list);
            priorityQueue.add(beamKey, (this.w[0] * prunedForest.getUnaryEdgeScore(beamKey.edgeIndex)) + dotProduct(this.w, iArr5) + dotProduct(this.w, computeNonlocalIndicatorFeaturesForUnaryEdge) + dArr[unaryChildNodeIndex][beamKey.leftChildK]);
            return;
        }
        int leftChildNodeIndex = prunedForest.getLeftChildNodeIndex(beamKey.edgeIndex);
        int rightChildNodeIndex = prunedForest.getRightChildNodeIndex(beamKey.edgeIndex);
        if (beamKey.leftChildK >= dArr[leftChildNodeIndex].length || beamKey.rightChildK >= dArr[rightChildNodeIndex].length) {
            return;
        }
        int[] iArr6 = iArr[beamKey.edgeIndex];
        List<Integer> computeNonlocalIndicatorFeaturesForBinaryEdge = this.nonlocalFeatureExtractor.computeNonlocalIndicatorFeaturesForBinaryEdge(prunedForest, beamKey.edgeIndex, zArr, iArr3, iArr4, beamKey.leftChildK, beamKey.rightChildK, list);
        priorityQueue.add(beamKey, (this.w[0] * prunedForest.getBinaryEdgeScore(beamKey.edgeIndex)) + dotProduct(this.w, iArr6) + dotProduct(this.w, computeNonlocalIndicatorFeaturesForBinaryEdge) + dArr[leftChildNodeIndex][beamKey.leftChildK] + dArr[rightChildNodeIndex][beamKey.rightChildK]);
    }

    private static double dotProduct(double[] dArr, int[] iArr) {
        double d = 0.0d;
        for (int i : iArr) {
            if (i > 0 && i < dArr.length) {
                d += dArr[i];
            }
        }
        return d;
    }

    private static double dotProduct(double[] dArr, List<Integer> list) {
        double d = 0.0d;
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (intValue > 0 && intValue < dArr.length) {
                d += dArr[intValue];
            }
        }
        return d;
    }

    public void setWeights(double[] dArr) {
        this.w = dArr;
    }

    public double[] getCurrentWeights() {
        return this.w;
    }
}
