package fig.prob;

import Jama.Matrix;
import fig.basic.Exceptions;
import java.util.Random;

/* loaded from: input_file:fig/prob/MargMultGaussian.class */
public class MargMultGaussian implements MargDistrib<NormalInverseWishart> {
    private NormalInverseWishartDistrib meanVarDistrib;

    public MargMultGaussian(NormalInverseWishartDistrib normalInverseWishartDistrib) {
        this.meanVarDistrib = normalInverseWishartDistrib;
    }

    @Override // fig.prob.MargDistrib
    public double margLogLikelihood(SuffStats suffStats) {
        return ((0.0d + this.meanVarDistrib.unNormalizedLogProb(MultGaussian.getZeroVector(dim()), MultGaussian.getIdentityMtx(dim()))) + MultGaussian.getStdNormal(dim()).logProb((MultGaussianSuffStats) suffStats)) - getPosterior(suffStats).meanVarDistrib.unNormalizedLogProb(MultGaussian.getZeroVector(dim()), MultGaussian.getIdentityMtx(dim()));
    }

    @Override // fig.prob.MargDistrib
    public double predLogLikelihood(SuffStats suffStats, SuffStats suffStats2) {
        return getPosterior(suffStats).margLogLikelihood(suffStats2);
    }

    @Override // fig.prob.MargDistrib
    public MargMultGaussian getPosterior(SuffStats suffStats) {
        MultGaussianSuffStats multGaussianSuffStats = (MultGaussianSuffStats) suffStats;
        Matrix matrix = new Matrix(multGaussianSuffStats.getSum(), multGaussianSuffStats.dim());
        Matrix matrix2 = new Matrix(multGaussianSuffStats.getOuterProduct());
        double kappa = this.meanVarDistrib.getKappa() + multGaussianSuffStats.numPoints();
        double nu = this.meanVarDistrib.getNu() + multGaussianSuffStats.numPoints();
        Matrix times = this.meanVarDistrib.getScriptV().times(this.meanVarDistrib.getKappa()).plus(matrix).times(1.0d / kappa);
        return new MargMultGaussian(new NormalInverseWishartDistrib(kappa, times, nu, this.meanVarDistrib.getDelta().times(this.meanVarDistrib.getNu()).plus(this.meanVarDistrib.getScriptV().times(this.meanVarDistrib.getScriptV().transpose()).times(this.meanVarDistrib.getKappa())).minus(times.times(times.transpose()).times(kappa)).plus(matrix2).times(1.0d / nu)));
    }

    @Override // fig.prob.Distrib
    public double logProb(SuffStats suffStats) {
        throw Exceptions.unimplemented;
    }

    @Override // fig.prob.Distrib
    public double logProbObject(NormalInverseWishart normalInverseWishart) {
        throw Exceptions.unimplemented;
    }

    @Override // fig.prob.Distrib
    public double crossEntropy(Distrib<NormalInverseWishart> distrib) {
        throw Exceptions.unimplemented;
    }

    @Override // fig.prob.MargDistrib
    public double expectedLogLikelihood(SuffStats suffStats) {
        throw Exceptions.unimplemented;
    }

    @Override // fig.prob.Distrib
    public NormalInverseWishart sampleObject(Random random) {
        return this.meanVarDistrib.sampleObject(random);
    }

    public int dim() {
        return this.meanVarDistrib.dim();
    }

    /* JADX WARN: Type inference failed for: r0v10, types: [double[], double[][]] */
    public static void main(String[] strArr) {
        Matrix matrix = new Matrix(1, 1);
        matrix.set(0, 0, 5.0d);
        Matrix matrix2 = new Matrix(1, 1);
        matrix2.set(0, 0, 1.0d);
        NormalInverseWishartDistrib normalInverseWishartDistrib = new NormalInverseWishartDistrib(1.0d, matrix, 4.0d, matrix2);
        MultGaussian multGaussian = new MultGaussian(new double[]{30.0d}, new double[]{new double[]{1.0d}});
        MultGaussianSuffStats multGaussianSuffStats = new MultGaussianSuffStats(1);
        Random random = new Random();
        for (int i = 0; i < 10000; i++) {
            multGaussianSuffStats.add(multGaussian.sample(random));
        }
        System.out.println(new MargMultGaussian(normalInverseWishartDistrib).getPosterior((SuffStats) multGaussianSuffStats).meanVarDistrib.expectedVariance().get(0, 0));
    }
}
