package iitb.CRF;

import cern.colt.function.DoubleDoubleFunction;
import cern.colt.function.DoubleFunction;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import riso.numerical.LBFGS;

/* loaded from: input_file:iitb/CRF/Trainer.class */
public class Trainer {
    protected int numF;
    protected int numY;
    double[] gradLogli;
    double[] diag;
    double[] lambda;
    protected boolean reuseM;
    protected double[] ExpF;
    double[] scale;
    double[] rLogScale;
    protected DoubleMatrix2D Mi_YY;
    protected DoubleMatrix1D Ri_Y;
    protected DoubleMatrix1D alpha_Y;
    protected DoubleMatrix1D newAlpha_Y;
    protected DoubleMatrix1D[] beta_Y;
    protected DoubleMatrix1D tmp_Y;
    static MultFunc multFunc = new MultFunc();
    protected static SumFunc sumFunc = new SumFunc();
    protected DataIter diter;
    FeatureGenerator featureGenerator;
    protected CrfParams params;
    EdgeGenerator edgeGen;
    protected int icall;
    FeatureGenCache featureGenCache;
    protected boolean initMDone = false;
    MultSingle constMultiplier = new MultSingle();
    Evaluator evaluator = null;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:iitb/CRF/Trainer$MultFunc.class */
    public static class MultFunc implements DoubleDoubleFunction {
        MultFunc() {
        }

