package edu.berkeley.nlp.PCFGLA.reranker;

import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.Lexicon;
import edu.berkeley.nlp.PCFGLA.ParserData;
import edu.berkeley.nlp.PCFGLA.TreeAnnotations;
import edu.berkeley.nlp.PCFGLA.reranker.FeatureExtractorManager;
import edu.berkeley.nlp.parser.EnglishPennTreebankParseEvaluator;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Logger;
import edu.berkeley.nlp.util.Numberer;
import fig.basic.IOUtils;
import fig.basic.Indexer;
import fig.basic.Option;
import fig.basic.Pair;
import fig.exec.Execution;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;

/* loaded from: input_file:edu/berkeley/nlp/PCFGLA/reranker/MIRAWeightLearner.class */
public class MIRAWeightLearner implements Runnable {

    @Option(gloss = "Baseline model grammar file location", required = true)
    public String grammarFile;

    @Option(gloss = "Path to Corpus (Default: null)")
    public static String path = null;

    @Option(gloss = "True if you're using a single file with trees rather than WSJ")
    public static boolean singleFile = false;

    @Option(gloss = "Maximum sentence length (Default <=10000)")
    public static int maxSentenceLength = 10000;

    @Option(gloss = "File to write final weights to")
    public String weightFile;

    @Option(gloss = "Forest directory")
    public String forestDirectory;

    @Option(gloss = "Forest filefilter")
    public String forestFileFilter;

    @Option(gloss = "Filter for getting held out set")
    public String heldOutFilter;

    @Option(gloss = "Gold trees for held out set")
    public String heldOutTreePath;
    FeatureExtractorManager manager;
    ForestReranker reranker;
    OracleTreeFinder oracle;
    Pruner pruner;
    List<Tree<String>> trainTrees;
    List<Tree<String>> heldOutTrees;
    List<PrunedForest> heldOutForests;
    EnglishPennTreebankParseEvaluator.LabeledConstituentEval<String> evaluator;

    @Option(gloss = "Sentence file to parse.")
    public String inputFile = null;

    @Option(gloss = "Threshold (in probability) for baseline model pruning")
    public double pruningThreshold = -10.0d;

    @Option(gloss = "Size of beam at each forest node during reranking")
    public int beamSize = 10;

    @Option(gloss = "Hyperparameter for MIRA")
    public double maxChange = 1.0d;

    @Option(gloss = "Read from forest file")
    public boolean readFromForest = false;

    @Option(gloss = "Start number from WSJ for gold trees")
    public int goldTreeStart = -1;

    @Option(gloss = "Stop number from WSJ for gold trees")
    public int goldTreeStop = -1;

    @Option(gloss = "Number of MIRA iterations")
    public int miraIter = 10;
    List<int[][]> parseLocalFeaturesBinary = null;
    List<int[][]> parseLocalFeaturesUnary = null;
    List<Pair<Double, int[]>> oracleFeatures = null;
    List<Tree<String>> oracleTrees = new ArrayList();
    List<PrunedForest> trainForests = new ArrayList();
    List<RerankedForest> oracleForests = new ArrayList();
    Counter<Integer> localFeatureCounts = new Counter<>();

    public static void main(String[] strArr) {
        MIRAWeightLearner mIRAWeightLearner = new MIRAWeightLearner();
        Execution.init(strArr, mIRAWeightLearner);
        System.out.println("Loading trees from " + path + " and using singleFile: " + singleFile);
        System.out.println("Will remove sentences with more than " + maxSentenceLength + " words.");
        mIRAWeightLearner.run();
        Execution.finish();
    }

    private void init() {
        Logger.setFig();
        Logger.logss("Loading baseline grammar...");
        ParserData Load = ParserData.Load(this.grammarFile);
        Grammar grammar = Load.getGrammar();
        Lexicon lexicon = Load.getLexicon();
        Numberer.setNumberers(Load.getNumbs());
        Logger.logss("Done.");
        Logger.logss("Initializing baseline parser...");
        BaseModel baseModel = new BaseModel(grammar);
        this.oracle = new OracleTreeFinder(baseModel);
        if (!this.readFromForest) {
            this.pruner = new MaxRulePruner(grammar, lexicon, Load.getSpanPredictor(), this.pruningThreshold);
        }
        Logger.logss("Done.");
        this.manager = new FeatureExtractorManager(baseModel);
        this.evaluator = new EnglishPennTreebankParseEvaluator.LabeledConstituentEval<>(new HashSet(), new HashSet(0));
        this.reranker = new ForestReranker(baseModel, this.manager, this.beamSize);
        this.reranker.setWeights(new double[]{0.0d});
        this.trainTrees = new TreeLoader(path, singleFile, this.goldTreeStart, this.goldTreeStop).getTrainTrees();
        if (this.readFromForest) {
            loadForests(this.trainTrees);
        }
        precomputeLocalFeatures(this.trainTrees);
    }

