package edu.berkeley.nlp.PCFGLA.reranker;

import edu.berkeley.nlp.PCFGLA.CoarseToFineMaxRuleParser;
import edu.berkeley.nlp.PCFGLA.ConstrainedTwoChartsParser;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.Lexicon;
import edu.berkeley.nlp.PCFGLA.SpanPredictor;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.util.Lists;
import edu.berkeley.nlp.util.Logger;
import edu.berkeley.nlp.util.Triple;
import fig.basic.Pair;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

/* loaded from: input_file:edu/berkeley/nlp/PCFGLA/reranker/ViterbiPruner.class */
public class ViterbiPruner implements Pruner {
    private final Grammar baseGrammar;
    private final Lexicon baseLexicon;
    private final Grammar pruningGrammar;
    private final CoarseToFineMaxRuleParser preParser;
    private final ConstrainedTwoChartsParser parser;
    private final short[] numSubStatesArray;
    private final double pruningThreshold;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/berkeley/nlp/PCFGLA/reranker/ViterbiPruner$Chart.class */
    public static class Chart {
        public final double[][][][] iScoresPreU;
        public final double[][][][] iScoresPostU;
        public final double[][][][] oScoresPreU;
        public final double[][][][] oScoresPostU;
        public final int[][][] iScale;
        public final int[][][] oScale;
        public final boolean[][][] allowedStates;
        public final boolean[][][][] allowedSubStates;

        public Chart(double[][][][] dArr, double[][][][] dArr2, double[][][][] dArr3, double[][][][] dArr4, int[][][] iArr, int[][][] iArr2, boolean[][][] zArr, boolean[][][][] zArr2) {
            this.iScoresPreU = dArr;
            this.iScoresPostU = dArr2;
            this.oScoresPreU = dArr3;
            this.oScoresPostU = dArr4;
            this.iScale = iArr;
            this.oScale = iArr2;
            this.allowedStates = zArr;
            this.allowedSubStates = zArr2;
        }
    }

    static {
        $assertionsDisabled = !ViterbiPruner.class.desiredAssertionStatus();
    }

    public ViterbiPruner(Grammar grammar, Lexicon lexicon, SpanPredictor spanPredictor, double d) {
        this.pruningGrammar = grammar.copyGrammar(false);
        Lexicon copyLexicon = lexicon.copyLexicon();
        this.baseGrammar = grammar;
        this.baseLexicon = lexicon;
        this.pruningThreshold = d;
        this.preParser = new CoarseToFineMaxRuleParser(grammar, lexicon, 1.0d, -1, true, false, false, false, false, false, false);
        this.preParser.initCascade(grammar, lexicon);
        this.parser = new ConstrainedTwoChartsParser(this.pruningGrammar, copyLexicon, spanPredictor);
        this.numSubStatesArray = this.parser.getNumSubStatesArray();
    }

