package net.sf.javaml.classification;

import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.Instance;
import net.sf.javaml.utils.ActiveSetsOptimization;
import nz.ac.waikato.cs.weka.Utils;

/* loaded from: input_file:net/sf/javaml/classification/Logistic.class */
public class Logistic implements Classifier {
    private static final long serialVersionUID = -5428362109088506874L;
    protected double[][] m_Par;
    protected double[][] m_Data;
    protected int m_NumPredictors;
    protected int m_NumClasses;
    protected boolean m_Debug;
    protected double m_LL;
    protected double m_Ridge = 1.0E-8d;
    private int m_MaxIts = -1;

    /* loaded from: input_file:net/sf/javaml/classification/Logistic$OptEng.class */
    private class OptEng extends ActiveSetsOptimization {
        private double[] weights;
        private int[] cls;

        private OptEng() {
        }

        public void setWeights(double[] dArr) {
            this.weights = dArr;
        }

        public void setClassLabels(int[] iArr) {
            this.cls = iArr;
        }

        @Override // net.sf.javaml.utils.ActiveSetsOptimization
        protected double objectiveFunction(double[] dArr) {
            double d = 0.0d;
            int i = Logistic.this.m_NumPredictors + 1;
            for (int i2 = 0; i2 < this.cls.length; i2++) {
                double[] dArr2 = new double[Logistic.this.m_NumClasses - 1];
                for (int i3 = 0; i3 < Logistic.this.m_NumClasses - 1; i3++) {
                    int i4 = i3 * i;
                    for (int i5 = 0; i5 < i; i5++) {
                        int i6 = i3;
                        dArr2[i6] = dArr2[i6] + (Logistic.this.m_Data[i2][i5] * dArr[i4 + i5]);
                    }
                }
                double d2 = dArr2[Utils.maxIndex(dArr2)];
                double exp = Math.exp(-d2);
                double d3 = this.cls[i2] == Logistic.this.m_NumClasses - 1 ? -d2 : dArr2[this.cls[i2]] - d2;
                for (int i7 = 0; i7 < Logistic.this.m_NumClasses - 1; i7++) {
                    exp += Math.exp(dArr2[i7] - d2);
                }
                d -= this.weights[i2] * (d3 - Math.log(exp));
            }
            for (int i8 = 0; i8 < Logistic.this.m_NumClasses - 1; i8++) {
                for (int i9 = 1; i9 < i; i9++) {
                    d += Logistic.this.m_Ridge * dArr[(i8 * i) + i9] * dArr[(i8 * i) + i9];
                }
            }
            return d;
        }

        @Override // net.sf.javaml.utils.ActiveSetsOptimization
        protected double[] evaluateGradient(double[] dArr) {
            double[] dArr2 = new double[dArr.length];
            int i = Logistic.this.m_NumPredictors + 1;
            for (int i2 = 0; i2 < this.cls.length; i2++) {
                double[] dArr3 = new double[Logistic.this.m_NumClasses - 1];
                for (int i3 = 0; i3 < Logistic.this.m_NumClasses - 1; i3++) {
                    double d = 0.0d;
                    int i4 = i3 * i;
                    for (int i5 = 0; i5 < i; i5++) {
                        d += Logistic.this.m_Data[i2][i5] * dArr[i4 + i5];
                    }
                    dArr3[i3] = d;
                }
                double d2 = dArr3[Utils.maxIndex(dArr3)];
                double exp = Math.exp(-d2);
                for (int i6 = 0; i6 < Logistic.this.m_NumClasses - 1; i6++) {
                    dArr3[i6] = Math.exp(dArr3[i6] - d2);
                    exp += dArr3[i6];
                }
                Utils.normalize(dArr3, exp);
                for (int i7 = 0; i7 < Logistic.this.m_NumClasses - 1; i7++) {
                    int i8 = i7 * i;
                    double d3 = this.weights[i2] * dArr3[i7];
                    for (int i9 = 0; i9 < i; i9++) {
                        int i10 = i8 + i9;
                        dArr2[i10] = dArr2[i10] + (d3 * Logistic.this.m_Data[i2][i9]);
                    }
                }
                if (this.cls[i2] != Logistic.this.m_NumClasses - 1) {
                    for (int i11 = 0; i11 < i; i11++) {
                        int i12 = (this.cls[i2] * i) + i11;
                        dArr2[i12] = dArr2[i12] - (this.weights[i2] * Logistic.this.m_Data[i2][i11]);
                    }
                }
            }
            for (int i13 = 0; i13 < Logistic.this.m_NumClasses - 1; i13++) {
                for (int i14 = 1; i14 < i; i14++) {
                    int i15 = (i13 * i) + i14;
                    dArr2[i15] = dArr2[i15] + (2.0d * Logistic.this.m_Ridge * dArr[(i13 * i) + i14]);
                }
            }
            return dArr2;
        }

