package edu.berkeley.nlp.PCFGLA.reranker;

import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Lists;
import edu.berkeley.nlp.util.PriorityQueue;
import fig.basic.Pair;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/berkeley/nlp/PCFGLA/reranker/EarlyFeaturesForestReranker.class */
public class EarlyFeaturesForestReranker {
    private final BaseModel baseModel;
    private final FeatureExtractorManager featureExtractorManager;
    private final LocalFeatureExtractor localFeatureExtractor;
    private final NonlocalFeatureExtractor nonlocalFeatureExtractor;
    private final AnticipatedFeatureExtractor earlyFeatureExtractor;
    private final int beamSize;
    private double[] w = defaultWeightVector();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/berkeley/nlp/PCFGLA/reranker/EarlyFeaturesForestReranker$BeamKey.class */
    public static class BeamKey {
        public final boolean unaryEdge;
        public final int edgeIndex;
        public final int leftChildK;
        public final int rightChildK;
        private int[][] earlyFeatures;

        public BeamKey(boolean z, int i, int i2, int i3) {
            this.unaryEdge = z;
            this.edgeIndex = i;
            this.leftChildK = i2;
            this.rightChildK = i3;
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof BeamKey)) {
                return false;
            }
            BeamKey beamKey = (BeamKey) obj;
            return beamKey.unaryEdge == this.unaryEdge && beamKey.edgeIndex == this.edgeIndex && beamKey.leftChildK == this.leftChildK && beamKey.rightChildK == this.rightChildK;
        }

        public int hashCode() {
            return (this.edgeIndex * 39 * 39) + (this.leftChildK * 39) + this.rightChildK;
        }

        public void setEarlyFeatures(int[][] iArr) {
            this.earlyFeatures = iArr;
        }

        public int[][] getEarlyFeatures() {
            return this.earlyFeatures;
        }
    }

    public EarlyFeaturesForestReranker(BaseModel baseModel, FeatureExtractorManager featureExtractorManager, int i) {
        this.baseModel = baseModel;
        this.featureExtractorManager = featureExtractorManager;
        this.localFeatureExtractor = featureExtractorManager.getLocalFeatureExtractor();
        this.nonlocalFeatureExtractor = featureExtractorManager.getNonlocalFeatureExtractor();
        this.earlyFeatureExtractor = featureExtractorManager.getAnticipatedFeatureExtractor();
        this.beamSize = i;
    }

    private double[] defaultWeightVector() {
        return DoubleArrays.constantArray(0.0d, this.featureExtractorManager.getTotalNumFeatures());
    }

    public Tree<String> getBestParse(PrunedForest prunedForest) {
        return this.baseModel.relabelStates(getViterbiStateTree(prunedForest), prunedForest.getSentence());
    }

    public Tree<Node> getViterbiStateTree(PrunedForest prunedForest) {
        RerankedForest rerankForest = rerankForest(prunedForest);
        if (rerankForest.hasParseFailure()) {
            return null;
        }
        return rerankForest.getViterbiTree();
    }

    public Pair<Double, int[]> getViterbiTreeFeatureVector(PrunedForest prunedForest) {
        RerankedForest rerankForest = rerankForest(prunedForest);
        if (rerankForest.hasParseFailure()) {
            return null;
        }
        return getViterbiTreeFeatureVector(rerankForest);
    }

    private Pair<Double, int[]> getViterbiTreeFeatureVector(RerankedForest rerankedForest) {
        ArrayList arrayList = new ArrayList();
        double d = 0.0d;
        for (int[] iArr : this.localFeatureExtractor.precomputeLocalIndicatorFeatures(rerankedForest.getBinaryEdgesFromViterbiTree(), rerankedForest.sentence)) {
            concat(arrayList, iArr);
        }
        for (int[] iArr2 : this.localFeatureExtractor.precomputeLocalIndicatorFeatures(rerankedForest.getUnaryEdgesFromViterbiTree(), rerankedForest.sentence)) {
            concat(arrayList, iArr2);
        }
        for (int i : rerankedForest.getBinaryEdgeIndicesFromViterbiTree()) {
            arrayList.addAll(this.nonlocalFeatureExtractor.computeNonlocalIndicatorFeaturesForBinaryEdge(rerankedForest.baseForest, i, rerankedForest.isUnaryEdgeBacktrace, rerankedForest.edgeIndexBacktrace, rerankedForest.childKBacktrace, 0, 0, rerankedForest.sentence));
            d += rerankedForest.baseForest.getBinaryEdgeScore(i);
        }
        for (int i2 : rerankedForest.getUnaryEdgeIndicesFromViterbiTree()) {
            arrayList.addAll(this.nonlocalFeatureExtractor.computeNonlocalIndicatorFeaturesForUnaryEdge(rerankedForest.baseForest, i2, rerankedForest.isUnaryEdgeBacktrace, rerankedForest.edgeIndexBacktrace, rerankedForest.childKBacktrace, 0, rerankedForest.sentence));
            d += rerankedForest.baseForest.getUnaryEdgeScore(i2);
        }
        return Pair.makePair(Double.valueOf(d), Lists.toPrimitiveArray((List<Integer>) arrayList));
    }

    private void concat(List<Integer> list, int[] iArr) {
        for (int i : iArr) {
            list.add(Integer.valueOf(i));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v12, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v15, types: [boolean[], boolean[][]] */
    /* JADX WARN: Type inference failed for: r0v18, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r0v21, types: [int[][], int[][][]] */
    /* JADX WARN: Type inference failed for: r0v24, types: [int[][][], int[][][][]] */
    private RerankedForest rerankForest(PrunedForest prunedForest) {
        List<String> sentence = prunedForest.getSentence();
        Node[] nodes = prunedForest.getNodes();
        int[][] precomputeLocalIndicatorFeatures = this.localFeatureExtractor.precomputeLocalIndicatorFeatures(prunedForest.getBinaryEdges(), sentence);
        int[][] precomputeLocalIndicatorFeatures2 = this.localFeatureExtractor.precomputeLocalIndicatorFeatures(prunedForest.getUnaryEdges(), sentence);
        ?? r0 = new double[nodes.length];
        ?? r02 = new boolean[nodes.length];
        ?? r03 = new int[nodes.length];
        ?? r04 = new int[nodes.length];
        ?? r05 = new int[nodes.length][];
        for (int i = 0; i < nodes.length; i++) {
            if (isPosNode(nodes[i])) {
                double[] dArr = new double[1];
                dArr[0] = this.w[0] * prunedForest.getLexicalNodeScore(i);
                r0[i] = dArr;
            } else {
                PriorityQueue priorityQueue = new PriorityQueue();
                ArrayList arrayList = new ArrayList();
                for (int i2 : prunedForest.getBinaryEdgesByNode(i)) {
                    pushOntoHeap(new BeamKey(false, i2, 0, 0), priorityQueue, prunedForest, precomputeLocalIndicatorFeatures, precomputeLocalIndicatorFeatures2, r0, r02, r03, r04, r05, sentence);
                }
                for (int i3 : prunedForest.getUnaryEdgesByNode(i)) {
                    pushOntoHeap(new BeamKey(true, i3, 0, -1), priorityQueue, prunedForest, precomputeLocalIndicatorFeatures, precomputeLocalIndicatorFeatures2, r0, r02, r03, r04, r05, sentence);
                }
                while (arrayList.size() < this.beamSize && priorityQueue.hasNext()) {
                    double priority = priorityQueue.getPriority();
                    BeamKey beamKey = (BeamKey) priorityQueue.next();
                    arrayList.add(Pair.makePair(Double.valueOf(priority), beamKey));
                    if (beamKey.unaryEdge) {
                        pushOntoHeap(new BeamKey(true, beamKey.edgeIndex, beamKey.leftChildK + 1, -1), priorityQueue, prunedForest, precomputeLocalIndicatorFeatures, precomputeLocalIndicatorFeatures2, r0, r02, r03, r04, r05, sentence);
                    } else {
                        pushOntoHeap(new BeamKey(false, beamKey.edgeIndex, beamKey.leftChildK + 1, beamKey.rightChildK), priorityQueue, prunedForest, precomputeLocalIndicatorFeatures, precomputeLocalIndicatorFeatures2, r0, r02, r03, r04, r05, sentence);
                        pushOntoHeap(new BeamKey(false, beamKey.edgeIndex, beamKey.leftChildK, beamKey.rightChildK + 1), priorityQueue, prunedForest, precomputeLocalIndicatorFeatures, precomputeLocalIndicatorFeatures2, r0, r02, r03, r04, r05, sentence);
                    }
                }
                storeBuffer(i, arrayList, r0, r02, r03, r04, r05);
            }
        }
        return new RerankedForest(prunedForest, sentence, r0, r02, r03, r04);
    }

    private boolean isPosNode(Node node) {
        return this.baseModel.isPosTag(node.state);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void storeBuffer(int i, List<Pair<Double, BeamKey>> list, double[][] dArr, boolean[][] zArr, int[][] iArr, int[][][] iArr2, int[][][][] iArr3) {
        int size = list.size();
        dArr[i] = new double[size];
        zArr[i] = new boolean[size];
        iArr[i] = new int[size];
        iArr2[i] = new int[size];
        iArr3[i] = new int[size];
        for (int i2 = 0; i2 < size; i2++) {
            dArr[i][i2] = list.get(i2).getFirst().doubleValue();
            BeamKey second = list.get(i2).getSecond();
            zArr[i][i2] = second.unaryEdge;
            iArr[i][i2] = second.edgeIndex;
            iArr3[i][i2] = second.getEarlyFeatures();
            if (second.unaryEdge) {
                int[] iArr4 = new int[1];
                iArr4[0] = second.leftChildK;
                iArr2[i][i2] = iArr4;
            } else {
                int[] iArr5 = new int[2];
                iArr5[0] = second.leftChildK;
                iArr5[1] = second.rightChildK;
                iArr2[i][i2] = iArr5;
            }
        }
    }

    private void pushOntoHeap(BeamKey beamKey, PriorityQueue<BeamKey> priorityQueue, PrunedForest prunedForest, int[][] iArr, int[][] iArr2, double[][] dArr, boolean[][] zArr, int[][] iArr3, int[][][] iArr4, int[][][][] iArr5, List<String> list) {
        int[] iArr6 = new int[0];
        if (beamKey.unaryEdge) {
            UnaryEdge unaryEdge = prunedForest.getUnaryEdges()[beamKey.edgeIndex];
            int unaryChildNodeIndex = prunedForest.getUnaryChildNodeIndex(beamKey.edgeIndex);
            if (beamKey.leftChildK >= dArr[unaryChildNodeIndex].length) {
                return;
            }
            int[] iArr7 = iArr2[beamKey.edgeIndex];
            List<Integer> computeNonlocalIndicatorFeaturesForUnaryEdge = this.nonlocalFeatureExtractor.computeNonlocalIndicatorFeaturesForUnaryEdge(prunedForest, beamKey.edgeIndex, zArr, iArr3, iArr4, beamKey.leftChildK, list);
            List<int[]> computeAnticipatedIndicatorFeaturesForUnaryEdge = this.earlyFeatureExtractor.computeAnticipatedIndicatorFeaturesForUnaryEdge(prunedForest, beamKey.edgeIndex, zArr, iArr3, iArr4, beamKey.leftChildK, list);
            Pair<List<int[]>, List<int[]>> splitEarlyFeatures = splitEarlyFeatures(iArr5[unaryChildNodeIndex][beamKey.leftChildK], unaryEdge.startIndex, unaryEdge.stopIndex);
            double unaryEdgeScore = (((((this.w[0] * prunedForest.getUnaryEdgeScore(beamKey.edgeIndex)) + dotProduct(this.w, iArr7)) + listDotProduct(this.w, computeNonlocalIndicatorFeaturesForUnaryEdge)) + dArr[unaryChildNodeIndex][beamKey.leftChildK]) + dotProduct(this.w, computeAnticipatedIndicatorFeaturesForUnaryEdge)) - dotProduct(this.w, splitEarlyFeatures.getSecond());
            ArrayList arrayList = new ArrayList(computeAnticipatedIndicatorFeaturesForUnaryEdge);
            arrayList.addAll(splitEarlyFeatures.getFirst());
            beamKey.setEarlyFeatures((int[][]) arrayList.toArray((Object[]) iArr6));
            priorityQueue.add(beamKey, unaryEdgeScore);
            return;
        }
        BinaryEdge binaryEdge = prunedForest.getBinaryEdges()[beamKey.edgeIndex];
        int leftChildNodeIndex = prunedForest.getLeftChildNodeIndex(beamKey.edgeIndex);
        int rightChildNodeIndex = prunedForest.getRightChildNodeIndex(beamKey.edgeIndex);
        if (beamKey.leftChildK >= dArr[leftChildNodeIndex].length || beamKey.rightChildK >= dArr[rightChildNodeIndex].length) {
            return;
        }
        int[] iArr8 = iArr[beamKey.edgeIndex];
        List<Integer> computeNonlocalIndicatorFeaturesForBinaryEdge = this.nonlocalFeatureExtractor.computeNonlocalIndicatorFeaturesForBinaryEdge(prunedForest, beamKey.edgeIndex, zArr, iArr3, iArr4, beamKey.leftChildK, beamKey.rightChildK, list);
        List<int[]> computeAnticipatedIndicatorFeaturesForBinaryEdge = this.earlyFeatureExtractor.computeAnticipatedIndicatorFeaturesForBinaryEdge(prunedForest, beamKey.edgeIndex, zArr, iArr3, iArr4, beamKey.leftChildK, beamKey.rightChildK, list);
        Pair<List<int[]>, List<int[]>> splitEarlyFeatures2 = splitEarlyFeatures(iArr5[leftChildNodeIndex][beamKey.leftChildK], binaryEdge.startIndex, binaryEdge.stopIndex);
        Pair<List<int[]>, List<int[]>> splitEarlyFeatures3 = splitEarlyFeatures(iArr5[rightChildNodeIndex][beamKey.rightChildK], binaryEdge.startIndex, binaryEdge.stopIndex);
        double binaryEdgeScore = (((((((this.w[0] * prunedForest.getBinaryEdgeScore(beamKey.edgeIndex)) + dotProduct(this.w, iArr8)) + listDotProduct(this.w, computeNonlocalIndicatorFeaturesForBinaryEdge)) + dArr[leftChildNodeIndex][beamKey.leftChildK]) + dArr[rightChildNodeIndex][beamKey.rightChildK]) + dotProduct(this.w, computeAnticipatedIndicatorFeaturesForBinaryEdge)) - dotProduct(this.w, splitEarlyFeatures2.getSecond())) - dotProduct(this.w, splitEarlyFeatures3.getSecond());
        ArrayList arrayList2 = new ArrayList(computeAnticipatedIndicatorFeaturesForBinaryEdge);
        arrayList2.addAll(splitEarlyFeatures2.getFirst());
        arrayList2.addAll(splitEarlyFeatures3.getFirst());
        beamKey.setEarlyFeatures((int[][]) arrayList2.toArray((Object[]) iArr6));
        priorityQueue.add(beamKey, binaryEdgeScore);
    }

    private Pair<List<int[]>, List<int[]>> splitEarlyFeatures(int[][] iArr, int i, int i2) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int[] iArr2 : iArr) {
            if (iArr2[1] < i || iArr2[2] > i2) {
                arrayList.add(iArr2);
            } else {
                arrayList2.add(iArr2);
            }
        }
        return Pair.makePair(arrayList, arrayList2);
    }

    private static double dotProduct(double[] dArr, int[] iArr) {
        double d = 0.0d;
        for (int i : iArr) {
            if (i > 0 && i < dArr.length) {
                d += dArr[i];
            }
        }
        return d;
    }

    private static double listDotProduct(double[] dArr, List<Integer> list) {
        double d = 0.0d;
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (intValue > 0 && intValue < dArr.length) {
                d += dArr[intValue];
            }
        }
        return d;
    }

    private static double dotProduct(double[] dArr, List<int[]> list) {
        double d = 0.0d;
        Iterator<int[]> it = list.iterator();
        while (it.hasNext()) {
            int i = it.next()[0];
            if (i > 0 && i < dArr.length) {
                d += dArr[i];
            }
        }
        return d;
    }

    public void setWeights(double[] dArr) {
        this.w = dArr;
    }

    public double[] getCurrentWeights() {
        return this.w;
    }
}
