package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.io.NumberRangeFileFilter;
import edu.stanford.nlp.ling.StringLabelFactory;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.trees.BobChrisTreeNormalizer;
import edu.stanford.nlp.trees.DiskTreebank;
import edu.stanford.nlp.trees.LabeledScoredTreeFactory;
import edu.stanford.nlp.trees.MemoryTreebank;
import edu.stanford.nlp.trees.PennTreeReader;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeReader;
import edu.stanford.nlp.trees.TreeReaderFactory;
import edu.stanford.nlp.trees.TreeVisitor;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.trees.TreebankLanguagePack;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.PriorityQueue;
import java.io.Reader;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/parser/lexparser/ParentAnnotationStats.class */
public class ParentAnnotationStats implements TreeVisitor {
    private TreebankLanguagePack tlp;
    private Map nodeRules;
    private Map pRules;
    private Map gPRules;
    private Map tagNodeRules;
    private Map tagPRules;
    private Map tagGPRules;
    public static final double SUPPCUTOFF = 100.0d;
    private static boolean doTags = false;
    public static final double[] CUTOFFS = {100.0d, 200.0d, 500.0d, 1000.0d};

    public ParentAnnotationStats() {
        this(null);
    }

    private ParentAnnotationStats(TreebankLanguagePack treebankLanguagePack) {
        this.nodeRules = new HashMap();
        this.pRules = new HashMap();
        this.gPRules = new HashMap();
        this.tagNodeRules = new HashMap();
        this.tagPRules = new HashMap();
        this.tagGPRules = new HashMap();
        this.tlp = treebankLanguagePack;
    }

    @Override // edu.stanford.nlp.trees.TreeVisitor
    public void visitTree(Tree tree) {
        processTreeHelper("TOP", "TOP", tree);
    }

    public static List kidLabels(Tree tree) {
        Tree[] children = tree.children();
        ArrayList arrayList = new ArrayList(children.length);
        for (Tree tree2 : children) {
            arrayList.add(tree2.label().value());
        }
        return arrayList;
    }

    public void processTreeHelper(String str, String str2, Tree tree) {
        Map map;
        Map map2;
        Map map3;
        if (tree.isLeaf()) {
            return;
        }
        if (doTags || !tree.isPreTerminal()) {
            if (tree.isPreTerminal()) {
                map = this.tagNodeRules;
                map2 = this.tagPRules;
                map3 = this.tagGPRules;
            } else {
                map = this.nodeRules;
                map2 = this.pRules;
                map3 = this.gPRules;
            }
            String value = tree.label().value();
            if (this.tlp != null) {
                str2 = this.tlp.basicCategory(str2);
                str = this.tlp.basicCategory(str);
            }
            List kidLabels = kidLabels(tree);
            ClassicCounter classicCounter = (ClassicCounter) map.get(value);
            if (classicCounter == null) {
                classicCounter = new ClassicCounter();
                map.put(value, classicCounter);
            }
            classicCounter.incrementCount(kidLabels);
            ArrayList arrayList = new ArrayList(2);
            arrayList.add(value);
            arrayList.add(str2);
            ClassicCounter classicCounter2 = (ClassicCounter) map2.get(arrayList);
            if (classicCounter2 == null) {
                classicCounter2 = new ClassicCounter();
                map2.put(arrayList, classicCounter2);
            }
            classicCounter2.incrementCount(kidLabels);
            ArrayList arrayList2 = new ArrayList(3);
            arrayList2.add(value);
            arrayList2.add(str2);
            arrayList2.add(str);
            ClassicCounter classicCounter3 = (ClassicCounter) map3.get(arrayList2);
            if (classicCounter3 == null) {
                classicCounter3 = new ClassicCounter();
                map3.put(arrayList2, classicCounter3);
            }
            classicCounter3.incrementCount(kidLabels);
            for (Tree tree2 : tree.children()) {
                processTreeHelper(str2, value, tree2);
            }
        }
    }

