package iitb2.CRF;

import cern.colt.function.IntDoubleFunction;
import cern.colt.function.IntIntDoubleFunction;
import cern.colt.list.DoubleArrayList;
import cern.colt.list.IntArrayList;
import cern.colt.list.ObjectArrayList;
import cern.colt.matrix.impl.DenseObjectMatrix1D;
import iitb2.CRF.Viterbi;

/* loaded from: input_file:iitb2/CRF/SparseViterbi.class */
public class SparseViterbi extends Viterbi {
    protected Context[] context;
    protected LogSparseDoubleMatrix1D Ri;
    ObjectArrayList prevContext;
    IntArrayList validYs;
    IntArrayList validPrevYs;
    DoubleArrayList values;
    protected ContextUpdate contextUpdate;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:iitb2/CRF/SparseViterbi$Context.class */
    public class Context extends DenseObjectMatrix1D {
        protected int pos;
        protected int beamsize;

        /* JADX INFO: Access modifiers changed from: protected */
        public Context(int i, int i2, int i3) {
            super(i);
            this.pos = i3;
            this.beamsize = i2;
        }

        protected Viterbi.Entry newEntry(int i, int i2, int i3) {
            return new Viterbi.Entry(i, i2, i3);
        }

        public void add(int i, Viterbi.Entry entry, float f) {
            if (getQuick(i) == null) {
                setQuick(i, newEntry(this.pos == 0 ? 1 : this.beamsize, i, this.pos));
            }
            getEntry(i).valid = true;
            getEntry(i).add(entry, f);
        }

        public void clear() {
            for (int i = 0; i < size(); i++) {
                if (getQuick(i) != null) {
                    getEntry(i).clear();
                }
            }
        }

        public Viterbi.Entry getEntry(int i) {
            return (Viterbi.Entry) getQuick(i);
        }

        public boolean entryNotNull(int i) {
            return getQuick(i) != null && getEntry(i).valid;
        }

        void assign(LogSparseDoubleMatrix1D logSparseDoubleMatrix1D) {
            for (int i = 0; i < logSparseDoubleMatrix1D.size(); i++) {
                if (logSparseDoubleMatrix1D.getQuick(i) != 0.0d) {
                    add(i, null, (float) logSparseDoubleMatrix1D.get(i));
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:iitb2/CRF/SparseViterbi$ContextUpdate.class */
    public class ContextUpdate implements IntIntDoubleFunction, IntDoubleFunction {
        protected int i;
        protected int ell;
        protected Iter iter;

        protected ContextUpdate() {
        }

        public double apply(int i, int i2, double d) {
            if (SparseViterbi.this.context[this.i - this.ell].entryNotNull(i)) {
                SparseViterbi.this.context[this.i].add(i2, SparseViterbi.this.context[this.i - this.ell].getEntry(i), (float) (SparseViterbi.this.Mi.get(i, i2) + SparseViterbi.this.Ri.get(i2)));
            }
            return d;
        }

        public double apply(int i, double d) {
            SparseViterbi.this.context[this.i].add(i, null, (float) SparseViterbi.this.Ri.get(i));
            return d;
        }

        double fillArray(DataSequence dataSequence, double[] dArr, boolean z) {
            double d = 0.0d;
            this.i = 0;
            while (this.i < dataSequence.length()) {
                SparseViterbi.this.context[this.i].clear();
                this.iter.start(this.i, dataSequence);
                while (true) {
                    int nextEll = this.iter.nextEll(this.i);
                    this.ell = nextEll;
                    if (nextEll <= 0) {
                        break;
                    }
                    SparseViterbi.this.computeLogMi(dataSequence, this.i, this.ell, dArr);
                    if (this.i - this.ell < 0) {
                        SparseViterbi.this.Ri.forEachNonZero(this);
                    } else {
                        SparseViterbi.this.Mi.forEachNonZero(this);
                    }
                    if (SparseViterbi.this.model.params.debugLvl > 1) {
                        System.out.println("Ri " + SparseViterbi.this.Ri);
                        System.out.println("Mi " + SparseViterbi.this.Mi);
                    }
                    if (z) {
                        d += SparseViterbi.this.getCorrectScore(dataSequence, this.i, this.ell);
                    }
                }
                SparseViterbi.this.finishContext(this.i);
                this.i++;
            }
            return d;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:iitb2/CRF/SparseViterbi$Iter.class */
    public class Iter {
        protected int ell;

        /* JADX INFO: Access modifiers changed from: protected */
        public Iter() {
        }

        protected void start(int i, DataSequence dataSequence) {
            this.ell = 1;
        }

        protected int nextEll(int i) {
            int i2 = this.ell;
            this.ell = i2 - 1;
            return i2;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SparseViterbi(CRF crf, int i) {
        super(crf, i);
        this.prevContext = new ObjectArrayList();
        this.validYs = new IntArrayList();
        this.validPrevYs = new IntArrayList();
        this.values = new DoubleArrayList();
    }

    protected void computeLogMi(DataSequence dataSequence, int i, int i2, double[] dArr) {
        this.model.featureGenerator.startScanFeaturesAt(dataSequence, i);
        SparseTrainer.computeLogMi(this.model.featureGenerator, dArr, this.Mi, this.Ri);
    }

    protected Iter getIter() {
        return new Iter();
    }

    protected void finishContext(int i) {
    }

    protected double getCorrectScore(DataSequence dataSequence, int i, int i2) {
        return this.Ri.getQuick(dataSequence.y(i)) + (i > 0 ? this.Mi.get(dataSequence.y(i - 1), dataSequence.y(i)) : 0.0d);
    }

    protected ContextUpdate newContextUpdate() {
        return new ContextUpdate();
    }

    @Override // iitb2.CRF.Viterbi
    protected void allocateScratch(int i) {
        this.Mi = new LogSparseDoubleMatrix2D(i, i);
        this.Ri = new LogSparseDoubleMatrix1D(i);
        this.context = new Context[0];
        this.finalSoln = new Viterbi.Entry(this.beamsize, 0, 0);
        this.contextUpdate = newContextUpdate();
        this.contextUpdate.iter = getIter();
    }

    protected Context newContext(int i, int i2, int i3) {
        return new Context(i, i2, i3);
    }

    @Override // iitb2.CRF.Viterbi
    public double viterbiSearch(DataSequence dataSequence, double[] dArr, boolean z) {
        if (this.Mi == null) {
            allocateScratch(this.model.numY);
        }
        if (this.context.length < dataSequence.length() + 1) {
            Context[] contextArr = this.context;
            this.context = new Context[dataSequence.length() + 1];
            for (int i = 0; i < contextArr.length; i++) {
                this.context[i] = contextArr[i];
            }
            for (int length = contextArr.length; length < this.context.length; length++) {
                this.context[length] = newContext(this.model.numY, this.beamsize, length);
            }
        }
        double fillArray = this.contextUpdate.fillArray(dataSequence, dArr, z);
        this.finalSoln.clear();
        this.finalSoln.valid = true;
        int length2 = dataSequence.length() - 1;
        if (length2 >= 0) {
            for (int i2 = 0; i2 < this.context[length2].size(); i2++) {
                if (this.context[length2].entryNotNull(i2)) {
                    this.finalSoln.add((Viterbi.Entry) this.context[length2].getQuick(i2), 0.0f);
                }
            }
        }
        if (this.model.params.debugLvl > 1) {
            System.out.println("Score of best sequence " + this.finalSoln.get(0).score + " corrScore " + fillArray);
        }
        return fillArray;
    }
}
