package fig.prob;

import fig.basic.NumUtils;
import fig.basic.TDoubleMap;
import java.util.Random;

/* loaded from: input_file:fig/prob/MargSparseMultinomial.class */
public class MargSparseMultinomial implements MargDistrib<SparseMultinomial> {
    private SparseDirichletInterface prior;

    public MargSparseMultinomial(SparseDirichletInterface sparseDirichletInterface) {
        this.prior = sparseDirichletInterface;
    }

    @Override // fig.prob.MargDistrib
    public double margLogLikelihood(SuffStats suffStats) {
        return this.prior instanceof DegenerateSparseDirichlet ? expectedLogLikelihood(suffStats) : predLogLikelihood(SparseMultinomialSuffStats.emptyStats, suffStats);
    }

    @Override // fig.prob.MargDistrib
    public double predLogLikelihood(SuffStats suffStats, SuffStats suffStats2) {
        SparseMultinomialSuffStats sparseMultinomialSuffStats = (SparseMultinomialSuffStats) suffStats;
        SparseMultinomialSuffStats sparseMultinomialSuffStats2 = (SparseMultinomialSuffStats) suffStats2;
        SparseDirichlet sparseDirichlet = (SparseDirichlet) this.prior;
        double d = 0.0d;
        for (TDoubleMap<T>.Entry entry : sparseMultinomialSuffStats2) {
            d += DirichletUtils.logGammaRatio(sparseDirichlet.getConcentration(entry.getKey()) + sparseMultinomialSuffStats.getCount(entry.getKey()), entry.getValue());
        }
        return d - DirichletUtils.logGammaRatio(sparseDirichlet.totalCount() + sparseMultinomialSuffStats.totalCount(), sparseMultinomialSuffStats2.totalCount());
    }

    @Override // fig.prob.MargDistrib
    public MargDistrib getPosterior(SuffStats suffStats) {
        return new MargSparseMultinomial(((SparseDirichlet) this.prior).withExtraCounts((SparseMultinomialSuffStats) suffStats));
    }

    @Override // fig.prob.Distrib
    public double logProb(SuffStats suffStats) {
        return this.prior.logProb(suffStats);
    }

    @Override // fig.prob.Distrib
    public double logProbObject(SparseMultinomial sparseMultinomial) {
        return this.prior.logProbObject(sparseMultinomial.getProbs());
    }

    @Override // fig.prob.Distrib
    public SparseMultinomial sampleObject(Random random) {
        return new SparseMultinomial(this.prior.sampleObject(random));
    }

    public double expectedLog(Object obj) {
        return this.prior.expectedLog(obj);
    }

    @Override // fig.prob.Distrib
    public double crossEntropy(Distrib<SparseMultinomial> distrib) {
        return this.prior.crossEntropy(((MargSparseMultinomial) distrib).prior);
    }

    @Override // fig.prob.MargDistrib
    public double expectedLogLikelihood(SuffStats suffStats) {
        double d = 0.0d;
        for (TDoubleMap<T>.Entry entry : (SparseMultinomialSuffStats) suffStats) {
            d += entry.getValue() * this.prior.expectedLog(entry.getKey());
        }
        NumUtils.assertIsFinite(d);
        return d;
    }

    public MargSparseMultinomial modeSpike() {
        return new MargSparseMultinomial(this.prior.modeSpike());
    }

    public String toString() {
        return this.prior.toString();
    }
}