    @Override // edu.berkeley.nlp.PCFGLA.reranker.Pruner
    public PrunedForest getPrunedForest(List<String> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        HashSet hashSet = new HashSet();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        ArrayList arrayList5 = new ArrayList();
        ArrayList arrayList6 = new ArrayList();
        ArrayList arrayList7 = new ArrayList();
        Chart chart = getChart(list);
        int size = list.size();
        double d = chart.iScoresPostU[0][size][0][0];
        int i = chart.iScale[0][size][0];
        for (int i2 = 1; i2 <= size; i2++) {
            for (int i3 = 0; i3 + i2 <= size; i3++) {
                int i4 = i3 + i2;
                if (i2 > 1) {
                    Triple<List<BinaryEdge>, List<Double>, List<Double>> pruneBinaryEdges = pruneBinaryEdges(chart, i3, i4, d, i);
                    arrayList.addAll(pruneBinaryEdges.getFirst());
                    arrayList4.addAll(pruneBinaryEdges.getSecond());
                    arrayList6.addAll(pruneBinaryEdges.getThird());
                    addChildNodesFromBinaryRules(arrayList3, hashSet, pruneBinaryEdges.getFirst());
                }
                Triple<List<UnaryEdge>, List<Double>, List<Double>> pruneUnaryEdges = pruneUnaryEdges(chart, i3, i4, d, i);
                arrayList2.addAll(pruneUnaryEdges.getFirst());
                arrayList5.addAll(pruneUnaryEdges.getSecond());
                arrayList7.addAll(pruneUnaryEdges.getThird());
                addChildNodesFromUnaryRules(arrayList3, hashSet, pruneUnaryEdges.getFirst());
            }
        }
        arrayList3.add(new Node(0, size, 0, 0));
        double[] dArr = new double[arrayList3.size()];
        for (int i5 = 0; i5 < arrayList3.size(); i5++) {
            Node node = arrayList3.get(i5);
            if (!this.baseGrammar.isGrammarTag(node.state)) {
                dArr[i5] = this.baseLexicon.score(list.get(node.startIndex), (short) node.state, node.startIndex, false, false)[node.substate];
            }
        }
        return new PrunedForest((Node[]) arrayList3.toArray(new Node[0]), (BinaryEdge[]) arrayList.toArray(new BinaryEdge[0]), (UnaryEdge[]) arrayList2.toArray(new UnaryEdge[0]), Lists.m80toPrimitiveArray((List<Double>) arrayList4), Lists.m80toPrimitiveArray((List<Double>) arrayList5), dArr, Lists.m80toPrimitiveArray((List<Double>) arrayList6), Lists.m80toPrimitiveArray((List<Double>) arrayList7), list);
    }

    private void addChildNodesFromUnaryRules(List<Node> list, Set<Node> set, List<UnaryEdge> list2) {
        HashSet<Node> hashSet = new HashSet();
        HashMap hashMap = new HashMap();
        for (UnaryEdge unaryEdge : list2) {
            hashSet.add(unaryEdge.getChild());
            hashMap.put(unaryEdge.getChild(), new ArrayList());
        }
        for (UnaryEdge unaryEdge2 : list2) {
            if (hashMap.containsKey(unaryEdge2.getParent())) {
                ((List) hashMap.get(unaryEdge2.getParent())).add(unaryEdge2.getChild());
            }
        }
        boolean z = true;
        while (!hashSet.isEmpty() && z) {
            HashSet hashSet2 = new HashSet();
            z = false;
            for (Node node : hashSet) {
                boolean z2 = true;
                Iterator it = ((List) hashMap.get(node)).iterator();
                while (it.hasNext()) {
                    z2 = z2 && set.contains((Node) it.next());
                }
                if (z2) {
                    z = true;
                    hashSet2.add(node);
                    if (set.contains(node)) {
                        continue;
                    } else {
                        list.add(node);
                        set.add(node);
                        if (!$assertionsDisabled && list.size() != set.size()) {
                            throw new AssertionError();
                        }
                    }
                }
            }
            hashSet.removeAll(hashSet2);
        }
        if (hashSet.isEmpty()) {
            return;
        }
        Logger.err("Topological sort failed and not all unary child nodes were added!");
    }

    private void addChildNodesFromBinaryRules(List<Node> list, Set<Node> set, List<BinaryEdge> list2) {
        for (BinaryEdge binaryEdge : list2) {
            Node leftChild = binaryEdge.getLeftChild();
            if (!set.contains(leftChild)) {
                list.add(leftChild);
                set.add(leftChild);
                if (!$assertionsDisabled && list.size() != set.size()) {
                    throw new AssertionError();
                }
            }
            Node rightChild = binaryEdge.getRightChild();
            if (!set.contains(rightChild)) {
                list.add(rightChild);
                set.add(rightChild);
            }
        }
    }