        public double apply(double d, double d2) {
            return d * d2;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:iitb/CRF/Trainer$MultSingle.class */
    public class MultSingle implements DoubleFunction {
        public double multiplicator = 1.0d;

        MultSingle() {
        }

        public double apply(double d) {
            return d * this.multiplicator;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:iitb/CRF/Trainer$SumFunc.class */
    public static class SumFunc implements DoubleDoubleFunction {
        SumFunc() {
        }

        public double apply(double d, double d2) {
            return d + d2;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double norm(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * dArr[i];
        }
        return Math.sqrt(d);
    }

    public Trainer(CrfParams crfParams) {
        this.params = crfParams;
    }

    public void train(CRF crf, DataIter dataIter, double[] dArr, Evaluator evaluator) {
        init(crf, dataIter, dArr);
        this.evaluator = evaluator;
        if (this.params.debugLvl > 0) {
            Util.printDbg("Number of features :" + this.lambda.length);
        }
        doTrain();
    }

    double getInitValue() {
        return this.params.initValue;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void init(CRF crf, DataIter dataIter, double[] dArr) {
        this.edgeGen = crf.edgeGen;
        this.lambda = dArr;
        this.numY = crf.numY;
        this.diter = dataIter;
        this.featureGenerator = crf.featureGenerator;
        this.numF = this.featureGenerator.numFeatures();
        this.gradLogli = new double[this.numF];
        this.diag = new double[this.numF];
        this.ExpF = new double[this.lambda.length];
        initMatrices();
        this.reuseM = this.params.reuseM;
        if (!this.params.miscOptions.getProperty("cache", "false").equals("true")) {
            this.featureGenCache = null;
        } else {
            this.featureGenCache = new FeatureGenCache(this.featureGenerator);
            this.featureGenerator = this.featureGenCache;
        }
    }

    void initMatrices() {
        System.out.println("Trainer.initMatrices() numY:" + this.numY);
        this.Mi_YY = new DenseDoubleMatrix2D(this.numY, this.numY);
        this.Ri_Y = new DenseDoubleMatrix1D(this.numY);
        this.alpha_Y = new DenseDoubleMatrix1D(this.numY);
        this.newAlpha_Y = new DenseDoubleMatrix1D(this.numY);
        this.tmp_Y = new DenseDoubleMatrix1D(this.numY);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void doTrain() {
        this.icall = 0;
        int[] iArr = {this.params.debugLvl - 2, this.params.debugLvl - 1};
        int[] iArr2 = {0};
        for (int i = 0; i < this.lambda.length; i++) {
            this.lambda[i] = getInitValue();
        }
        do {
            double computeFunctionGradient = computeFunctionGradient(this.lambda, this.gradLogli) * (-1.0d);
            for (int i2 = 0; i2 < this.lambda.length; i2++) {
                double[] dArr = this.gradLogli;
                int i3 = i2;
                dArr[i3] = dArr[i3] * (-1.0d);
            }
            if (this.evaluator != null && !this.evaluator.evaluate()) {
                return;
            }
            try {
                LBFGS.lbfgs(this.numF, this.params.mForHessian, this.lambda, computeFunctionGradient, this.gradLogli, false, this.diag, iArr, this.params.epsForConvergence, 1.0E-16d, iArr2);
                this.icall++;
                if (iArr2[0] == 0) {
                    return;
                }
            } catch (LBFGS.ExceptionWithIflag e) {
                System.err.println("CRF: lbfgs failed.\n" + e);
                if (e.iflag == -1) {
                    System.err.println("Possible reasons could be: \n \t 1. Bug in the feature generation or data handling code\n\t 2. Not enough features to make observed feature value==expected value\n");
                    return;
                }
                return;
            }
        } while (this.icall <= this.params.maxIters);
    }

    protected double computeFunctionGradient(double[] dArr, double[] dArr2) {
        this.initMDone = false;
        if (this.params.trainerType.equals("ll")) {
            return computeFunctionGradientLL(dArr, dArr2);
        }
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            try {
                dArr2[i] = (-1.0d) * dArr[i] * this.params.invSigmaSquare;
                d -= ((dArr[i] * dArr[i]) * this.params.invSigmaSquare) / 2.0d;
            } catch (Exception e) {
                System.out.println("Alpha-i " + this.alpha_Y.toString());
                System.out.println("Ri " + this.Ri_Y.toString());
                System.out.println("Mi " + this.Mi_YY.toString());
                e.printStackTrace();
                System.exit(0);
            }
        }
        boolean z = this.params.doScaling;
        this.diter.startScan();
        if (this.featureGenCache != null) {
            this.featureGenCache.startDataScan();
        }
        int i2 = 0;
        while (this.diter.hasNext()) {
            DataSequence next = this.diter.next();
            if (this.featureGenCache != null) {
                this.featureGenCache.nextDataIndex();
            }
            if (this.params.debugLvl > 1) {
                Util.printDbg("Read next seq: " + i2 + " logli " + d);
            }
            this.alpha_Y.assign(1.0d);
            for (int i3 = 0; i3 < dArr.length; i3++) {
                this.ExpF[i3] = 0.0d;
            }
            if (this.beta_Y == null || this.beta_Y.length < next.length()) {
                this.beta_Y = new DenseDoubleMatrix1D[2 * next.length()];
                for (int i4 = 0; i4 < this.beta_Y.length; i4++) {
                    this.beta_Y[i4] = new DenseDoubleMatrix1D(this.numY);
                }
                this.scale = new double[2 * next.length()];
            }
            this.scale[next.length() - 1] = z ? this.numY : 1;
            this.beta_Y[next.length() - 1].assign(1.0d / this.scale[next.length() - 1]);
            for (int length = next.length() - 1; length > 0; length--) {
                if (this.params.debugLvl > 2) {
                    Util.printDbg("Features fired");
                }
                this.initMDone = computeLogMi(this.featureGenerator, dArr, next, length, this.Mi_YY, this.Ri_Y, true, this.reuseM, this.initMDone);
                this.tmp_Y.assign(this.beta_Y[length]);
                this.tmp_Y.assign(this.Ri_Y, multFunc);
                RobustMath.Mult(this.Mi_YY, this.tmp_Y, this.beta_Y[length - 1], 1.0d, 0.0d, false, this.edgeGen);
                this.scale[length - 1] = z ? this.beta_Y[length - 1].zSum() : 1.0d;
                if (this.scale[length - 1] < 1.0d && this.scale[length - 1] > -1.0d) {
                    this.scale[length - 1] = 1.0d;
                }
                this.constMultiplier.multiplicator = 1.0d / this.scale[length - 1];
                this.beta_Y[length - 1].assign(this.constMultiplier);
            }
            double d2 = 0.0d;
            for (int i5 = 0; i5 < next.length(); i5++) {
                this.initMDone = computeLogMi(this.featureGenerator, dArr, next, i5, this.Mi_YY, this.Ri_Y, true, this.reuseM, this.initMDone);
                this.featureGenerator.startScanFeaturesAt(next, i5);
                if (i5 > 0) {
                    this.tmp_Y.assign(this.alpha_Y);
                    RobustMath.Mult(this.Mi_YY, this.tmp_Y, this.newAlpha_Y, 1.0d, 0.0d, true, this.edgeGen);
                    this.newAlpha_Y.assign(this.Ri_Y, multFunc);
                } else {
                    this.newAlpha_Y.assign(this.Ri_Y);
                }
                while (this.featureGenerator.hasNext()) {
                    Feature next2 = this.featureGenerator.next();
                    int index = next2.index();
                    int y = next2.y();
                    int yprev = next2.yprev();
                    float value = next2.value();
                    if (next.y(i5) == y && ((i5 - 1 >= 0 && yprev == next.y(i5 - 1)) || yprev < 0)) {
                        dArr2[index] = dArr2[index] + value;
                        d2 += value * dArr[index];
                    }
                    if (yprev < 0) {
                        double[] dArr3 = this.ExpF;
                        dArr3[index] = dArr3[index] + (this.newAlpha_Y.get(y) * value * this.beta_Y[i5].get(y));
                    } else {
                        double[] dArr4 = this.ExpF;
                        dArr4[index] = dArr4[index] + (this.alpha_Y.get(yprev) * this.Ri_Y.get(y) * this.Mi_YY.get(yprev, y) * value * this.beta_Y[i5].get(y));
                    }
                }
                this.alpha_Y.assign(this.newAlpha_Y);
                this.constMultiplier.multiplicator = 1.0d / this.scale[i5];
                this.alpha_Y.assign(this.constMultiplier);
                if (this.params.debugLvl > 2) {
                    System.out.println("Alpha-i " + this.alpha_Y.toString());
                    System.out.println("Ri " + this.Ri_Y.toString());
                    System.out.println("Mi " + this.Mi_YY.toString());
                    System.out.println("Beta-i " + this.beta_Y[i5].toString());
                }
            }
            double zSum = this.alpha_Y.zSum();
            double log = d2 - log(zSum);
            for (int i6 = 0; i6 < next.length(); i6++) {
                log -= log(this.scale[i6]);
            }
            d += log;
            for (int i7 = 0; i7 < dArr2.length; i7++) {
                int i8 = i7;
                dArr2[i8] = dArr2[i8] - (this.ExpF[i7] / zSum);
            }
            if (this.params.debugLvl > 1) {
                System.out.println("Sequence " + log + " logli " + d + " log(Zx) " + Math.log(zSum) + " Zx " + zSum);
            }
            i2++;
        }
        if (this.params.debugLvl > 2) {
            for (double d3 : dArr) {
                System.out.print(String.valueOf(d3) + " ");
            }
            System.out.println(" :x");
            for (int i9 = 0; i9 < dArr.length; i9++) {
                System.out.println(String.valueOf(this.featureGenerator.featureName(i9)) + " " + dArr2[i9] + " ");
            }
            System.out.println(" :g");
        }
        if (this.params.debugLvl > 0) {
            Util.printDbg("Iter " + this.icall + " log likelihood " + d + " norm(grad logli) " + norm(dArr2) + " norm(x) " + norm(dArr));
        }
        if (this.icall == 0) {
            System.out.println("Number of training records" + i2);
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void computeLogMi(FeatureGenerator featureGenerator, double[] dArr, DoubleMatrix2D doubleMatrix2D, DoubleMatrix1D doubleMatrix1D, boolean z) {
        computeLogMi(featureGenerator, dArr, doubleMatrix2D, doubleMatrix1D, z, false, false);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static boolean computeLogMi(FeatureGenerator featureGenerator, double[] dArr, DoubleMatrix2D doubleMatrix2D, DoubleMatrix1D doubleMatrix1D, boolean z, boolean z2, boolean z3) {
        if (z2 && z3) {
            doubleMatrix2D = null;
        } else {
            z3 = false;
        }
        if (doubleMatrix2D != null) {
            doubleMatrix2D.assign(0.0d);
        }
        doubleMatrix1D.assign(0.0d);
        while (featureGenerator.hasNext()) {
            Feature next = featureGenerator.next();
            int index = next.index();
            int y = next.y();
            int yprev = next.yprev();
            float value = next.value();
            if (yprev < 0) {
                doubleMatrix1D.setQuick(y, doubleMatrix1D.getQuick(y) + (dArr[index] * value));
            } else if (doubleMatrix2D != null) {
                doubleMatrix2D.setQuick(yprev, y, doubleMatrix2D.getQuick(yprev, y) + (dArr[index] * value));
                z3 = true;
            }
        }
        if (z) {
            for (int size = doubleMatrix1D.size() - 1; size >= 0; size--) {
                doubleMatrix1D.setQuick(size, expE(doubleMatrix1D.getQuick(size)));
                if (doubleMatrix2D != null) {
                    for (int columns = doubleMatrix2D.columns() - 1; columns >= 0; columns--) {
                        doubleMatrix2D.setQuick(size, columns, expE(doubleMatrix2D.getQuick(size, columns)));
                    }
                }
            }
        }
        return z3;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void computeLogMi(FeatureGenerator featureGenerator, double[] dArr, DataSequence dataSequence, int i, DoubleMatrix2D doubleMatrix2D, DoubleMatrix1D doubleMatrix1D, boolean z) {
        computeLogMi(featureGenerator, dArr, dataSequence, i, doubleMatrix2D, doubleMatrix1D, z, false, false);
    }

    static boolean computeLogMi(FeatureGenerator featureGenerator, double[] dArr, DataSequence dataSequence, int i, DoubleMatrix2D doubleMatrix2D, DoubleMatrix1D doubleMatrix1D, boolean z, boolean z2, boolean z3) {
        featureGenerator.startScanFeaturesAt(dataSequence, i);
        return computeLogMi(featureGenerator, dArr, doubleMatrix2D, doubleMatrix1D, z, z2, z3);
    }

    protected double computeFunctionGradientLL(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            try {
                dArr2[i] = (-1.0d) * dArr[i] * this.params.invSigmaSquare;
                d -= ((dArr[i] * dArr[i]) * this.params.invSigmaSquare) / 2.0d;
            } catch (Exception e) {
                System.out.println("Alpha-i " + this.alpha_Y.toString());
                System.out.println("Ri " + this.Ri_Y.toString());
                System.out.println("Mi " + this.Mi_YY.toString());
                e.printStackTrace();
                System.exit(0);
            }
        }
        this.diter.startScan();
        if (this.featureGenCache != null) {
            this.featureGenCache.startDataScan();
        }
        int i2 = 0;
        while (this.diter.hasNext()) {
            DataSequence next = this.diter.next();
            if (this.featureGenCache != null) {
                this.featureGenCache.nextDataIndex();
            }
            if (this.params.debugLvl > 1) {
                Util.printDbg("Read next seq: " + i2 + " logli " + d);
            }
            this.alpha_Y.assign(0.0d);
            for (int i3 = 0; i3 < dArr.length; i3++) {
                this.ExpF[i3] = RobustMath.LOG0;
            }
            if (this.beta_Y == null || this.beta_Y.length < next.length()) {
                this.beta_Y = new DenseDoubleMatrix1D[2 * next.length()];
                for (int i4 = 0; i4 < this.beta_Y.length; i4++) {
                    this.beta_Y[i4] = new DenseDoubleMatrix1D(this.numY);
                }
            }
            this.beta_Y[next.length() - 1].assign(0.0d);
            for (int length = next.length() - 1; length > 0; length--) {
                int i5 = this.params.debugLvl;
                this.initMDone = computeLogMi(this.featureGenerator, dArr, next, length, this.Mi_YY, this.Ri_Y, false, this.reuseM, this.initMDone);
                this.tmp_Y.assign(this.beta_Y[length]);
                this.tmp_Y.assign(this.Ri_Y, sumFunc);
                RobustMath.logMult(this.Mi_YY, this.tmp_Y, this.beta_Y[length - 1], 1.0d, 0.0d, false, this.edgeGen);
            }
            double d2 = 0.0d;
            for (int i6 = 0; i6 < next.length(); i6++) {
                this.initMDone = computeLogMi(this.featureGenerator, dArr, next, i6, this.Mi_YY, this.Ri_Y, false, this.reuseM, this.initMDone);
                this.featureGenerator.startScanFeaturesAt(next, i6);
                if (i6 > 0) {
                    this.tmp_Y.assign(this.alpha_Y);
                    RobustMath.logMult(this.Mi_YY, this.tmp_Y, this.newAlpha_Y, 1.0d, 0.0d, true, this.edgeGen);
                    this.newAlpha_Y.assign(this.Ri_Y, sumFunc);
                } else {
                    this.newAlpha_Y.assign(this.Ri_Y);
                }
                while (this.featureGenerator.hasNext()) {
                    Feature next2 = this.featureGenerator.next();
                    int index = next2.index();
                    int y = next2.y();
                    int yprev = next2.yprev();
                    float value = next2.value();
                    if (next.y(i6) == y && ((i6 - 1 >= 0 && yprev == next.y(i6 - 1)) || yprev < 0)) {
                        dArr2[index] = dArr2[index] + value;
                        d2 += value * dArr[index];
                        if (this.params.debugLvl > 2) {
                            System.out.println("Feature fired " + index + " " + next2);
                        }
                    }
                    if (yprev < 0) {
                        this.ExpF[index] = RobustMath.logSumExp(this.ExpF[index], this.newAlpha_Y.get(y) + RobustMath.log(value) + this.beta_Y[i6].get(y));
                    } else {
                        this.ExpF[index] = RobustMath.logSumExp(this.ExpF[index], this.alpha_Y.get(yprev) + this.Ri_Y.get(y) + this.Mi_YY.get(yprev, y) + RobustMath.log(value) + this.beta_Y[i6].get(y));
                    }
                }
                this.alpha_Y.assign(this.newAlpha_Y);
                if (this.params.debugLvl > 2) {
                    System.out.println("Alpha-i " + this.alpha_Y.toString());
                    System.out.println("Ri " + this.Ri_Y.toString());
                    System.out.println("Mi " + this.Mi_YY.toString());
                    System.out.println("Beta-i " + this.beta_Y[i6].toString());
                }
            }
            double logSumExp = RobustMath.logSumExp(this.alpha_Y);
            double d3 = d2 - logSumExp;
            d += d3;
            for (int i7 = 0; i7 < dArr2.length; i7++) {
                int i8 = i7;
                dArr2[i8] = dArr2[i8] - RobustMath.exp(this.ExpF[i7] - logSumExp);
            }
            if (this.params.debugLvl > 1) {
                System.out.println("Sequence " + d3 + " logli " + d + " log(Zx) " + logSumExp + " Zx " + Math.exp(logSumExp));
            }
            i2++;
        }
        if (this.params.debugLvl > 2) {
            for (double d4 : dArr) {
                System.out.print(String.valueOf(d4) + " ");
            }
            System.out.println(" :x");
            for (int i9 = 0; i9 < dArr.length; i9++) {
                System.out.print(String.valueOf(dArr2[i9]) + " ");
            }
            System.out.println(" :g");
        }
        if (this.params.debugLvl > 0) {
            Util.printDbg("Iteration " + this.icall + " log-likelihood " + d + " norm(grad logli) " + norm(dArr2) + " norm(x) " + norm(dArr));
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double log(double d) {
        try {
            return logE(d);
        } catch (Exception e) {
            System.out.println(e.getMessage());
            e.printStackTrace();
            return -1.7976931348623157E308d;
        }
    }

    static double logE(double d) throws Exception {
        double log = Math.log(d);
        if (Double.isNaN(log) || Double.isInfinite(log)) {
            throw new Exception("Overflow error when taking log of " + d);
        }
        return log;
    }

    static double expE(double d) {
        double exp = RobustMath.exp(d);
        if (!Double.isNaN(exp) && !Double.isInfinite(exp)) {
            return exp;
        }
        try {
            throw new Exception("Overflow error when taking exp of " + d + "\n Try running the CRF with the following option \"trainer ll\" to perform computations in the log-space.");
        } catch (Exception e) {
            System.out.println(e.getMessage());
            e.printStackTrace();
            return Double.MAX_VALUE;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double expLE(double d) {
        double exp = RobustMath.exp(d);
        if (!Double.isNaN(exp) && !Double.isInfinite(exp)) {
            return exp;
        }
        try {
            throw new Exception("Overflow error when taking exp of " + d + " you might need to redesign feature values so as to not reach such high values");
        } catch (Exception e) {
            System.out.println(e.getMessage());
            e.printStackTrace();
            return Double.MAX_VALUE;
        }
    }
}