    public void printStats() {
        NumberFormat numberInstance = NumberFormat.getNumberInstance();
        numberInstance.setMaximumFractionDigits(2);
        StringBuffer[] stringBufferArr = new StringBuffer[CUTOFFS.length];
        for (int i = 0; i < CUTOFFS.length; i++) {
            stringBufferArr[i] = new StringBuffer("  private static String[] splitters" + (i + 1) + " = new String[] {");
        }
        ClassicCounter classicCounter = new ClassicCounter();
        for (String str : this.nodeRules.keySet()) {
            ArrayList arrayList = new ArrayList();
            ClassicCounter classicCounter2 = (ClassicCounter) this.nodeRules.get(str);
            System.out.println("Node " + str + " support is " + classicCounter2.totalCount());
            for (List list : this.pRules.keySet()) {
                if (list.get(0).equals(str)) {
                    ClassicCounter classicCounter3 = (ClassicCounter) this.pRules.get(list);
                    double d = classicCounter3.totalCount();
                    double klDivergence = Counters.klDivergence(classicCounter3, classicCounter2);
                    System.out.println("KL(" + list + "||" + str + ") = " + numberInstance.format(klDivergence) + "\tsupport(" + list + ") = " + d);
                    double d2 = klDivergence * d;
                    arrayList.add(new Pair(list, new Double(d2)));
                    classicCounter.setCount(list, d2);
                }
            }
            System.out.println("----");
            System.out.println("Sorted descending support * KL");
            Collections.sort(arrayList, new Comparator() { // from class: edu.stanford.nlp.parser.lexparser.ParentAnnotationStats.1
                @Override // java.util.Comparator
                public int compare(Object obj, Object obj2) {
                    return ((Double) ((Pair) obj2).second()).compareTo((Double) ((Pair) obj).second());
                }
            });
            int size = arrayList.size();
            for (int i2 = 0; i2 < size; i2++) {
                Pair pair = (Pair) arrayList.get(i2);
                double doubleValue = ((Double) pair.second()).doubleValue();
                System.out.println(pair.first() + ": " + numberInstance.format(doubleValue));
                if (doubleValue >= CUTOFFS[0]) {
                    List list2 = (List) pair.first();
                    String str2 = (String) list2.get(0);
                    String str3 = (String) list2.get(1);
                    for (int i3 = 0; i3 < CUTOFFS.length; i3++) {
                        if (doubleValue >= CUTOFFS[i3]) {
                            stringBufferArr[i3].append("\"").append(str2).append("^");
                            stringBufferArr[i3].append(str3).append("\", ");
                        }
                    }
                }
            }
            System.out.println();
        }
        for (List list3 : this.pRules.keySet()) {
            ArrayList arrayList2 = new ArrayList();
            ClassicCounter classicCounter4 = (ClassicCounter) this.pRules.get(list3);
            double d3 = classicCounter4.totalCount();
            if (d3 >= 100.0d) {
                System.out.println("Node " + list3 + " support is " + d3);
                for (List list4 : this.gPRules.keySet()) {
                    if (list4.get(0).equals(list3.get(0)) && list4.get(1).equals(list3.get(1))) {
                        ClassicCounter classicCounter5 = (ClassicCounter) this.gPRules.get(list4);
                        double d4 = classicCounter5.totalCount();
                        double klDivergence2 = Counters.klDivergence(classicCounter5, classicCounter4);
                        System.out.println("KL(" + list4 + "||" + list3 + ") = " + numberInstance.format(klDivergence2) + "\tsupport(" + list4 + ") = " + d4);
                        double d5 = klDivergence2 * d4;
                        arrayList2.add(new Pair(list4, new Double(d5)));
                        classicCounter.setCount(list4, d5);
                    }
                }
                System.out.println("----");
                System.out.println("Sorted descending support * KL");
                Collections.sort(arrayList2, new Comparator() { // from class: edu.stanford.nlp.parser.lexparser.ParentAnnotationStats.2
                    @Override // java.util.Comparator
                    public int compare(Object obj, Object obj2) {
                        return ((Double) ((Pair) obj2).second()).compareTo((Double) ((Pair) obj).second());
                    }
                });
                int size2 = arrayList2.size();
                for (int i4 = 0; i4 < size2; i4++) {
                    Pair pair2 = (Pair) arrayList2.get(i4);
                    double doubleValue2 = ((Double) pair2.second()).doubleValue();
                    System.out.println(pair2.first() + ": " + numberInstance.format(doubleValue2));
                    if (doubleValue2 >= CUTOFFS[0]) {
                        List list5 = (List) pair2.first();
                        String str4 = (String) list5.get(0);
                        String str5 = (String) list5.get(1);
                        String str6 = (String) list5.get(2);
                        for (int i5 = 0; i5 < CUTOFFS.length; i5++) {
                            if (doubleValue2 >= CUTOFFS[i5]) {
                                stringBufferArr[i5].append("\"").append(str4).append("^");
                                stringBufferArr[i5].append(str5).append("~");
                                stringBufferArr[i5].append(str6).append("\", ");
                            }
                        }
                    }
                }
                System.out.println();
            }
        }
        System.out.println();
        System.out.println("All scores:");
        PriorityQueue priorityQueue = Counters.toPriorityQueue(classicCounter);
        while (!priorityQueue.isEmpty()) {
            Object first = priorityQueue.getFirst();
            double priority = priorityQueue.getPriority(first);
            priorityQueue.removeFirst();
            System.out.println(first + "\t" + priority);
        }
        System.out.println("  // Automatically generated by ParentAnnotationStats -- preferably don't edit");
        for (int i6 = 0; i6 < CUTOFFS.length; i6++) {
            int length = stringBufferArr[i6].length();
            stringBufferArr[i6].replace(length - 2, length, "};");
            System.out.println(stringBufferArr[i6]);
        }
        System.out.print("  public static HashSet splitters = new HashSet(Arrays.asList(");
        for (int length2 = CUTOFFS.length; length2 > 0; length2--) {
            if (length2 == 1) {
                System.out.print("splitters1");
            } else {
                System.out.print("selectiveSplit" + length2 + " ? splitters" + length2 + " : (");
            }
        }
        for (int length3 = CUTOFFS.length; length3 >= 0; length3--) {
            System.out.print(")");
        }
        System.out.println(";");
    }

