package edu.berkeley.nlp.PCFGLA.reranker;

import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.Lexicon;
import edu.berkeley.nlp.PCFGLA.ParserData;
import edu.berkeley.nlp.PCFGLA.TreeAnnotations;
import edu.berkeley.nlp.PCFGLA.reranker.FeatureExtractorManager;
import edu.berkeley.nlp.parser.EnglishPennTreebankParseEvaluator;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Lists;
import edu.berkeley.nlp.util.Logger;
import edu.berkeley.nlp.util.Numberer;
import fig.basic.IOUtils;
import fig.basic.Indexer;
import fig.basic.Option;
import fig.basic.Pair;
import fig.exec.Execution;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;

/* loaded from: input_file:edu/berkeley/nlp/PCFGLA/reranker/MIRAWeightTester.class */
public class MIRAWeightTester {

    @Option(gloss = "Baseline model grammar file location", required = true)
    public String grammarFile;

    @Option(gloss = "True if you're using a single file with trees rather than WSJ")
    public static boolean singleFile = false;

    @Option(gloss = "Path to Corpus (Default: null)")
    public static String path = null;

    @Option(gloss = "Maximum sentence length (Default <=10000)")
    public static int maxSentenceLength = 10000;

    @Option(gloss = "File to get weights from")
    public String weightFile;

    @Option(gloss = "Forest directory")
    public String forestDirectory;

    @Option(gloss = "Forest filefilter")
    public String forestFileFilter;
    OracleTreeFinder oracle;
    FeatureExtractorManager manager;
    ForestReranker reranker;
    EnglishPennTreebankParseEvaluator.LabeledConstituentEval<String> eval;
    Pruner pruner;

    @Option(gloss = "Sentence file to parse.")
    public String inputFile = null;

    @Option(gloss = "Threshold (in probability) for baseline model pruning")
    public double pruningThreshold = -10.0d;

    @Option(gloss = "Size of beam at each forest node during reranking")
    public int beamSize = 10;

    @Option(gloss = "Hyperparameter for MIRA")
    public double maxChange = 1.0d;

    @Option(gloss = "Read from forest file")
    public boolean readFromForest = false;

    @Option(gloss = "Start number from WSJ for gold trees")
    public int goldTreeStart = -1;

    @Option(gloss = "Stop number from WSJ for gold trees")
    public int goldTreeStop = -1;

    @Option(gloss = "debug comments on/off")
    public boolean debug = true;

    @Option(gloss = "use wsj dev set")
    public boolean useDev = false;

    @Option(gloss = "tree out file")
    public String treeOutFile = "trees.out";

    private void init() {
        Logger.setFig();
        Logger.logss("Loading baseline grammar...");
        ParserData Load = ParserData.Load(this.grammarFile);
        Grammar grammar = Load.getGrammar();
        Lexicon lexicon = Load.getLexicon();
        Numberer.setNumberers(Load.getNumbs());
        Logger.logss("Done.");
        Logger.logss("Initializing baseline parser...");
        BaseModel baseModel = new BaseModel(grammar);
        if (!this.readFromForest) {
            this.pruner = new MaxRulePruner(grammar, lexicon, Load.getSpanPredictor(), this.pruningThreshold);
        }
        Logger.logss("Done.");
        this.eval = new EnglishPennTreebankParseEvaluator.LabeledConstituentEval<>(Collections.singleton("ROOT"), new HashSet(Arrays.asList("''", "``", ".", ":", ",")));
        double[] readWeights = readWeights(this.weightFile);
        this.manager = new FeatureExtractorManager(baseModel, readIndexer(this.weightFile));
        this.oracle = new OracleTreeFinder(baseModel);
        this.reranker = new ForestReranker(baseModel, this.manager, this.beamSize);
        this.reranker.setWeights(readWeights);
    }

