package iitb2.CRF;

import java.util.Vector;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:iitb2/CRF/CollinsTrainer.class */
public class CollinsTrainer extends Trainer {
    int beamsize;
    double beta;
    boolean useUpdated;
    boolean voted;
    Soln[] solnPool;

    public CollinsTrainer(CrfParams crfParams) {
        super(crfParams);
        this.beamsize = 3;
        this.beta = 0.05d;
        this.useUpdated = false;
        this.voted = true;
        if (this.params.miscOptions.getProperty("beamSize") != null) {
            this.beamsize = Integer.parseInt(this.params.miscOptions.getProperty("beamSize"));
        }
        if (this.params.miscOptions.getProperty("beta") != null) {
            this.beta = Double.parseDouble(this.params.miscOptions.getProperty("beta"));
        }
        if (this.params.miscOptions.getProperty("UpdatedViterbi") != null) {
            this.useUpdated = this.params.miscOptions.getProperty("UpdatedViterbi").equalsIgnoreCase("true");
        }
        if (this.params.miscOptions.getProperty("voted") != null) {
            this.voted = this.params.miscOptions.getProperty("voted").equalsIgnoreCase("true");
        }
    }

    @Override // iitb2.CRF.Trainer
    public void train(CRF crf, DataIter dataIter, double[] dArr, Evaluator evaluator) {
        Soln soln;
        init(crf, dataIter, dArr);
        double[] dArr2 = this.gradLogli;
        Viterbi viterbi = crf.getViterbi(this.beamsize);
        for (int i = 0; i < this.lambda.length; i++) {
            dArr2[i] = 0.0d;
            this.lambda[i] = 0.0d;
        }
        Vector vector = new Vector();
        for (int i2 = 0; i2 < this.params.maxIters; i2++) {
            int i3 = 0;
            this.diter.startScan();
            int i4 = 0;
            while (this.diter.hasNext()) {
                DataSequence next = this.diter.next();
                viterbi.viterbiSearch(next, this.useUpdated ? this.lambda : dArr2, false);
                Soln correctSoln = getCorrectSoln(next, this.useUpdated ? this.lambda : dArr2);
                double d = correctSoln.score;
                int numSolutions = viterbi.numSolutions();
                vector.clear();
                for (int i5 = 0; i5 < numSolutions; i5++) {
                    Soln bestSoln = viterbi.getBestSoln(i5);
                    if (bestSoln.score < d * (1.0d - this.beta)) {
                        break;
                    }
                    if (!isCorrect(bestSoln, correctSoln)) {
                        vector.add(bestSoln);
                    }
                }
                if (vector.size() > 0) {
                    while (correctSoln != null) {
                        boolean z = false;
                        for (int i6 = 0; i6 < vector.size(); i6++) {
                            Soln soln2 = (Soln) vector.elementAt(i6);
                            if (soln2 == null || !correctSoln.equals(soln2)) {
                                z = true;
                                break;
                            }
                        }
                        if (z) {
                            i3++;
                            updateWeights(correctSoln, 1.0d, dArr2, next);
                            for (int i7 = 0; i7 < vector.size(); i7++) {
                                Soln soln3 = (Soln) vector.elementAt(i7);
                                while (true) {
                                    Soln soln4 = soln3;
                                    if (soln4 != null && soln4.pos > correctSoln.prevPos()) {
                                        updateWeights(soln4, (-1.0d) / vector.size(), dArr2, next);
                                        soln3 = soln4.prevSoln;
                                    }
                                }
                            }
                        }
                        for (int i8 = 0; i8 < vector.size(); i8++) {
                            Soln soln5 = (Soln) vector.elementAt(i8);
                            while (true) {
                                soln = soln5;
                                if (soln != null && soln.pos > correctSoln.prevPos()) {
                                    soln5 = soln.prevSoln;
                                }
                            }
                            vector.set(i8, soln);
                        }
                        correctSoln = correctSoln.prevSoln;
                    }
                }
                for (int i9 = 0; i9 < this.lambda.length; i9++) {
                    double[] dArr3 = this.lambda;
                    int i10 = i9;
                    dArr3[i10] = dArr3[i10] + dArr2[i9];
                }
                i4++;
            }
            if (this.params.debugLvl > 0) {
                Util.printDbg("Iteration " + i2 + " numErrs " + i3);
            }
            if (i3 == 0) {
                return;
            }
        }
    }

    boolean isCorrect(Soln soln, Soln soln2) {
        while (soln != null && soln2 != null) {
            if (!soln.equals(soln2)) {
                return false;
            }
            soln2 = soln2.prevSoln;
            soln = soln.prevSoln;
        }
        return soln == null && soln2 == null;
    }

    int getSegmentEnd(DataSequence dataSequence, int i) {
        return i;
    }

    void startFeatureGenerator(FeatureGenerator featureGenerator, DataSequence dataSequence, Soln soln) {
        featureGenerator.startScanFeaturesAt(dataSequence, soln.pos);
    }

    void updateWeights(Soln soln, double d, double[] dArr, DataSequence dataSequence) {
        startFeatureGenerator(this.featureGenerator, dataSequence, soln);
        while (this.featureGenerator.hasNext()) {
            Feature next = this.featureGenerator.next();
            int index = next.index();
            int y = next.y();
            int yprev = next.yprev();
            float value = next.value();
            if (soln.label == y && ((soln.prevPos() >= 0 && yprev == soln.prevSoln.label) || yprev < 0)) {
                dArr[index] = dArr[index] + (d * value);
            }
        }
    }

    Soln getCorrectSoln(DataSequence dataSequence, double[] dArr) {
        Soln soln = null;
        if (this.solnPool == null || this.solnPool.length < dataSequence.length()) {
            this.solnPool = new Soln[dataSequence.length()];
            int i = 0;
            while (i < dataSequence.length()) {
                int i2 = i;
                i++;
                this.solnPool[i2] = new Soln(0, 0);
            }
        }
        int i3 = 0;
        while (true) {
            int i4 = i3;
            if (i4 >= dataSequence.length()) {
                return soln;
            }
            int segmentEnd = getSegmentEnd(dataSequence, i4);
            Soln soln2 = this.solnPool[i4];
            soln2.pos = segmentEnd;
            soln2.label = dataSequence.y(i4);
            soln2.prevSoln = soln;
            soln2.score = soln == null ? 0.0f : soln.score;
            startFeatureGenerator(this.featureGenerator, dataSequence, soln2);
            while (this.featureGenerator.hasNext()) {
                Feature next = this.featureGenerator.next();
                int index = next.index();
                int y = next.y();
                int yprev = next.yprev();
                float value = next.value();
                if (soln2.label == y && ((soln2.prevPos() >= 0 && yprev == soln2.prevSoln.label) || yprev < 0)) {
                    soln2.score = (float) (soln2.score + (dArr[index] * value));
                }
            }
            soln = soln2;
            i3 = segmentEnd + 1;
        }
    }
}