    private void loadHeldOutSet(List<Tree<String>> list) {
        this.heldOutTrees = list;
    }

    private void loadForests(List<Tree<String>> list) {
        PrunedForestReader prunedForestReader = new PrunedForestReader(new File(this.forestDirectory), this.forestFileFilter);
        for (int i = 0; i < list.size(); i++) {
            PrunedForest nextForest = prunedForestReader.getNextForest();
            Tree<String> oracleTree = this.oracle.getOracleTree(nextForest, list.get(i));
            RerankedForest oracleTreeAsForest = this.oracle.getOracleTreeAsForest(nextForest, list.get(i));
            this.trainForests.add(nextForest);
            this.oracleTrees.add(oracleTree);
            this.oracleForests.add(oracleTreeAsForest);
        }
    }

    private void precomputeLocalFeaturesForest() {
        this.parseLocalFeaturesBinary = new ArrayList();
        this.parseLocalFeaturesUnary = new ArrayList();
        this.oracleFeatures = new ArrayList();
        for (int i = 0; i < this.oracleForests.size(); i++) {
            PrunedForest prunedForest = this.trainForests.get(i);
            RerankedForest rerankedForest = this.oracleForests.get(i);
            if (rerankedForest == null) {
                this.parseLocalFeaturesBinary.add(null);
                this.parseLocalFeaturesUnary.add(null);
                this.oracleFeatures.add(null);
            } else {
                if (!prunedForest.getSentence().equals(rerankedForest.getSentence())) {
                    System.out.println("FORESTS NOT EQUAL:");
                    System.out.println("curForest: " + prunedForest.getSentence());
                    System.out.println("oracleTree: " + rerankedForest.getSentence());
                }
                int[][] precomputeLocalIndicatorFeatures = this.manager.getLocalFeatureExtractor().precomputeLocalIndicatorFeatures(prunedForest.getBinaryEdges(), prunedForest.getSentence());
                int[][] precomputeLocalIndicatorFeatures2 = this.manager.getLocalFeatureExtractor().precomputeLocalIndicatorFeatures(prunedForest.getUnaryEdges(), prunedForest.getSentence());
                int[][] precomputeLocalIndicatorFeatures3 = this.manager.getLocalFeatureExtractor().precomputeLocalIndicatorFeatures(rerankedForest.getBinaryEdgesFromViterbiTree(), rerankedForest.getSentence());
                int[][] precomputeLocalIndicatorFeatures4 = this.manager.getLocalFeatureExtractor().precomputeLocalIndicatorFeatures(rerankedForest.getUnaryEdgesFromViterbiTree(), rerankedForest.getSentence());
                addToFeatCounter(precomputeLocalIndicatorFeatures);
                addToFeatCounter(precomputeLocalIndicatorFeatures2);
                addToFeatCounter(precomputeLocalIndicatorFeatures3);
                addToFeatCounter(precomputeLocalIndicatorFeatures4);
                Pair<Double, int[]> viterbiTreeFeatureVector = this.reranker.getViterbiTreeFeatureVector(rerankedForest);
                this.parseLocalFeaturesBinary.add(precomputeLocalIndicatorFeatures);
                this.parseLocalFeaturesUnary.add(precomputeLocalIndicatorFeatures2);
                this.oracleFeatures.add(viterbiTreeFeatureVector);
            }
            if (i % 100 == 0) {
                System.out.println("Precomputed features for " + i + " trees.");
            }
        }
    }

    private void addToFeatCounter(int[][] iArr) {
        for (int i = 0; i < iArr.length; i++) {
            for (int i2 = 0; i2 < iArr[i].length; i2++) {
                this.localFeatureCounts.incrementCount(Integer.valueOf(iArr[i][i2]), 1.0d);
            }
        }
    }

