package edu.berkeley.nlp.syntax;

import edu.berkeley.nlp.util.CollectionUtils;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Factory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/berkeley/nlp/syntax/UnaryClosureComputer.class */
public class UnaryClosureComputer<V> {
    private Factory<Edge> unaryRuleFactory = new Factory<Edge>() { // from class: edu.berkeley.nlp.syntax.UnaryClosureComputer.1
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // edu.berkeley.nlp.util.Factory
        public Edge newInstance(Object... objArr) {
            return new Edge(objArr[0], objArr[1], null);
        }
    };
    Map<V, List<Edge<V>>> closedUnaryRulesByChild = new HashMap();
    Map<V, List<Edge<V>>> closedUnaryRulesByParent = new HashMap();
    Map<Edge<V>, List<V>> pathMap = new HashMap();
    Set<Edge<V>> unaryRules = new HashSet();
    private boolean sumInsteadOfMultipy;

    /* loaded from: input_file:edu/berkeley/nlp/syntax/UnaryClosureComputer$Edge.class */
    public static class Edge<V> {
        private V parent;
        private V child;
        private double score;

        public int hashCode() {
            return (31 * ((31 * 1) + (this.child == null ? 0 : this.child.hashCode()))) + (this.parent == null ? 0 : this.parent.hashCode());
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            Edge edge = (Edge) obj;
            if (this.child == null) {
                if (edge.child != null) {
                    return false;
                }
            } else if (!this.child.equals(edge.child)) {
                return false;
            }
            return this.parent == null ? edge.parent == null : this.parent.equals(edge.parent);
        }

        public void setParent(V v) {
            this.parent = v;
        }

        public void setChild(V v) {
            this.child = v;
        }

        private Edge(V v, V v2) {
            this.parent = v;
            this.child = v2;
        }

        public V getParent() {
            return this.parent;
        }

        public V getChild() {
            return this.child;
        }

        public double getScore() {
            return this.score;
        }

        public void setScore(double d) {
            this.score = d;
        }

        /* synthetic */ Edge(Object obj, Object obj2, Edge edge) {
            this(obj, obj2);
        }
    }

    public Map<V, List<Edge<V>>> getAllClosedRulesByChildren() {
        return this.closedUnaryRulesByChild;
    }

    public List<Edge<V>> getClosedUnaryRulesByChild(V v) {
        return CollectionUtils.getValueList(this.closedUnaryRulesByChild, v);
    }

    public List<Edge<V>> getClosedUnaryRulesByParent(V v) {
        return CollectionUtils.getValueList(this.closedUnaryRulesByParent, v);
    }

    public List<V> getPath(Edge edge) {
        return this.pathMap.get(edge);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        Iterator<V> it = this.closedUnaryRulesByParent.keySet().iterator();
        while (it.hasNext()) {
            for (Edge<V> edge : getClosedUnaryRulesByParent(it.next())) {
                List<V> path = getPath(edge);
                sb.append(edge);
                sb.append("  ");
                sb.append(path);
                sb.append("\n");
            }
        }
        return sb.toString();
    }

    public UnaryClosureComputer(boolean z) {
        this.sumInsteadOfMultipy = z;
    }

    public void add(V v, V v2, double d) {
        Edge<V> edge = new Edge<>(v, v2, null);
        edge.setScore(d);
        this.unaryRules.add(edge);
    }

    public void solve() {
        Map<Edge<V>, List<V>> computeUnaryClosure = computeUnaryClosure(this.unaryRules);
        for (Edge<V> edge : computeUnaryClosure.keySet()) {
            addUnary(edge, computeUnaryClosure.get(edge));
        }
    }

    private void addUnary(Edge<V> edge, List<V> list) {
        CollectionUtils.addToValueList(this.closedUnaryRulesByChild, edge.getChild(), edge);
        CollectionUtils.addToValueList(this.closedUnaryRulesByParent, edge.getParent(), edge);
        this.pathMap.put(edge, list);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Map<Edge<V>, List<V>> computeUnaryClosure(Collection<Edge<V>> collection) {
        HashMap hashMap = new HashMap();
        Counter counter = new Counter();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        HashSet hashSet = new HashSet();
        for (Edge<V> edge : collection) {
            relax(counter, hashMap, hashMap2, hashMap3, edge, null, edge.getScore());
            hashSet.add(edge.getParent());
            hashSet.add(edge.getChild());
        }
        for (Object obj : hashSet) {
            List<Edge> list = (List) hashMap2.get(obj);
            List<Edge> list2 = (List) hashMap3.get(obj);
            if (list != null && list2 != null) {
                for (Edge edge2 : list) {
                    for (Edge edge3 : list2) {
                        relax(counter, hashMap, hashMap2, hashMap3, this.unaryRuleFactory.newInstance(edge2.getParent(), edge3.getChild()), obj, combinePathCosts(counter, edge2, edge3));
                    }
                }
            }
        }
        for (Object obj2 : hashSet) {
            relax(counter, hashMap, hashMap2, hashMap3, this.unaryRuleFactory.newInstance(obj2, obj2), null, 0.0d);
        }
        HashMap hashMap4 = new HashMap();
        for (Edge<V> edge4 : counter.keySet()) {
            edge4.setScore(counter.getCount(edge4));
            hashMap4.put(edge4, extractPath(edge4, hashMap));
        }
        return hashMap4;
    }

    private double combinePathCosts(Counter<Edge<V>> counter, Edge<V> edge, Edge<V> edge2) {
        return this.sumInsteadOfMultipy ? counter.getCount(edge) + counter.getCount(edge2) : counter.getCount(edge) * counter.getCount(edge2);
    }

    private List<V> extractPath(Edge<V> edge, Map<Edge<V>, V> map) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(edge.getParent());
        V v = map.get(edge);
        if (v != null) {
            List<V> extractPath = extractPath(this.unaryRuleFactory.newInstance(edge.getParent(), v), map);
            for (int i = 1; i < extractPath.size() - 1; i++) {
                arrayList.add(extractPath.get(i));
            }
            arrayList.add(v);
            List<V> extractPath2 = extractPath(this.unaryRuleFactory.newInstance(v, edge.getChild()), map);
            for (int i2 = 1; i2 < extractPath2.size() - 1; i2++) {
                arrayList.add(extractPath2.get(i2));
            }
        }
        if (arrayList.size() == 1 && edge.getParent() == edge.getChild()) {
            return arrayList;
        }
        arrayList.add(edge.getChild());
        return arrayList;
    }

    private void relax(Counter<Edge<V>> counter, Map<Edge<V>, V> map, Map<V, List<Edge<V>>> map2, Map<V, List<Edge<V>>> map3, Edge<V> edge, V v, double d) {
        if (v == null || !(v.equals(edge.getParent()) || v.equals(edge.getChild()))) {
            boolean z = !counter.containsKey(edge);
            if ((z ? Double.NEGATIVE_INFINITY : counter.getCount(edge)) > d) {
                return;
            }
            if (z) {
                CollectionUtils.addToValueList(map2, edge.getChild(), edge);
                CollectionUtils.addToValueList(map3, edge.getParent(), edge);
            }
            counter.setCount(edge, d);
            map.put(edge, v);
        }
    }

    public double getProb(V v, V v2) {
        int indexOf;
        if (v == v2) {
            return 0.0d;
        }
        List<Edge<V>> list = this.closedUnaryRulesByParent.get(v);
        if (list != null && (indexOf = list.indexOf(this.unaryRuleFactory.newInstance(v, v2))) >= 0) {
            return list.get(indexOf).getScore();
        }
        return Double.POSITIVE_INFINITY;
    }
}
