package fig.prob;

import fig.basic.ListUtils;
import fig.basic.NumUtils;
import java.io.Serializable;
import java.util.Random;

/* loaded from: input_file:fig/prob/MargMultinomial.class */
public class MargMultinomial implements MargDistrib<Multinomial>, Serializable {
    private static long serialVersionUID = 42;
    private DirichletInterface prior;

    public MargMultinomial(DirichletInterface dirichletInterface) {
        this.prior = dirichletInterface;
    }

    @Override // fig.prob.MargDistrib
    public double margLogLikelihood(SuffStats suffStats) {
        return this.prior instanceof DegenerateDirichlet ? expectedLogLikelihood(suffStats) : predLogLikelihood(null, suffStats);
    }

    @Override // fig.prob.MargDistrib
    public double predLogLikelihood(SuffStats suffStats, SuffStats suffStats2) {
        MultinomialSuffStats multinomialSuffStats = (MultinomialSuffStats) suffStats;
        MultinomialSuffStats multinomialSuffStats2 = (MultinomialSuffStats) suffStats2;
        Dirichlet dirichlet = (Dirichlet) this.prior;
        double d = 0.0d;
        for (int i = 0; i < dirichlet.dim(); i++) {
            d += DirichletUtils.logGammaRatio(dirichlet.getAlpha(i) + (multinomialSuffStats == null ? 0.0d : multinomialSuffStats.getCount(i)), multinomialSuffStats2.getCount(i));
        }
        return d - DirichletUtils.logGammaRatio(dirichlet.totalCount() + (multinomialSuffStats == null ? 0.0d : multinomialSuffStats.totalCount()), multinomialSuffStats2.totalCount());
    }

    @Override // fig.prob.Distrib
    public double logProb(SuffStats suffStats) {
        return this.prior.logProb(suffStats);
    }

    @Override // fig.prob.Distrib
    public double logProbObject(Multinomial multinomial) {
        return this.prior.logProbObject(multinomial.getProbs());
    }

    @Override // fig.prob.Distrib
    public double crossEntropy(Distrib<Multinomial> distrib) {
        return this.prior.crossEntropy(((MargMultinomial) distrib).prior);
    }

    public double expectedLog(int i) {
        return this.prior.expectedLog(i);
    }

    public double[] expectedLog() {
        return this.prior.expectedLog();
    }

    public int dim() {
        return this.prior.dim();
    }

    @Override // fig.prob.MargDistrib
    public double expectedLogLikelihood(SuffStats suffStats) {
        double[] counts = ((MultinomialSuffStats) suffStats).getCounts();
        double d = 0.0d;
        for (int i = 0; i < counts.length; i++) {
            d += counts[i] * this.prior.expectedLog(i);
        }
        NumUtils.assertIsFinite(d);
        return d;
    }

    public DirichletInterface getPrior() {
        return this.prior;
    }

    @Override // fig.prob.MargDistrib
    public MargMultinomial getPosterior(SuffStats suffStats) {
        return new MargMultinomial(new Dirichlet(ListUtils.add(((Dirichlet) this.prior).getAlpha(), ((MultinomialSuffStats) suffStats).getCounts())));
    }

    public MargMultinomial modeSpike() {
        return new MargMultinomial(this.prior.modeSpike());
    }

    public MargMultinomial perturb(Random random) {
        return new MargMultinomial(((Dirichlet) this.prior).perturb(random));
    }

    public MargMultinomial degeneratePerturb(Random random) {
        return new MargMultinomial(new DegenerateDirichlet(((Dirichlet) this.prior).sample(random)));
    }

    @Override // fig.prob.Distrib
    public Multinomial sampleObject(Random random) {
        return new Multinomial(this.prior.sampleObject(random));
    }

    public String toString() {
        return String.format("MargMultinomial(%s)", this.prior);
    }
}