        /* synthetic */ OptEng(Logistic logistic, OptEng optEng) {
            this();
        }
    }

    @Override // net.sf.javaml.classification.Classifier
    public void buildClassifier(Dataset dataset) {
        double[] findArgmin;
        this.m_NumClasses = dataset.getNumClasses();
        int i = this.m_NumClasses - 1;
        int size = dataset.getInstance(0).size();
        this.m_NumPredictors = size;
        int size2 = dataset.size();
        this.m_Data = new double[size2][size + 1];
        int[] iArr = new int[size2];
        double[] dArr = new double[size + 1];
        double[] dArr2 = new double[size + 1];
        double[] dArr3 = new double[i + 1];
        double[] dArr4 = new double[size2];
        double d = 0.0d;
        this.m_Par = new double[size + 1][i];
        if (this.m_Debug) {
            System.out.println("Extracting data...");
        }
        for (int i2 = 0; i2 < size2; i2++) {
            Instance dataset2 = dataset.getInstance(i2);
            iArr[i2] = dataset2.getClassValue();
            dArr4[i2] = dataset2.getWeight();
            d += dArr4[i2];
            this.m_Data[i2][0] = 1.0d;
            int i3 = 1;
            for (int i4 = 0; i4 <= size; i4++) {
                double value = dataset2.getValue(i4);
                this.m_Data[i2][i3] = value;
                int i5 = i3;
                dArr[i5] = dArr[i5] + (dArr4[i2] * value);
                int i6 = i3;
                dArr2[i6] = dArr2[i6] + (dArr4[i2] * value * value);
                i3++;
            }
            int i7 = iArr[i2];
            dArr3[i7] = dArr3[i7] + 1.0d;
        }
        dArr[0] = 0.0d;
        dArr2[0] = 1.0d;
        for (int i8 = 1; i8 <= size; i8++) {
            dArr[i8] = dArr[i8] / d;
            if (d > 1.0d) {
                dArr2[i8] = Math.sqrt(Math.abs(dArr2[i8] - ((d * dArr[i8]) * dArr[i8])) / (d - 1.0d));
            } else {
                dArr2[i8] = 0.0d;
            }
        }
        if (this.m_Debug) {
            System.out.println("Descriptives...");
            for (int i9 = 0; i9 <= i; i9++) {
                System.out.println(String.valueOf(dArr3[i9]) + " cases have class " + i9);
            }
            System.out.println("\n Variable     Avg       SD    ");
            for (int i10 = 1; i10 <= size; i10++) {
                System.out.println(String.valueOf(Utils.doubleToString(i10, 8, 4)) + Utils.doubleToString(dArr[i10], 10, 4) + Utils.doubleToString(dArr2[i10], 10, 4));
            }
        }
        for (int i11 = 0; i11 < size2; i11++) {
            for (int i12 = 0; i12 <= size; i12++) {
                if (dArr2[i12] != 0.0d) {
                    this.m_Data[i11][i12] = (this.m_Data[i11][i12] - dArr[i12]) / dArr2[i12];
                }
            }
        }
        if (this.m_Debug) {
            System.out.println("\nIteration History...");
        }
        double[] dArr5 = new double[(size + 1) * i];
        double[][] dArr6 = new double[2][dArr5.length];
        for (int i13 = 0; i13 < i; i13++) {
            int i14 = i13 * (size + 1);
            dArr5[i14] = Math.log(dArr3[i13] + 1.0d) - Math.log(dArr3[i] + 1.0d);
            dArr6[0][i14] = Double.NaN;
            dArr6[1][i14] = Double.NaN;
            for (int i15 = 1; i15 <= size; i15++) {
                dArr5[i14 + i15] = 0.0d;
                dArr6[0][i14 + i15] = Double.NaN;
                dArr6[1][i14 + i15] = Double.NaN;
            }
        }
        OptEng optEng = new OptEng(this, null);
        optEng.setDebug(this.m_Debug);
        optEng.setWeights(dArr4);
        optEng.setClassLabels(iArr);
        if (this.m_MaxIts == -1) {
            double[] findArgmin2 = optEng.findArgmin(dArr5, dArr6);
            while (true) {
                findArgmin = findArgmin2;
                if (findArgmin != null) {
                    break;
                }
                double[] varbValues = optEng.getVarbValues();
                if (this.m_Debug) {
                    System.out.println("200 iterations finished, not enough!");
                }
                findArgmin2 = optEng.findArgmin(varbValues, dArr6);
            }
            if (this.m_Debug) {
                System.out.println(" -------------<Converged>--------------");
            }
        } else {
            optEng.setMaxIteration(this.m_MaxIts);
            findArgmin = optEng.findArgmin(dArr5, dArr6);
            if (findArgmin == null) {
                findArgmin = optEng.getVarbValues();
            }
        }
        this.m_LL = -optEng.getMinFunction();
        this.m_Data = null;
        for (int i16 = 0; i16 < i; i16++) {
            this.m_Par[0][i16] = findArgmin[i16 * (size + 1)];
            for (int i17 = 1; i17 <= size; i17++) {
                this.m_Par[i17][i16] = findArgmin[(i16 * (size + 1)) + i17];
                if (dArr2[i17] != 0.0d) {
                    double[] dArr7 = this.m_Par[i17];
                    int i18 = i16;
                    dArr7[i18] = dArr7[i18] / dArr2[i17];
                    double[] dArr8 = this.m_Par[0];
                    int i19 = i16;
                    dArr8[i19] = dArr8[i19] - (this.m_Par[i17][i16] * dArr[i17]);
                }
            }
        }
    }

