package fig.prob;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import java.util.Arrays;
import java.util.Random;

/* loaded from: input_file:fig/prob/MultGaussian.class */
public class MultGaussian implements Distrib<double[]> {
    private Matrix mean;
    private Matrix covar;
    private CholeskyDecomposition chol = null;
    private static MultGaussian stdNormal;
    private static double[] zeroVector;
    private static double[][] identityMtx;
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !MultGaussian.class.desiredAssertionStatus();
        stdNormal = null;
    }

    public MultGaussian(double[] dArr, double[][] dArr2) {
        this.mean = new Matrix(dArr, dArr.length);
        this.covar = new Matrix(dArr2);
    }

    @Override // fig.prob.Distrib
    public double logProb(SuffStats suffStats) {
        MultGaussianSuffStats multGaussianSuffStats = (MultGaussianSuffStats) suffStats;
        double dim = ((0.5d * Gaussian.LOG_INV_SQRT_2_PI) * multGaussianSuffStats.dim()) - (this.covar.det() * 0.5d);
        Matrix inverse = this.covar.inverse();
        return dim - (0.5d * (((aggregatePtwiseProduct(multGaussianSuffStats.getMtxOuterProduct(), inverse) + (this.mean.transpose().times(inverse).times(this.mean).get(0, 0) * multGaussianSuffStats.numPoints())) + (-multGaussianSuffStats.getMtxSum().transpose().times(inverse).times(this.mean).get(0, 0))) + (-this.mean.transpose().times(inverse).times(multGaussianSuffStats.getMtxSum()).get(0, 0))));
    }

    @Override // fig.prob.Distrib
    public double logProbObject(double[] dArr) {
        return logProb(new MultGaussianSuffStats(dArr));
    }

    private CholeskyDecomposition getChol() {
        if (this.chol != null) {
            return this.chol;
        }
        this.chol = this.covar.chol();
        return this.chol;
    }

    public double[] sample(Random random) {
        Matrix l = getChol().getL();
        Matrix matrix = new Matrix(dim(), 1);
        for (int i = 0; i < dim(); i++) {
            matrix.set(i, 0, SampleUtils.sampleGaussian(random));
        }
        Matrix times = l.times(matrix);
        times.plusEquals(this.mean);
        return times.getColumnPackedCopy();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // fig.prob.Distrib
    public double[] sampleObject(Random random) {
        return sample(random);
    }

    @Override // fig.prob.Distrib
    public double crossEntropy(Distrib<double[]> distrib) {
        throw new RuntimeException("unsupported");
    }

    public static void main(String[] strArr) {
        double[][] dArr = new double[2][2];
        dArr[0][0] = 1.0d;
        dArr[1][1] = 4.0d;
        dArr[0][1] = 1.0d;
        dArr[1][0] = 1.0d;
        MultGaussian multGaussian = new MultGaussian(new double[]{1.0d, 2.0d}, dArr);
        Random random = new Random();
        for (int i = 0; i < 10000; i++) {
            System.out.println(Arrays.toString(multGaussian.sample(random)));
        }
    }

    public static double aggregatePtwiseProduct(Matrix matrix, Matrix matrix2) {
        if (!$assertionsDisabled && matrix.getRowDimension() != matrix2.getRowDimension()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && matrix.getColumnDimension() != matrix2.getColumnDimension()) {
            throw new AssertionError();
        }
        double d = 0.0d;
        for (int i = 0; i < matrix.getRowDimension(); i++) {
            for (int i2 = 0; i2 < matrix.getColumnDimension(); i2++) {
                d += matrix.get(i, i2) * matrix2.get(i, i2);
            }
        }
        return d;
    }

    public int dim() {
        return this.covar.getRowDimension();
    }

    public static MultGaussian getStdNormal(int i) {
        if (stdNormal != null && stdNormal.dim() == i) {
            return stdNormal;
        }
        stdNormal = new MultGaussian(getZeroVector(i), getIdentityMtx(i));
        return stdNormal;
    }

    public static double[] getZeroVector(int i) {
        if (zeroVector != null && zeroVector.length == i) {
            return zeroVector;
        }
        zeroVector = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            zeroVector[i2] = 0.0d;
        }
        return zeroVector;
    }

    public static double[][] getIdentityMtx(int i) {
        if (identityMtx != null && identityMtx.length == i) {
            return identityMtx;
        }
        identityMtx = new double[i][i];
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                if (i2 != i3) {
                    identityMtx[i2][i3] = 0.0d;
                } else {
                    identityMtx[i2][i3] = 1.0d;
                }
            }
        }
        return identityMtx;
    }

    public double[] getMean() {
        return this.mean.getArray()[0];
    }

    public double[][] getCovar() {
        return this.covar.getArray();
    }

    public Matrix getCovarMatrix() {
        return this.covar;
    }
}
