package edu.berkeley.nlp.classify;

import edu.berkeley.nlp.math.CachingDifferentiableFunction;
import edu.berkeley.nlp.math.LBFGSMinimizer;
import edu.berkeley.nlp.util.CollectionUtils;
import edu.berkeley.nlp.util.Counter;
import fig.basic.Pair;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/berkeley/nlp/classify/LinearRegression.class */
public class LinearRegression<I> {
    private FeatureExtractor<I, String> featureExtractor;
    private double[] weights;
    private FeatureManager featureManager;

    /* loaded from: input_file:edu/berkeley/nlp/classify/LinearRegression$Factory.class */
    public static class Factory<I> {
        double[] weights;
        FeatureManager featureManager = new FeatureManager();
        FeatureExtractor<I, String> featureExtractor;
        Collection<Pair<I, Double>> trainingData;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:edu/berkeley/nlp/classify/LinearRegression$Factory$ObjectiveFunction.class */
        public class ObjectiveFunction extends CachingDifferentiableFunction {
            private ObjectiveFunction() {
            }

            @Override // edu.berkeley.nlp.math.CachingDifferentiableFunction
            protected Pair<Double, double[]> calculate(double[] dArr) {
                Factory.this.weights = dArr;
                double d = 0.0d;
                double[] dArr2 = new double[dimension()];
                for (Pair<I, Double> pair : Factory.this.trainingData) {
                    Counter features = Factory.this.getFeatures(pair.getFirst());
                    double score = Factory.this.getScore(features) - pair.getSecond().doubleValue();
                    d += 0.5d * score * score;
                    for (Feature feature : features.keySet()) {
                        double count = features.getCount(feature);
                        int index = feature.getIndex();
                        dArr2[index] = dArr2[index] + (count * score);
                    }
                }
                return Pair.newPair(Double.valueOf(d), dArr2);
            }

            @Override // edu.berkeley.nlp.math.CachingDifferentiableFunction, edu.berkeley.nlp.math.Function
            public int dimension() {
                return Factory.this.featureManager.getNumFeatures();
            }

            public double[] unregularizedDerivativeAt(double[] dArr) {
                return null;
            }

            /* synthetic */ ObjectiveFunction(Factory factory, ObjectiveFunction objectiveFunction) {
                this();
            }
        }

        public Factory(FeatureExtractor<I, String> featureExtractor) {
            this.featureExtractor = featureExtractor;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public Counter<Feature> getFeatures(I i) {
            Counter<String> extractFeatures = this.featureExtractor.extractFeatures(i);
            Counter<Feature> counter = new Counter<>();
            for (String str : extractFeatures.keySet()) {
                counter.setCount(this.featureManager.getFeature(str), extractFeatures.getCount(str));
            }
            return counter;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double getScore(Counter<Feature> counter) {
            double d = 0.0d;
            for (Feature feature : counter.keySet()) {
                d += counter.getCount(feature) * this.weights[feature.getIndex()];
            }
            return d;
        }

        private void extractAllFeatures() {
            Iterator<Pair<I, Double>> it = this.trainingData.iterator();
            while (it.hasNext()) {
                Iterator<String> it2 = this.featureExtractor.extractFeatures(it.next().getFirst()).keySet().iterator();
                while (it2.hasNext()) {
                    this.featureManager.getFeature(it2.next());
                }
            }
            this.featureManager.lock();
        }

        private String examineWeights() {
            Counter counter = new Counter();
            for (int i = 0; i < this.weights.length; i++) {
                counter.setCount(this.featureManager.getFeature(i), this.weights[i]);
            }
            return counter.toString();
        }

        public LinearRegression<I> train(Collection<Pair<I, Double>> collection) {
            this.trainingData = collection;
            extractAllFeatures();
            ObjectiveFunction objectiveFunction = new ObjectiveFunction(this, null);
            this.weights = new LBFGSMinimizer().minimize(objectiveFunction, new double[objectiveFunction.dimension()], 1.0E-4d);
            return new LinearRegression<>(this.featureExtractor, this.featureManager, this.weights, null);
        }
    }

    private LinearRegression(FeatureExtractor<I, String> featureExtractor, FeatureManager featureManager, double[] dArr) {
        this.featureExtractor = featureExtractor;
        this.featureManager = featureManager;
        this.weights = dArr;
    }

    public double getResponse(I i) {
        Counter<String> extractFeatures = this.featureExtractor.extractFeatures(i);
        double d = 0.0d;
        for (String str : extractFeatures.keySet()) {
            d += extractFeatures.getCount(str) * this.weights[this.featureManager.getFeature(str).getIndex()];
        }
        return d;
    }

    public static void main(String[] strArr) {
        List makeList = CollectionUtils.makeList("a", "b", "c");
        System.out.println("guess: " + new Factory(new FeatureExtractor<List<String>, String>() { // from class: edu.berkeley.nlp.classify.LinearRegression.1
            @Override // edu.berkeley.nlp.classify.FeatureExtractor
            public Counter<String> extractFeatures(List<String> list) {
                Counter<String> counter = new Counter<>();
                Iterator<String> it = list.iterator();
                while (it.hasNext()) {
                    counter.incrementCount(it.next(), 1.0d);
                }
                return counter;
            }
        }).train(CollectionUtils.makeList(Pair.newPair(makeList, Double.valueOf(3.0d)), Pair.newPair(CollectionUtils.makeList("a", "b"), Double.valueOf(2.0d)))).getResponse(makeList));
    }

    /* synthetic */ LinearRegression(FeatureExtractor featureExtractor, FeatureManager featureManager, double[] dArr, LinearRegression linearRegression) {
        this(featureExtractor, featureManager, dArr);
    }
}
