package fig.prob;

import java.util.Random;

/* loaded from: input_file:fig/prob/MargMeanGaussian.class */
public class MargMeanGaussian implements MargDistrib<Gaussian> {
    private Gaussian meanDistrib;
    private double varSpike;

    public MargMeanGaussian(Gaussian gaussian, double d) {
        this.meanDistrib = gaussian;
        this.varSpike = d;
    }

    private Gaussian getMeanPosterior(GaussianSuffStats gaussianSuffStats) {
        double d;
        double d2;
        double mean = this.meanDistrib.getMean();
        double var = this.meanDistrib.getVar();
        if (gaussianSuffStats.numPoints() < 1.0E-200d) {
            d = ((var * gaussianSuffStats.getSum()) + (this.varSpike * mean)) / this.varSpike;
            d2 = var;
        } else {
            double sum = gaussianSuffStats.getSum() / gaussianSuffStats.numPoints();
            double numPoints = this.varSpike / gaussianSuffStats.numPoints();
            d = ((var * sum) + (numPoints * mean)) / (var + numPoints);
            d2 = (var * numPoints) / (var + numPoints);
        }
        return new Gaussian(d, d2);
    }

    @Override // fig.prob.MargDistrib
    public MargMeanGaussian getPosterior(SuffStats suffStats) {
        return new MargMeanGaussian(getMeanPosterior((GaussianSuffStats) suffStats), this.varSpike);
    }

    @Override // fig.prob.MargDistrib
    public double margLogLikelihood(SuffStats suffStats) {
        return ((0.0d + Gaussian.logProb(this.meanDistrib.getMean(), this.meanDistrib.getVar(), 0.0d)) + Gaussian.logProb(0.0d, this.varSpike, (GaussianSuffStats) suffStats)) - getMeanPosterior((GaussianSuffStats) suffStats).logProb(0.0d);
    }

    @Override // fig.prob.Distrib
    public double logProb(SuffStats suffStats) {
        return this.meanDistrib.logProb(suffStats);
    }

    @Override // fig.prob.Distrib
    public double logProbObject(Gaussian gaussian) {
        return this.meanDistrib.logProbObject(Double.valueOf(gaussian.getMean()));
    }

    @Override // fig.prob.Distrib
    public double crossEntropy(Distrib<Gaussian> distrib) {
        return this.meanDistrib.crossEntropy(((MargMeanGaussian) distrib).meanDistrib);
    }

    @Override // fig.prob.MargDistrib
    public double predLogLikelihood(SuffStats suffStats, SuffStats suffStats2) {
        return getPosterior(suffStats).margLogLikelihood(suffStats2);
    }

    @Override // fig.prob.Distrib
    public Gaussian sampleObject(Random random) {
        return new Gaussian(this.meanDistrib.sample(random), this.varSpike);
    }

    @Override // fig.prob.MargDistrib
    public double expectedLogLikelihood(SuffStats suffStats) {
        GaussianSuffStats gaussianSuffStats = (GaussianSuffStats) suffStats;
        return 0.0d + (gaussianSuffStats.numPoints() * (Gaussian.LOG_INV_SQRT_2_PI - (0.5d * Math.log(this.varSpike)))) + (((-1.0d) / (2.0d * this.varSpike)) * ((gaussianSuffStats.getSumSq() - ((2.0d * gaussianSuffStats.getSum()) * this.meanDistrib.getMean())) + this.meanDistrib.getSecondMoment()));
    }

    public Gaussian getMeanDistrib() {
        return this.meanDistrib;
    }

    public double getVarSpike() {
        return this.varSpike;
    }

    public String toString() {
        return String.format("mean(%s),var(%.3f)", this.meanDistrib, Double.valueOf(this.varSpike));
    }
}
