package edu.berkeley.nlp.PCFGLA.reranker;

import edu.berkeley.nlp.io.PennTreebankReader;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.syntax.Trees;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Writer;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import javax.servlet.http.HttpServletResponse;

/* loaded from: input_file:edu/berkeley/nlp/PCFGLA/reranker/TreeLoader.class */
public class TreeLoader {
    List<Tree<String>> trainTrees;
    private List<Tree<String>> validationTrees;
    private List<Tree<String>> devTestTrees;
    private List<Tree<String>> finalTestTrees;

    public static List<Tree<String>> getGoldTreesByNumber(String str, boolean z, int i, int i2) {
        return new TreeLoader(str, z, i, i2).getTrainTrees();
    }

    public List<Tree<String>> getTrainTrees() {
        return this.trainTrees;
    }

    public TreeLoader(String str, boolean z, int i, int i2) {
        this(str, z);
        if (i == -1 || i2 == -1) {
            return;
        }
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(this.trainTrees.subList(i, i2));
        this.trainTrees = arrayList;
    }

    public TreeLoader(String str, boolean z, int i, int i2, int i3, String str2, boolean z2) {
        this(str, z);
        int size = this.trainTrees.size();
        if (i != -1) {
            ArrayList arrayList = new ArrayList(this.trainTrees.subList(0, i3));
            ArrayList arrayList2 = new ArrayList();
            if (i3 + i < i2) {
                arrayList2.addAll(this.trainTrees.subList(i3, i3 + i));
                arrayList.addAll(this.trainTrees.subList(i3 + i, i2));
            } else {
                arrayList2.addAll(this.trainTrees.subList(i3, i2));
            }
            System.out.println("Including trees from (0, " + i3 + ") and (" + (i3 + i) + ", " + i2 + ")");
            if (str2 != null) {
                if (z2) {
                    System.out.println("Writing " + this.trainTrees.size() + " kept trees");
                    try {
                        PrintWriter printWriter = new PrintWriter((Writer) new OutputStreamWriter(new FileOutputStream(str2), "UTF-8"), true);
                        for (int i4 = 0; i4 < this.trainTrees.size(); i4++) {
                            printWriter.println(this.trainTrees.get(i4).toString());
                        }
                        printWriter.close();
                    } catch (Exception e) {
                        System.out.println("Problem writing trees.");
                    }
                } else {
                    System.out.println("Writing " + arrayList2.size() + " left out trees");
                    try {
                        PrintWriter printWriter2 = new PrintWriter((Writer) new OutputStreamWriter(new FileOutputStream(str2), "UTF-8"), true);
                        for (int i5 = 0; i5 < arrayList2.size(); i5++) {
                            printWriter2.println(((Tree) arrayList2.get(i5)).toString());
                        }
                        printWriter2.close();
                    } catch (Exception e2) {
                        System.out.println("Problem writing trees.");
                    }
                }
                System.exit(0);
            }
            this.trainTrees = arrayList;
            int i6 = 0;
            Iterator<Tree<String>> it = this.trainTrees.iterator();
            while (it.hasNext()) {
                i6 += it.next().getYield().size();
            }
            System.out.println("In training set we have # of words: " + i6);
            System.out.println("In training set we have # of sentences: " + this.trainTrees.size());
            System.out.println("reducing number of training trees from " + size + " to " + this.trainTrees.size());
        }
    }

    private TreeLoader(String str, boolean z) {
        this.trainTrees = new ArrayList();
        this.validationTrees = new ArrayList();
        this.devTestTrees = new ArrayList();
        this.finalTestTrees = new ArrayList();
        try {
            if (z) {
                System.out.println("Loading data from single file!");
                loadSingleFile(str);
            } else {
                System.out.println("Loading ENGLISH WSJ data!");
                loadWSJ(str);
            }
        } catch (Exception e) {
            System.out.println("Error loading trees!");
            System.out.println(e.getStackTrace().toString());
            throw new Error(e.getMessage(), e);
        }
    }