    @Override // net.sf.javaml.classification.Classifier
    public double[] distributionForInstance(Instance instance) {
        double[] dArr = new double[this.m_NumPredictors + 1];
        int i = 1;
        dArr[0] = 1.0d;
        for (int i2 = 0; i2 <= this.m_NumPredictors; i2++) {
            int i3 = i;
            i++;
            dArr[i3] = instance.getValue(i2);
        }
        return evaluateProbability(dArr);
    }

    private double[] evaluateProbability(double[] dArr) {
        double[] dArr2 = new double[this.m_NumClasses];
        double[] dArr3 = new double[this.m_NumClasses];
        for (int i = 0; i < this.m_NumClasses - 1; i++) {
            for (int i2 = 0; i2 <= this.m_NumPredictors; i2++) {
                int i3 = i;
                dArr3[i3] = dArr3[i3] + (this.m_Par[i2][i] * dArr[i2]);
            }
        }
        dArr3[this.m_NumClasses - 1] = 0.0d;
        for (int i4 = 0; i4 < this.m_NumClasses; i4++) {
            double d = 0.0d;
            for (int i5 = 0; i5 < this.m_NumClasses - 1; i5++) {
                d += Math.exp(dArr3[i5] - dArr3[i4]);
            }
            dArr2[i4] = 1.0d / (d + Math.exp(-dArr3[i4]));
        }
        return dArr2;
    }

    public String toString() {
        String str = "Logistic Regression with ridge parameter of " + this.m_Ridge;
        if (this.m_Par == null) {
            return String.valueOf(str) + ": No model built yet.";
        }
        String str2 = String.valueOf(str) + "\nCoefficients...\nVariable      Coeff.\n";
        for (int i = 1; i <= this.m_NumPredictors; i++) {
            String str3 = String.valueOf(str2) + Utils.doubleToString(i, 8, 0);
            for (int i2 = 0; i2 < this.m_NumClasses - 1; i2++) {
                str3 = String.valueOf(str3) + " " + Utils.doubleToString(this.m_Par[i][i2], 12, 4);
            }
            str2 = String.valueOf(str3) + "\n";
        }
        String str4 = String.valueOf(str2) + "Intercept ";
        for (int i3 = 0; i3 < this.m_NumClasses - 1; i3++) {
            str4 = String.valueOf(str4) + " " + Utils.doubleToString(this.m_Par[0][i3], 10, 4);
        }
        String str5 = String.valueOf(String.valueOf(str4) + "\n") + "\nOdds Ratios...\nVariable         O.R.\n";
        for (int i4 = 1; i4 <= this.m_NumPredictors; i4++) {
            String str6 = String.valueOf(str5) + Utils.doubleToString(i4, 8, 0);
            for (int i5 = 0; i5 < this.m_NumClasses - 1; i5++) {
                double exp = Math.exp(this.m_Par[i4][i5]);
                str6 = String.valueOf(str6) + " " + (exp > 1.0E10d ? new StringBuilder().append(exp).toString() : Utils.doubleToString(exp, 12, 4));
            }
            str5 = String.valueOf(str6) + "\n";
        }
        return str5;
    }

    @Override // net.sf.javaml.classification.Classifier
    public int classifyInstance(Instance instance) {
        double[] distributionForInstance = distributionForInstance(instance);
        int i = 0;
        double d = distributionForInstance[0];
        for (int i2 = 1; i2 < distributionForInstance.length; i2++) {
            if (distributionForInstance[i2] > d) {
                d = distributionForInstance[i2];
                i = i2;
            }
        }
        return i;
    }
}