    private void precomputeLocalFeatures(List<Tree<String>> list) {
        if (this.readFromForest) {
            precomputeLocalFeaturesForest();
            return;
        }
        this.parseLocalFeaturesBinary = new ArrayList();
        this.parseLocalFeaturesUnary = new ArrayList();
        this.oracleFeatures = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            Tree<String> tree = list.get(i);
            PrunedForest prunedForest = this.pruner.getPrunedForest(tree.getYield());
            RerankedForest oracleTreeAsForest = this.oracle.getOracleTreeAsForest(prunedForest, tree);
            if (oracleTreeAsForest == null) {
                this.parseLocalFeaturesBinary.add(null);
                this.parseLocalFeaturesUnary.add(null);
                this.oracleFeatures.add(null);
            } else {
                if (!prunedForest.getSentence().equals(tree.getYield())) {
                    System.out.println("FORESTS NOT EQUAL:");
                    System.out.println("curForest: " + prunedForest.getSentence());
                    System.out.println("oracleTree: " + tree.getYield());
                }
                this.parseLocalFeaturesBinary.add(this.manager.getLocalFeatureExtractor().precomputeLocalIndicatorFeatures(prunedForest.getBinaryEdges(), prunedForest.getSentence()));
                this.parseLocalFeaturesUnary.add(this.manager.getLocalFeatureExtractor().precomputeLocalIndicatorFeatures(prunedForest.getUnaryEdges(), prunedForest.getSentence()));
                this.oracleFeatures.add(this.reranker.getViterbiTreeFeatureVector(oracleTreeAsForest));
            }
            if (i % 100 == 0) {
                System.out.println("Precomputed features for " + i + " trees.");
            }
        }
    }

    public void runForest() {
        PrintWriter openOutHard = IOUtils.openOutHard("trees.out");
        PrintWriter openOutHard2 = IOUtils.openOutHard("taus.out");
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int i = 0;
        while (true) {
            double d = 0.0d;
            Logger.startTrack("Parsing sentences", new Object[0]);
            int i2 = 0;
            int i3 = 0;
            for (int i4 = 0; i4 < this.oracleTrees.size(); i4++) {
                PrunedForest prunedForest = this.trainForests.get(i4);
                Tree<String> bestParse = this.reranker.getBestParse(prunedForest, this.parseLocalFeaturesBinary.get(i4), this.parseLocalFeaturesUnary.get(i4));
                if (bestParse == null || bestParse.isLeaf()) {
                    Logger.err("Error parsing sentence %d: %s", Integer.valueOf(i2 + 1), this.trainTrees.get(i4).getYield());
                    openOutHard.println("()");
                } else {
                    Tree<String> tree = this.oracleTrees.get(i4);
                    Tree<String> unAnnotateTree = TreeAnnotations.unAnnotateTree(bestParse);
                    Tree<String> unAnnotateTree2 = TreeAnnotations.unAnnotateTree(tree);
                    if (!unAnnotateTree2.equals(unAnnotateTree)) {
                        i3++;
                        d += applyWeightChanges(prunedForest, this.oracleForests.get(i4), i4, unAnnotateTree, unAnnotateTree2);
                    }
                    openOutHard.println("( " + unAnnotateTree.getChildren().get(0) + " )");
                }
                i2++;
                if (i2 % 100 == 0) {
                    i2++;
                    Logger.logs("Parsed %d sentences", Integer.valueOf(i2));
                }
            }
            arrayList.add(Integer.valueOf(i3));
            System.out.println("Missed parses: " + i3);
            arrayList2.add(Double.valueOf(d));
            System.out.println("Hamming error: " + d);
            System.out.println("Finished iteration " + i);
            Logger.endTrack();
            writeWeights(this.reranker.getCurrentWeights(), this.manager.features, i);
            i++;
            if (this.heldOutFilter == null && i >= this.miraIter) {
                openOutHard.close();
                openOutHard2.close();
                System.out.println("Num missed over all iterations: " + arrayList);
                System.out.println("Hamming error all iterations: " + arrayList2);
                return;
            }
        }
    }

    @Override // java.lang.Runnable
    public void run() {
        init();
        if (this.readFromForest) {
            runForest();
            return;
        }
        PrunedForestReader prunedForestReader = this.readFromForest ? new PrunedForestReader(new File(this.forestDirectory), this.forestFileFilter) : null;
        precomputeLocalFeatures(this.trainTrees);
        PrintWriter openOutHard = IOUtils.openOutHard("trees.out");
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.miraIter; i++) {
            Logger.startTrack("Parsing sentences", new Object[0]);
            int i2 = 0;
            int i3 = 0;
            if (this.readFromForest) {
                prunedForestReader = new PrunedForestReader(new File(this.forestDirectory), this.forestFileFilter);
            }
            for (int i4 = 0; i4 < this.trainTrees.size(); i4++) {
                Tree<String> tree = this.trainTrees.get(i4);
                PrunedForest nextForest = this.readFromForest ? prunedForestReader.getNextForest() : this.pruner.getPrunedForest(tree.getYield());
                Tree<String> bestParse = this.reranker.getBestParse(nextForest, this.parseLocalFeaturesBinary.get(i4), this.parseLocalFeaturesUnary.get(i4));
                if (bestParse == null || bestParse.isLeaf()) {
                    Logger.err("Error parsing sentence %d: %s", Integer.valueOf(i2 + 1), tree.getYield());
                    openOutHard.println("()");
                } else {
                    Tree<String> oracleTree = this.oracle.getOracleTree(nextForest, tree);
                    System.out.println(bestParse);
                    System.out.println(oracleTree);
                    if (!oracleTree.equals(bestParse)) {
                        i3++;
                        applyWeightChanges(nextForest, this.oracle.getOracleTreeAsForest(nextForest, oracleTree), i4, bestParse, oracleTree);
                    }
                    openOutHard.println("( " + TreeAnnotations.unAnnotateTree(bestParse).getChildren().get(0) + " )");
                }
                i2++;
                if (i2 % 100 == 0) {
                    i2++;
                    Logger.logs("Parsed %d sentences", Integer.valueOf(i2));
                }
            }
            arrayList.add(Integer.valueOf(i3));
            System.out.println("Missed parses: " + i3);
            System.out.println("Finished iteration " + i);
            Logger.endTrack();
            writeWeights(this.reranker.getCurrentWeights(), this.manager.features, i);
        }
        openOutHard.close();
        System.out.println("Num missed over all iterations: " + arrayList);
    }

    private double getWeight(double[] dArr, int i) {
        if (i < dArr.length) {
            return dArr[i];
        }
        return 0.0d;
    }

    private void writeWeights(double[] dArr, Indexer<FeatureExtractorManager.Feature> indexer, int i) {
        System.out.println("Writing " + dArr.length + " weights.");
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(String.valueOf(this.weightFile) + i + ".indexer"));
            objectOutputStream.writeObject(indexer);
            objectOutputStream.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e2) {
            e2.printStackTrace();
        }
        PrintWriter openOutHard = IOUtils.openOutHard(String.valueOf(this.weightFile) + i + ".weights");
        for (double d : dArr) {
            openOutHard.println(d);
        }
        openOutHard.close();
        System.out.println("Finished writing weights.");
    }

    private double applyWeightChanges(PrunedForest prunedForest, RerankedForest rerankedForest, int i, Tree<String> tree, Tree<String> tree2) {
        Pair<Double, int[]> viterbiTreeFeatureVector;
        Pair<Double, int[]> pair;
        if (this.parseLocalFeaturesBinary == null) {
            viterbiTreeFeatureVector = this.reranker.getViterbiTreeFeatureVector(prunedForest);
            pair = this.reranker.getViterbiTreeFeatureVector(rerankedForest);
        } else {
            viterbiTreeFeatureVector = this.reranker.getViterbiTreeFeatureVector(prunedForest, this.parseLocalFeaturesBinary.get(i), this.parseLocalFeaturesUnary.get(i));
            pair = this.oracleFeatures.get(i);
        }
        double[] currentWeights = this.reranker.getCurrentWeights();
        double dotProduct = dotProduct(currentWeights, pair.getSecond()) + (currentWeights[0] * pair.getFirst().doubleValue());
        double dotProduct2 = dotProduct(currentWeights, viterbiTreeFeatureVector.getSecond()) + (currentWeights[0] * viterbiTreeFeatureVector.getFirst().doubleValue());
        Counter<Integer> makeCounts = makeCounts(pair);
        Counter<Integer> makeCounts2 = makeCounts(viterbiTreeFeatureVector);
        Counter<Integer> difference = makeCounts.difference(makeCounts2);
        double dotProduct3 = difference.dotProduct(difference);
        double d = dotProduct - dotProduct2;
        double hammingDistance = this.evaluator.getHammingDistance(tree, tree2);
        double d2 = (hammingDistance - d) / dotProduct3;
        if (Double.isNaN(d2)) {
            System.out.println("NAN badness");
        }
        double d3 = d2 > this.maxChange ? this.maxChange : d2;
        double d4 = d3 < 0.0d ? 0.0d : d3;
        for (Integer num : difference.keySet()) {
            currentWeights = updateWeights(currentWeights, num.intValue(), d4 * (makeCounts.getCount(num) - makeCounts2.getCount(num)));
        }
        this.reranker.setWeights(currentWeights);
        return hammingDistance;
    }

    private double dotProduct(double[] dArr, int[] iArr) {
        double d = 0.0d;
        for (int i : iArr) {
            if (i > 0 && i < dArr.length) {
                d += dArr[i];
            }
        }
        return d;
    }

    private Counter<Integer> makeCounts(Pair<Double, int[]> pair) {
        Counter<Integer> counter = new Counter<>();
        for (int i : pair.getSecond()) {
            counter.incrementCount(Integer.valueOf(i), 1.0d);
        }
        counter.incrementCount(0, pair.getFirst().doubleValue());
        return counter;
    }

    private double[] updateWeights(double[] dArr, int i, double d) {
        if (i < dArr.length) {
            dArr[i] = dArr[i] + d;
            return dArr;
        }
        double[] dArr2 = new double[i + 1];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr2[i2] = dArr[i2];
        }
        Arrays.fill(dArr2, dArr.length, dArr2.length, 0.0d);
        dArr2[i] = d;
        return dArr2;
    }
}