    private void getSplitters(double d, Map map, Map map2, Map map3, Set set) {
        for (String str : map.keySet()) {
            ArrayList arrayList = new ArrayList();
            ClassicCounter classicCounter = (ClassicCounter) map.get(str);
            classicCounter.totalCount();
            for (List list : map2.keySet()) {
                if (list.get(0).equals(str)) {
                    ClassicCounter classicCounter2 = (ClassicCounter) map2.get(list);
                    arrayList.add(new Pair(list, new Double(Counters.klDivergence(classicCounter2, classicCounter) * classicCounter2.totalCount())));
                }
            }
            Collections.sort(arrayList, new Comparator() { // from class: edu.stanford.nlp.parser.lexparser.ParentAnnotationStats.3
                @Override // java.util.Comparator
                public int compare(Object obj, Object obj2) {
                    return ((Double) ((Pair) obj2).second()).compareTo((Double) ((Pair) obj).second());
                }
            });
            int size = arrayList.size();
            for (int i = 0; i < size; i++) {
                Pair pair = (Pair) arrayList.get(i);
                if (((Double) pair.second()).doubleValue() >= d) {
                    List list2 = (List) pair.first();
                    set.add(((String) list2.get(0)) + "^" + ((String) list2.get(1)));
                }
            }
        }
        for (List list3 : map2.keySet()) {
            ArrayList arrayList2 = new ArrayList();
            ClassicCounter classicCounter3 = (ClassicCounter) map2.get(list3);
            if (classicCounter3.totalCount() >= 100.0d) {
                for (List list4 : map3.keySet()) {
                    if (list4.get(0).equals(list3.get(0)) && list4.get(1).equals(list3.get(1))) {
                        ClassicCounter classicCounter4 = (ClassicCounter) map3.get(list4);
                        arrayList2.add(new Pair(list4, new Double(Counters.klDivergence(classicCounter4, classicCounter3) * classicCounter4.totalCount())));
                    }
                }
                Collections.sort(arrayList2, new Comparator() { // from class: edu.stanford.nlp.parser.lexparser.ParentAnnotationStats.4
                    @Override // java.util.Comparator
                    public int compare(Object obj, Object obj2) {
                        return ((Double) ((Pair) obj2).second()).compareTo((Double) ((Pair) obj).second());
                    }
                });
                int size2 = arrayList2.size();
                for (int i2 = 0; i2 < size2; i2++) {
                    Pair pair2 = (Pair) arrayList2.get(i2);
                    if (((Double) pair2.second()).doubleValue() >= d) {
                        List list5 = (List) pair2.first();
                        set.add(((String) list5.get(0)) + "^" + ((String) list5.get(1)) + "~" + ((String) list5.get(2)));
                    }
                }
            }
        }
    }