    /*  JADX ERROR: NullPointerException in pass: LoopRegionVisitor
        java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.SSAVar.use(jadx.core.dex.instructions.args.RegisterArg)" because "ssaVar" is null
        	at jadx.core.dex.nodes.InsnNode.rebindArgs(InsnNode.java:489)
        	at jadx.core.dex.nodes.InsnNode.rebindArgs(InsnNode.java:492)
        */
    private edu.berkeley.nlp.util.Triple<java.util.List<edu.berkeley.nlp.PCFGLA.reranker.UnaryEdge>, java.util.List<java.lang.Double>, java.util.List<java.lang.Double>> pruneUnaryEdges(edu.berkeley.nlp.PCFGLA.reranker.ViterbiPruner.Chart r10, int r11, int r12, double r13, int r15) {
        /*
            Method dump skipped, instructions count: 535
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: edu.berkeley.nlp.PCFGLA.reranker.ViterbiPruner.pruneUnaryEdges(edu.berkeley.nlp.PCFGLA.reranker.ViterbiPruner$Chart, int, int, double, int):edu.berkeley.nlp.util.Triple");
    }

    private void addTransitiveClosure(HashSet<Pair<Integer, Integer>> hashSet, Pair<Integer, Integer> pair) {
        HashSet hashSet2 = new HashSet();
        Iterator<Pair<Integer, Integer>> it = hashSet.iterator();
        while (it.hasNext()) {
            Pair<Integer, Integer> next = it.next();
            if (next.getSecond().equals(pair.getFirst())) {
                hashSet2.add(Pair.makePair(next.getFirst(), pair.getSecond()));
            }
            if (next.getFirst().equals(pair.getSecond())) {
                hashSet2.add(Pair.makePair(pair.getFirst(), next.getSecond()));
            }
        }
        hashSet.add(pair);
        Iterator it2 = hashSet2.iterator();
        while (it2.hasNext()) {
            addTransitiveClosure(hashSet, (Pair) it2.next());
        }
    }

    /*  JADX ERROR: NullPointerException in pass: LoopRegionVisitor
        java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.SSAVar.use(jadx.core.dex.instructions.args.RegisterArg)" because "ssaVar" is null
        	at jadx.core.dex.nodes.InsnNode.rebindArgs(InsnNode.java:489)
        	at jadx.core.dex.nodes.InsnNode.rebindArgs(InsnNode.java:492)
        */
    private edu.berkeley.nlp.util.Triple<java.util.List<edu.berkeley.nlp.PCFGLA.reranker.BinaryEdge>, java.util.List<java.lang.Double>, java.util.List<java.lang.Double>> pruneBinaryEdges(edu.berkeley.nlp.PCFGLA.reranker.ViterbiPruner.Chart r13, int r14, int r15, double r16, int r18) {
        /*
            Method dump skipped, instructions count: 563
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: edu.berkeley.nlp.PCFGLA.reranker.ViterbiPruner.pruneBinaryEdges(edu.berkeley.nlp.PCFGLA.reranker.ViterbiPruner$Chart, int, int, double, int):edu.berkeley.nlp.util.Triple");
    }

    private Chart getChart(List<String> list) {
        this.preParser.getBestParse(list);
        boolean[][][][] zArr = (boolean[][][][]) this.preParser.getAllowedSubStates().clone();
        this.parser.projectConstraints(zArr, false);
        this.parser.doConstrainedInsideOutsideScores(convertToStateSetList(list), zArr, false, null, null, true);
        return new Chart(this.parser.getPreUnaryInsideScores(), this.parser.getPostUnaryInsideScores(), this.parser.getPreUnaryOutsideScores(), this.parser.getPostUnaryOutsideScores(), this.parser.getInsideScalingFactors(), this.parser.getOutsideScalingFactors(), this.preParser.getAllowedStates(), zArr);
    }

    private List<StateSet> convertToStateSetList(List<String> list) {
        ArrayList arrayList = new ArrayList(list.size());
        short s = 0;
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            StateSet stateSet = new StateSet((short) -1, (short) 1, it.next(), s, (short) (s + 1));
            s = (short) (s + 1);
            stateSet.wordIndex = -2;
            stateSet.sigIndex = -2;
            arrayList.add(stateSet);
        }
        return arrayList;
    }
}