    private void loadSingleFile(String str) throws Exception {
        System.out.print("Loading trees from single file...");
        Trees.PennTreeReader pennTreeReader = new Trees.PennTreeReader(new InputStreamReader(new FileInputStream(str), "UTF-8"));
        while (pennTreeReader.hasNext()) {
            this.trainTrees.add(pennTreeReader.next());
        }
        Trees.StandardTreeNormalizer standardTreeNormalizer = new Trees.StandardTreeNormalizer();
        ArrayList arrayList = new ArrayList();
        Iterator<Tree<String>> it = this.trainTrees.iterator();
        while (it.hasNext()) {
            arrayList.add(standardTreeNormalizer.transformTree(it.next()));
        }
        if (arrayList.size() == 0) {
            throw new Exception("failed to load any trees at " + str);
        }
        this.trainTrees = arrayList;
        this.devTestTrees = this.trainTrees;
        System.out.println("done");
    }

    private void loadWSJ(String str) throws Exception {
        System.out.print("Loading WSJ trees...");
        this.trainTrees.addAll(readTrees(str, HttpServletResponse.SC_OK, 2199, Charset.defaultCharset()));
        this.validationTrees.addAll(readTrees(str, 2100, 2199, Charset.defaultCharset()));
        this.devTestTrees.addAll(readTrees(str, 2200, 2299, Charset.defaultCharset()));
        this.finalTestTrees.addAll(readTrees(str, 2300, 2399, Charset.defaultCharset()));
        System.out.println("done");
    }

    public static List<Tree<String>> readDev(String str, Charset charset) {
        return null;
    }

    public static List<Tree<String>> readTrees(String str, int i, int i2, Charset charset) throws Exception {
        Collection<Tree<String>> readTrees = PennTreebankReader.readTrees(str, i, i2, charset);
        System.out.println("in readTrees: " + str);
        Trees.EmptyNodeRelabeler emptyNodeRelabeler = new Trees.EmptyNodeRelabeler();
        ArrayList arrayList = new ArrayList();
        Iterator<Tree<String>> it = readTrees.iterator();
        while (it.hasNext()) {
            arrayList.add(emptyNodeRelabeler.transformTree(it.next()));
        }
        if (arrayList.size() == 0) {
            throw new Exception("failed to load any trees at " + str + " from " + i + " to " + i2);
        }
        System.out.println("Read " + arrayList.size() + " trees.");
        return arrayList;
    }

    public static void splitTrainValidTest(List<Tree<String>> list, List<Tree<String>> list2, List<Tree<String>> list3, List<Tree<String>> list4, List<Tree<String>> list5) {
        for (int i = 0; i < list.size(); i++) {
            if (i % 10 < 7) {
                list2.add(list.get(i));
            } else if (i % 10 == 7) {
                list3.add(list.get(i));
            } else if (i % 10 == 8) {
                list4.add(list.get(i));
            } else if (i % 10 == 9) {
                list5.add(list.get(i));
            }
        }
    }

    public static List<Tree<String>> filterTreesForConditional(List<Tree<String>> list, boolean z, boolean z2, boolean z3) {
        ArrayList arrayList = new ArrayList(list.size());
        for (Tree<String> tree : list) {
            if (tree.getYield().size() != 1) {
                if (tree.hasUnaryChain()) {
                    if (z3) {
                        tree.removeUnaryChains();
                    }
                }
                if (z2) {
                    Iterator<Tree<String>> it = tree.getNonTerminals().iterator();
                    while (it.hasNext()) {
                        if (it.next().getLabel().contains("WHNP")) {
                            break;
                        }
                    }
                }
                if (!z || !tree.hasUnariesOtherThanRoot()) {
                    arrayList.add(tree);
                }
            }
        }
        return arrayList;
    }

    public static void lowercaseWords(List<Tree<String>> list) {
        Iterator<Tree<String>> it = list.iterator();
        while (it.hasNext()) {
            for (Tree<String> tree : it.next().getTerminals()) {
                tree.setLabel(tree.getLabel().toLowerCase());
            }
        }
    }
}
