package iitb2.CRF;

import cern.colt.matrix.impl.DenseDoubleMatrix1D;

/* loaded from: input_file:iitb2/CRF/NestedTrainer.class */
class NestedTrainer extends Trainer {
    DenseDoubleMatrix1D[] alpha_Y_Array;

    public NestedTrainer(CrfParams crfParams) {
        super(crfParams);
    }

    @Override // iitb2.CRF.Trainer
    protected double computeFunctionGradient(double[] dArr, double[] dArr2) {
        if (this.params.doScaling) {
            return computeFunctionGradientLL(dArr, dArr2);
        }
        try {
            FeatureGeneratorNested featureGeneratorNested = (FeatureGeneratorNested) this.featureGenerator;
            double d = 0.0d;
            for (int i = 0; i < dArr.length; i++) {
                dArr2[i] = (-1.0d) * dArr[i] * this.params.invSigmaSquare;
                d -= ((dArr[i] * dArr[i]) * this.params.invSigmaSquare) / 2.0d;
            }
            this.diter.startScan();
            if (this.featureGenCache != null) {
                this.featureGenCache.startDataScan();
            }
            int i2 = 0;
            while (this.diter.hasNext()) {
                SegmentDataSequence segmentDataSequence = (SegmentDataSequence) this.diter.next();
                if (this.featureGenCache != null) {
                    this.featureGenCache.nextDataIndex();
                }
                if (this.params.debugLvl > 1) {
                    Util.printDbg("Read next seq: " + i2 + " logli " + d);
                }
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    this.ExpF[i3] = 0.0d;
                }
                if (this.alpha_Y_Array == null || this.alpha_Y_Array.length < segmentDataSequence.length() - (-1)) {
                    this.alpha_Y_Array = new DenseDoubleMatrix1D[2 * segmentDataSequence.length()];
                    for (int i4 = 0; i4 < this.alpha_Y_Array.length; i4++) {
                        this.alpha_Y_Array[i4] = new DenseDoubleMatrix1D(this.numY);
                    }
                }
                if (this.beta_Y == null || this.beta_Y.length < segmentDataSequence.length()) {
                    this.beta_Y = new DenseDoubleMatrix1D[2 * segmentDataSequence.length()];
                    for (int i5 = 0; i5 < this.beta_Y.length; i5++) {
                        this.beta_Y[i5] = new DenseDoubleMatrix1D(this.numY);
                    }
                    this.scale = new double[2 * segmentDataSequence.length()];
                }
                this.beta_Y[segmentDataSequence.length() - 1].assign(1.0d);
                this.scale[segmentDataSequence.length() - 1] = 1.0d;
                for (int length = segmentDataSequence.length() - 2; length >= 0; length--) {
                    if (0 != 0 && length + featureGeneratorNested.maxMemory() < segmentDataSequence.length()) {
                        int maxMemory = length + featureGeneratorNested.maxMemory();
                        this.scale[maxMemory] = this.beta_Y[maxMemory].zSum();
                        this.constMultiplier.multiplicator = 1.0d / this.scale[maxMemory];
                        for (int i6 = length + 1; i6 <= maxMemory; i6++) {
                            this.beta_Y[i6].assign(this.constMultiplier);
                        }
                    }
                    this.beta_Y[length].assign(0.0d);
                    this.scale[length] = 1.0d;
                    for (int i7 = 1; i7 <= featureGeneratorNested.maxMemory() && length + i7 < segmentDataSequence.length(); i7++) {
                        featureGeneratorNested.startScanFeaturesAt(segmentDataSequence, length, length + i7);
                        this.initMDone = computeLogMi((FeatureGenerator) featureGeneratorNested, dArr, this.Mi_YY, this.Ri_Y, true, this.reuseM, this.initMDone);
                        this.tmp_Y.assign(this.beta_Y[length + i7]);
                        this.tmp_Y.assign(this.Ri_Y, multFunc);
                        this.Mi_YY.zMult(this.tmp_Y, this.beta_Y[length], 1.0d, 1.0d, false);
                    }
                }
                double d2 = 0.0d;
                this.alpha_Y_Array[0].assign(1.0d);
                int i8 = 0;
                int i9 = -1;
                boolean z = false;
                int i10 = 0;
                while (true) {
                    if (i10 >= segmentDataSequence.length()) {
                        break;
                    }
                    if (i9 < i10) {
                        i8 = i10;
                        i9 = segmentDataSequence.getSegmentEnd(i10);
                    }
                    if ((i9 - i8) + 1 > featureGeneratorNested.maxMemory()) {
                        if (this.icall <= 1) {
                            System.out.println("Ignoring record with segment length greater than maxMemory " + i2);
                        }
                        z = true;
                    } else {
                        this.alpha_Y_Array[i10 - (-1)].assign(0.0d);
                        float f = 1.0f;
                        for (int maxMemory2 = (i10 - featureGeneratorNested.maxMemory()) - (-1); maxMemory2 <= i10 - 1; maxMemory2++) {
                            if (maxMemory2 >= 0) {
                                f = (float) (f * this.scale[maxMemory2]);
                            }
                        }
                        for (int i11 = 1; i11 <= featureGeneratorNested.maxMemory() && i10 - i11 >= -1; i11++) {
                            featureGeneratorNested.startScanFeaturesAt(segmentDataSequence, i10 - i11, i10);
                            this.initMDone = computeLogMi((FeatureGenerator) featureGeneratorNested, dArr, this.Mi_YY, this.Ri_Y, true, this.reuseM, this.initMDone);
                            featureGeneratorNested.startScanFeaturesAt(segmentDataSequence, i10 - i11, i10);
                            boolean z2 = (i10 - i11) + 1 == i8 && i10 == i9;
                            while (featureGeneratorNested.hasNext()) {
                                Feature next = featureGeneratorNested.next();
                                int index = next.index();
                                int y = next.y();
                                int yprev = next.yprev();
                                float value = next.value();
                                if ((z2 && segmentDataSequence.y(i10) == y) && ((i10 - i11 >= 0 && yprev == segmentDataSequence.y(i10 - i11)) || yprev < 0)) {
                                    dArr2[index] = dArr2[index] + value;
                                    d2 += value * dArr[index];
                                }
                                if (yprev < 0) {
                                    for (int i12 = 0; i12 < this.Mi_YY.rows(); i12++) {
                                        double[] dArr3 = this.ExpF;
                                        dArr3[index] = dArr3[index] + (((((this.alpha_Y_Array[(i10 - i11) - (-1)].get(i12) * this.Ri_Y.get(y)) * this.Mi_YY.get(i12, y)) * value) * this.beta_Y[i10].get(y)) / f);
                                    }
                                } else {
                                    double[] dArr4 = this.ExpF;
                                    dArr4[index] = dArr4[index] + (((((this.alpha_Y_Array[(i10 - i11) - (-1)].get(yprev) * this.Ri_Y.get(y)) * this.Mi_YY.get(yprev, y)) * value) * this.beta_Y[i10].get(y)) / f);
                                }
                            }
                            this.Mi_YY.zMult(this.alpha_Y_Array[(i10 - i11) - (-1)], this.tmp_Y, 1.0d, 0.0d, true);
                            this.tmp_Y.assign(this.Ri_Y, multFunc);
                            this.alpha_Y_Array[i10 - (-1)].assign(this.tmp_Y, sumFunc);
                        }
                        if ((i10 - (-1)) - featureGeneratorNested.maxMemory() >= 0) {
                            this.constMultiplier.multiplicator = 1.0d / this.scale[(i10 - (-1)) - featureGeneratorNested.maxMemory()];
                            for (int i13 = r0; i13 <= i10 - (-1); i13++) {
                                this.alpha_Y_Array[i13].assign(this.constMultiplier);
                            }
                        }
                        if (this.params.debugLvl > 2) {
                            System.out.println("Alpha-i " + this.alpha_Y_Array[i10 - (-1)].toString());
                            System.out.println("Ri " + this.Ri_Y.toString());
                            System.out.println("Mi " + this.Mi_YY.toString());
                            System.out.println("Beta-i " + this.beta_Y[i10].toString());
                        }
                        if (this.params.debugLvl > 1) {
                            System.out.println(" pos " + i10 + " " + d2);
                        }
                        i10++;
                    }
                }
                if (!z) {
                    double zSum = this.alpha_Y_Array[(segmentDataSequence.length() - 1) - (-1)].zSum();
                    double log = d2 - log(zSum);
                    for (int i14 = 0; i14 < (segmentDataSequence.length() - (-1)) - featureGeneratorNested.maxMemory(); i14++) {
                        log -= log(this.scale[i14]);
                    }
                    d += log;
                    for (int i15 = 0; i15 < dArr2.length; i15++) {
                        int i16 = i15;
                        dArr2[i16] = dArr2[i16] - (this.ExpF[i15] / zSum);
                    }
                    if (this.params.debugLvl > 1) {
                        System.out.println("Sequence " + log + " " + d + " " + zSum);
                        System.out.println("Last Alpha-i " + this.alpha_Y_Array[(segmentDataSequence.length() - 1) - (-1)].toString());
                    }
                }
                i2++;
            }
            if (this.params.debugLvl > 2) {
                for (double d3 : dArr) {
                    System.out.print(String.valueOf(d3) + " ");
                }
                System.out.println(" :x");
                for (int i17 = 0; i17 < dArr.length; i17++) {
                    System.out.print(String.valueOf(dArr2[i17]) + " ");
                }
                System.out.println(" :g");
            }
            if (this.params.debugLvl > 0) {
                Util.printDbg("Iter " + this.icall + " loglikelihood " + d + " gnorm " + norm(dArr2) + " xnorm " + norm(dArr));
            }
            return d;
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
            return 0.0d;
        }
    }

    @Override // iitb2.CRF.Trainer
    protected double computeFunctionGradientLL(double[] dArr, double[] dArr2) {
        try {
            FeatureGeneratorNested featureGeneratorNested = (FeatureGeneratorNested) this.featureGenerator;
            double d = 0.0d;
            for (int i = 0; i < dArr.length; i++) {
                dArr2[i] = (-1.0d) * dArr[i] * this.params.invSigmaSquare;
                d -= ((dArr[i] * dArr[i]) * this.params.invSigmaSquare) / 2.0d;
            }
            this.diter.startScan();
            if (this.featureGenCache != null) {
                this.featureGenCache.startDataScan();
            }
            int i2 = 0;
            while (this.diter.hasNext()) {
                SegmentDataSequence segmentDataSequence = (SegmentDataSequence) this.diter.next();
                if (this.featureGenCache != null) {
                    this.featureGenCache.nextDataIndex();
                }
                if (this.params.debugLvl > 1) {
                    Util.printDbg("Read next seq: " + i2 + " logli " + d);
                }
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    this.ExpF[i3] = RobustMath.LOG0;
                }
                if (this.alpha_Y_Array == null || this.alpha_Y_Array.length < segmentDataSequence.length() - (-1)) {
                    this.alpha_Y_Array = new DenseDoubleMatrix1D[2 * segmentDataSequence.length()];
                    for (int i4 = 0; i4 < this.alpha_Y_Array.length; i4++) {
                        this.alpha_Y_Array[i4] = new DenseDoubleMatrix1D(this.numY);
                    }
                }
                if (this.beta_Y == null || this.beta_Y.length < segmentDataSequence.length()) {
                    this.beta_Y = new DenseDoubleMatrix1D[2 * segmentDataSequence.length()];
                    for (int i5 = 0; i5 < this.beta_Y.length; i5++) {
                        this.beta_Y[i5] = new DenseDoubleMatrix1D(this.numY);
                    }
                }
                this.beta_Y[segmentDataSequence.length() - 1].assign(0.0d);
                for (int length = segmentDataSequence.length() - 2; length >= 0; length--) {
                    this.beta_Y[length].assign(RobustMath.LOG0);
                    for (int i6 = 1; i6 <= featureGeneratorNested.maxMemory() && length + i6 < segmentDataSequence.length(); i6++) {
                        featureGeneratorNested.startScanFeaturesAt(segmentDataSequence, length, length + i6);
                        this.initMDone = computeLogMi((FeatureGenerator) featureGeneratorNested, dArr, this.Mi_YY, this.Ri_Y, false, this.reuseM, this.initMDone);
                        this.tmp_Y.assign(this.beta_Y[length + i6]);
                        this.tmp_Y.assign(this.Ri_Y, sumFunc);
                        RobustMath.logMult(this.Mi_YY, this.tmp_Y, this.beta_Y[length], 1.0d, 1.0d, false, this.edgeGen);
                    }
                }
                double d2 = 0.0d;
                this.alpha_Y_Array[0].assign(0.0d);
                int i7 = 0;
                int i8 = -1;
                boolean z = false;
                int i9 = 0;
                while (true) {
                    if (i9 >= segmentDataSequence.length()) {
                        break;
                    }
                    if (i8 < i9) {
                        i7 = i9;
                        i8 = segmentDataSequence.getSegmentEnd(i9);
                    }
                    if ((i8 - i7) + 1 > featureGeneratorNested.maxMemory()) {
                        if (this.icall == 0) {
                            System.out.println("Ignoring record with segment length greater than maxMemory " + i2);
                        }
                        z = true;
                    } else {
                        this.alpha_Y_Array[i9 - (-1)].assign(RobustMath.LOG0);
                        for (int i10 = 1; i10 <= featureGeneratorNested.maxMemory() && i9 - i10 >= -1; i10++) {
                            featureGeneratorNested.startScanFeaturesAt(segmentDataSequence, i9 - i10, i9);
                            this.initMDone = computeLogMi((FeatureGenerator) featureGeneratorNested, dArr, this.Mi_YY, this.Ri_Y, false, this.reuseM, this.initMDone);
                            featureGeneratorNested.startScanFeaturesAt(segmentDataSequence, i9 - i10, i9);
                            boolean z2 = (i9 - i10) + 1 == i7 && i9 == i8;
                            while (featureGeneratorNested.hasNext()) {
                                Feature next = featureGeneratorNested.next();
                                int index = next.index();
                                int y = next.y();
                                int yprev = next.yprev();
                                float value = next.value();
                                if ((z2 && segmentDataSequence.y(i9) == y) && ((i9 - i10 >= 0 && yprev == segmentDataSequence.y(i9 - i10)) || yprev < 0)) {
                                    dArr2[index] = dArr2[index] + value;
                                    d2 += value * dArr[index];
                                }
                                if (yprev < 0 && i9 - i10 >= 0) {
                                    for (int i11 = 0; i11 < this.Mi_YY.rows(); i11++) {
                                        this.ExpF[index] = RobustMath.logSumExp(this.ExpF[index], this.alpha_Y_Array[(i9 - i10) - (-1)].get(i11) + this.Ri_Y.get(y) + this.Mi_YY.get(i11, y) + RobustMath.log(value) + this.beta_Y[i9].get(y));
                                    }
                                } else if (i9 - i10 < 0) {
                                    this.ExpF[index] = RobustMath.logSumExp(this.ExpF[index], this.Ri_Y.get(y) + RobustMath.log(value) + this.beta_Y[i9].get(y));
                                } else {
                                    this.ExpF[index] = RobustMath.logSumExp(this.ExpF[index], this.alpha_Y_Array[(i9 - i10) - (-1)].get(yprev) + this.Ri_Y.get(y) + this.Mi_YY.get(yprev, y) + RobustMath.log(value) + this.beta_Y[i9].get(y));
                                }
                            }
                            if (i9 - i10 >= 0) {
                                RobustMath.logMult(this.Mi_YY, this.alpha_Y_Array[(i9 - i10) - (-1)], this.tmp_Y, 1.0d, 0.0d, true, this.edgeGen);
                                this.tmp_Y.assign(this.Ri_Y, sumFunc);
                                RobustMath.logSumExp(this.alpha_Y_Array[i9 - (-1)], this.tmp_Y);
                            } else {
                                RobustMath.logSumExp(this.alpha_Y_Array[i9 - (-1)], this.Ri_Y);
                            }
                        }
                        if (this.params.debugLvl > 2) {
                            System.out.println("Alpha-i " + this.alpha_Y_Array[i9 - (-1)].toString());
                            System.out.println("Ri " + this.Ri_Y.toString());
                            System.out.println("Mi " + this.Mi_YY.toString());
                            System.out.println("Beta-i " + this.beta_Y[i9].toString());
                        }
                        if (this.params.debugLvl > 1) {
                            System.out.println(" pos " + i9 + " " + d2);
                        }
                        i9++;
                    }
                }
                if (!z) {
                    double logSumExp = RobustMath.logSumExp(this.alpha_Y_Array[(segmentDataSequence.length() - 1) - (-1)]);
                    double d3 = d2 - logSumExp;
                    d += d3;
                    for (int i12 = 0; i12 < dArr2.length; i12++) {
                        int i13 = i12;
                        dArr2[i13] = dArr2[i13] - RobustMath.exp(this.ExpF[i12] - logSumExp);
                    }
                    if (this.params.debugLvl > 1) {
                        System.out.println("Sequence " + d3 + " " + d + " " + Math.exp(logSumExp));
                        System.out.println("Last Alpha-i " + this.alpha_Y_Array[(segmentDataSequence.length() - 1) - (-1)].toString());
                    }
                }
                i2++;
            }
            if (this.params.debugLvl > 2) {
                for (double d4 : dArr) {
                    System.out.print(String.valueOf(d4) + " ");
                }
                System.out.println(" :x");
                for (int i14 = 0; i14 < dArr.length; i14++) {
                    System.out.print(String.valueOf(dArr2[i14]) + " ");
                }
                System.out.println(" :g");
            }
            if (this.params.debugLvl > 0) {
                Util.printDbg("Iter " + this.icall + " loglikelihood " + d + " gnorm " + norm(dArr2) + " xnorm " + norm(dArr));
            }
            return d;
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
            return 0.0d;
        }
    }
}
