package fig.prob;

import fig.basic.Fmt;
import fig.basic.ListUtils;
import java.util.Random;

/* loaded from: input_file:fig/prob/MargMeanDiagMultGaussian.class */
public class MargMeanDiagMultGaussian implements MargDistrib<DiagMultGaussian> {
    private DiagMultGaussian meanDistrib;
    private double[] varSpikes;

    public MargMeanDiagMultGaussian(DiagMultGaussian diagMultGaussian, double[] dArr) {
        this.meanDistrib = diagMultGaussian;
        this.varSpikes = dArr;
    }

    public MargMeanDiagMultGaussian(int i, Gaussian gaussian, double d) {
        this.meanDistrib = new DiagMultGaussian(i, gaussian);
        this.varSpikes = ListUtils.newDouble(i, d);
    }

    public MargMeanGaussian getComponent(int i) {
        return new MargMeanGaussian(this.meanDistrib.getComponent(i), this.varSpikes[i]);
    }

    @Override // fig.prob.MargDistrib
    public MargDistrib getPosterior(SuffStats suffStats) {
        DiagMultGaussianSuffStats diagMultGaussianSuffStats = (DiagMultGaussianSuffStats) suffStats;
        Gaussian[] gaussianArr = new Gaussian[dim()];
        double[] dArr = new double[dim()];
        for (int i = 0; i < dim(); i++) {
            MargMeanGaussian posterior = getComponent(i).getPosterior((SuffStats) diagMultGaussianSuffStats.getComponent(i));
            gaussianArr[i] = posterior.getMeanDistrib();
            dArr[i] = posterior.getVarSpike();
        }
        return new MargMeanDiagMultGaussian(new DiagMultGaussian(gaussianArr), dArr);
    }

    @Override // fig.prob.MargDistrib
    public double margLogLikelihood(SuffStats suffStats) {
        double d = 0.0d;
        for (int i = 0; i < dim(); i++) {
            d += getComponent(i).margLogLikelihood(((DiagMultGaussianSuffStats) suffStats).getComponent(i));
        }
        return d;
    }

    @Override // fig.prob.MargDistrib
    public double predLogLikelihood(SuffStats suffStats, SuffStats suffStats2) {
        return getPosterior((DiagMultGaussianSuffStats) suffStats).margLogLikelihood(suffStats2);
    }

    @Override // fig.prob.Distrib
    public double logProb(SuffStats suffStats) {
        return this.meanDistrib.logProb(suffStats);
    }

    @Override // fig.prob.Distrib
    public double logProbObject(DiagMultGaussian diagMultGaussian) {
        return this.meanDistrib.logProbObject(diagMultGaussian.getMean());
    }

    @Override // fig.prob.Distrib
    public double crossEntropy(Distrib<DiagMultGaussian> distrib) {
        return this.meanDistrib.crossEntropy(((MargMeanDiagMultGaussian) distrib).meanDistrib);
    }

    @Override // fig.prob.MargDistrib
    public double expectedLogLikelihood(SuffStats suffStats) {
        DiagMultGaussianSuffStats diagMultGaussianSuffStats = (DiagMultGaussianSuffStats) suffStats;
        double d = 0.0d;
        for (int i = 0; i < dim(); i++) {
            d += getComponent(i).expectedLogLikelihood(diagMultGaussianSuffStats.getComponent(i));
        }
        return d;
    }

    @Override // fig.prob.Distrib
    public DiagMultGaussian sampleObject(Random random) {
        return new DiagMultGaussian(this.meanDistrib.sample(random), this.varSpikes);
    }

    public int dim() {
        return this.varSpikes.length;
    }

    public String toString() {
        return String.format("mean(%s),var(%s)", this.meanDistrib, Fmt.D(this.varSpikes));
    }
}
