package edu.berkeley.nlp.scripts;

import edu.berkeley.nlp.PCFGLA.Binarization;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.Option;
import edu.berkeley.nlp.PCFGLA.OptionParser;
import edu.berkeley.nlp.PCFGLA.ParserData;
import edu.berkeley.nlp.PCFGLA.SophisticatedLexicon;
import edu.berkeley.nlp.PCFGLA.StateSetTreeList;
import edu.berkeley.nlp.PCFGLA.smoothing.NoSmoothing;
import edu.berkeley.nlp.PCFGLA.smoothing.SmoothAcrossParentSubstate;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.syntax.Trees;
import edu.berkeley.nlp.util.Numberer;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/berkeley/nlp/scripts/ObservedGrammarExtractor.class */
public class ObservedGrammarExtractor {
    static Numberer tagNumberer;
    static List<Numberer> substateNumberers;

    /* loaded from: input_file:edu/berkeley/nlp/scripts/ObservedGrammarExtractor$Options.class */
    public static class Options {

        @Option(name = "-out", required = true, usage = "Output File for Grammar (Required)")
        public String outFileName;

        @Option(name = "-path", usage = "Path to Corpus File (Default: null)")
        public String path = null;

        @Option(name = "-smooth", usage = "Smooth the grammar if possible")
        public boolean smooth = false;
    }

    public static void main(String[] strArr) {
        Options options = (Options) new OptionParser(Options.class).parse(strArr, true);
        if (createGrammar(loadTrees(options.path), options.smooth).Save(options.outFileName)) {
            System.out.println("Saved grammar.");
        } else {
            System.out.println("Saving failed!");
        }
        System.exit(0);
    }

    private static ParserData createGrammar(List<Tree<String>> list, boolean z) {
        tagNumberer = Numberer.getGlobalNumberer("tags");
        substateNumberers = new ArrayList();
        short[] countSymbols = countSymbols(list);
        StateSetTreeList stateSetTreeList = new StateSetTreeList(stripOffAnnotation(list), countSymbols, false, tagNumberer);
        Grammar grammar = new Grammar(countSymbols, false, new NoSmoothing(), null, -1.0d);
        SophisticatedLexicon sophisticatedLexicon = new SophisticatedLexicon(countSymbols, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, new double[]{0.5d, 0.1d}, new NoSmoothing(), 0.0d);
        if (z) {
            System.out.println("Will smooth the grammar.");
            SmoothAcrossParentSubstate smoothAcrossParentSubstate = new SmoothAcrossParentSubstate(0.01d);
            SmoothAcrossParentSubstate smoothAcrossParentSubstate2 = new SmoothAcrossParentSubstate(0.1d);
            grammar.setSmoother(smoothAcrossParentSubstate);
            sophisticatedLexicon.setSmoother(smoothAcrossParentSubstate2);
        }
        System.out.print("Creating grammar...");
        int i = 0;
        int size = list.size();
        Iterator<Tree<StateSet>> it = stateSetTreeList.iterator();
        while (it.hasNext()) {
            Tree<StateSet> next = it.next();
            int i2 = i;
            i++;
            Tree<String> tree = list.get(i2);
            boolean z2 = ((double) i) > ((double) size) / 2.0d;
            setScores(next, tree);
            sophisticatedLexicon.trainTree(next, 0.0d, null, z2, false);
            grammar.tallyStateSetTree(next, grammar);
        }
        sophisticatedLexicon.optimize();
        grammar.optimize(0.0d);
        System.out.println("done.");
        return new ParserData(sophisticatedLexicon, grammar, null, Numberer.getNumberers(), countSymbols, 1, 0, Binarization.RIGHT);
    }

    private static void setScores(Tree<StateSet> tree, Tree<String> tree2) {
        if (tree2.isLeaf()) {
            return;
        }
        String[] splitLabel = splitLabel(tree2.getLabel());
        StateSet label = tree.getLabel();
        int number = substateNumberers.get(label.getState()).number(splitLabel[1]);
        label.setIScore(number, 1.0d);
        label.setIScale(0);
        label.setOScore(number, 1.0d);
        label.setOScale(0);
        int size = tree2.getChildren().size();
        if (size != tree.getChildren().size()) {
            System.err.println("Mismatch!");
        }
        for (int i = 0; i < size; i++) {
            setScores(tree.getChildren().get(i), tree2.getChildren().get(i));
        }
    }

    private static List<Tree<String>> stripOffAnnotation(List<Tree<String>> list) {
        ArrayList<Tree> arrayList = new ArrayList(list.size());
        Iterator<Tree<String>> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().shallowClone());
        }
        for (Tree tree : arrayList) {
            for (Tree tree2 : tree.getPostOrderTraversal()) {
                if (!tree.isLeaf()) {
                    String str = (String) tree2.getLabel();
                    int indexOf = str.indexOf(45);
                    if (indexOf != -1) {
                        str = str.substring(0, indexOf);
                    }
                    tree2.setLabel(str);
                }
            }
        }
        return arrayList;
    }

    private static short[] countSymbols(List<Tree<String>> list) {
        System.out.print("Counting symbols...");
        Iterator<Tree<String>> it = list.iterator();
        while (it.hasNext()) {
            processTree(it.next());
        }
        short[] sArr = new short[tagNumberer.total()];
        for (int i = 0; i < sArr.length; i++) {
            sArr[i] = (short) substateNumberers.get(i).total();
        }
        System.out.println("done.");
        for (int i2 = 0; i2 < tagNumberer.size(); i2++) {
            System.out.println(String.valueOf((String) tagNumberer.object(i2)) + "\t" + ((int) sArr[i2]));
        }
        return sArr;
    }

    private static void processTree(Tree<String> tree) {
        String[] splitLabel = splitLabel(tree.getLabel());
        int number = tagNumberer.number(splitLabel[0]);
        if (number >= substateNumberers.size()) {
            substateNumberers.add(new Numberer());
        }
        substateNumberers.get(number).number(splitLabel[1]);
        for (Tree<String> tree2 : tree.getChildren()) {
            if (!tree2.isLeaf()) {
                processTree(tree2);
            }
        }
    }

    private static String[] splitLabel(String str) {
        int indexOf = str.indexOf("-");
        return new String[]{indexOf < 0 ? str : str.substring(0, indexOf), indexOf < 0 ? "" : str.substring(indexOf)};
    }

    private static List<Tree<String>> loadTrees(String str) {
        System.out.print("Loading trees...");
        InputStreamReader inputStreamReader = null;
        try {
            inputStreamReader = new InputStreamReader(new FileInputStream(str), "UTF-8");
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (UnsupportedEncodingException e2) {
            e2.printStackTrace();
        }
        Trees.PennTreeReader pennTreeReader = new Trees.PennTreeReader(inputStreamReader);
        ArrayList arrayList = new ArrayList();
        while (pennTreeReader.hasNext()) {
            arrayList.add(pennTreeReader.next());
        }
        System.out.println("done.");
        return arrayList;
    }
}