    private Indexer<FeatureExtractorManager.Feature> readIndexer(String str) {
        try {
            return (Indexer) new ObjectInputStream(new FileInputStream(String.valueOf(str) + ".indexer")).readObject();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            return null;
        } catch (IOException e2) {
            e2.printStackTrace();
            return null;
        } catch (ClassNotFoundException e3) {
            e3.printStackTrace();
            return null;
        }
    }

    private double[] readWeights(String str) {
        ArrayList arrayList = new ArrayList();
        BufferedReader openInHard = IOUtils.openInHard(String.valueOf(str) + ".weights");
        while (true) {
            try {
                String readLine = openInHard.readLine();
                if (readLine == null) {
                    break;
                }
                arrayList.add(Double.valueOf(readLine));
            } catch (IOException e) {
                e.printStackTrace();
                System.exit(-1);
            }
        }
        return Lists.m80toPrimitiveArray((List<Double>) arrayList);
    }

    public static void main(String[] strArr) {
        MIRAWeightTester mIRAWeightTester = new MIRAWeightTester();
        Execution.init(strArr, mIRAWeightTester);
        System.out.println("Loading trees from " + path + "; singleFile:  " + singleFile);
        System.out.println("Will remove sentences with more than " + maxSentenceLength + " words.");
        mIRAWeightTester.run();
        Execution.finish();
    }

    public void run() {
        init();
        List<Tree<String>> goldTreesByNumber = TreeLoader.getGoldTreesByNumber(path, singleFile, this.goldTreeStart, this.goldTreeStop);
        PrunedForestReader prunedForestReader = this.readFromForest ? new PrunedForestReader(new File(this.forestDirectory), this.forestFileFilter) : null;
        PrintWriter openOutHard = IOUtils.openOutHard(this.treeOutFile);
        Logger.startTrack("Parsing sentences", new Object[0]);
        int i = 0;
        for (int i2 = 0; i2 < goldTreesByNumber.size(); i2++) {
            Tree<String> tree = goldTreesByNumber.get(i2);
            PrunedForest nextForest = this.readFromForest ? prunedForestReader.getNextForest() : this.pruner.getPrunedForest(tree.getYield());
            Tree<String> bestParse = this.reranker.getBestParse(nextForest);
            if (bestParse == null || bestParse.isLeaf()) {
                Logger.err("Error parsing sentence %d: %s", Integer.valueOf(i + 1), tree.getYield());
                openOutHard.println("()");
            } else {
                Tree<String> unAnnotateTree = TreeAnnotations.unAnnotateTree(bestParse);
                if (this.debug) {
                    Pair<Double, int[]> viterbiTreeFeatureVector = this.reranker.getViterbiTreeFeatureVector(nextForest);
                    Pair<Double, int[]> viterbiTreeFeatureVector2 = this.reranker.getViterbiTreeFeatureVector(this.oracle.getOracleTreeAsForest(nextForest, tree));
                    double[] currentWeights = this.reranker.getCurrentWeights();
                    System.out.println("Dot score oracle: " + (dotProduct(currentWeights, viterbiTreeFeatureVector2.getSecond()) + (currentWeights[0] * viterbiTreeFeatureVector2.getFirst().doubleValue())) + "; dot score parse: " + (dotProduct(currentWeights, viterbiTreeFeatureVector.getSecond()) + (currentWeights[0] * viterbiTreeFeatureVector.getFirst().doubleValue())));
                }
                openOutHard.println(unAnnotateTree);
            }
            i++;
            Logger.logs("Parsed %d sentences", Integer.valueOf(i));
        }
        System.out.println("Parsed " + i + " sentences.");
        this.eval.display(true);
        System.out.println("The computed F1,LP,LR scores are just a rough guide. They are typically 0.1-0.2 lower than the official EVALB scores.");
        Logger.endTrack();
        openOutHard.close();
    }

    private double dotProduct(double[] dArr, int[] iArr) {
        double d = 0.0d;
        for (int i : iArr) {
            if (i > 0 && i < dArr.length) {
                d += dArr[i];
            }
        }
        return d;
    }
}