    public static void main(String[] strArr) {
        if (strArr.length < 1) {
            System.out.println("Usage: java edu.stanford.nlp.parser.lexparser.ParentAnnotationStats [-tags] treebankPath");
            return;
        }
        int i = 0;
        boolean z = false;
        double d = 0.0d;
        while (strArr[i].startsWith("-")) {
            if (strArr[i].equals("-tags")) {
                doTags = true;
                i++;
            } else if (!strArr[i].equals("-cutOff") || i + 1 >= strArr.length) {
                System.err.println("Unknown option: " + strArr[i]);
                i++;
            } else {
                z = true;
                d = Double.parseDouble(strArr[i + 1]);
                i += 2;
            }
        }
        DiskTreebank diskTreebank = new DiskTreebank(new TreeReaderFactory() { // from class: edu.stanford.nlp.parser.lexparser.ParentAnnotationStats.5
            @Override // edu.stanford.nlp.trees.TreeReaderFactory
            public TreeReader newTreeReader(Reader reader) {
                return new PennTreeReader(reader, new LabeledScoredTreeFactory(new StringLabelFactory()), new BobChrisTreeNormalizer());
            }
        });
        diskTreebank.loadPath(strArr[i]);
        if (z) {
            System.out.println(getSplitCategories(diskTreebank, doTags, 0, d, d, null));
        } else {
            ParentAnnotationStats parentAnnotationStats = new ParentAnnotationStats();
            diskTreebank.apply(parentAnnotationStats);
            parentAnnotationStats.printStats();
        }
    }

    public static Set<String> getSplitCategories(Treebank treebank, double d, TreebankLanguagePack treebankLanguagePack) {
        return getSplitCategories(treebank, true, 0, d, d, treebankLanguagePack);
    }

    public static Set<String> getSplitCategories(Treebank treebank, boolean z, int i, double d, double d2, TreebankLanguagePack treebankLanguagePack) {
        doTags = z;
        ParentAnnotationStats parentAnnotationStats = new ParentAnnotationStats(treebankLanguagePack);
        treebank.apply(parentAnnotationStats);
        HashSet hashSet = new HashSet();
        parentAnnotationStats.getSplitters(d, parentAnnotationStats.nodeRules, parentAnnotationStats.pRules, parentAnnotationStats.gPRules, hashSet);
        parentAnnotationStats.getSplitters(d2, parentAnnotationStats.tagNodeRules, parentAnnotationStats.tagPRules, parentAnnotationStats.tagGPRules, hashSet);
        return hashSet;
    }

    public static Set<String> getEnglishSplitCategories(String str) {
        EnglishTreebankParserParams englishTreebankParserParams = new EnglishTreebankParserParams();
        MemoryTreebank memoryTreebank = englishTreebankParserParams.memoryTreebank();
        memoryTreebank.loadPath(str, new NumberRangeFileFilter(200, 2199, true));
        return getSplitCategories(memoryTreebank, 300.0d, englishTreebankParserParams.treebankLanguagePack());
    }
}
