package fig.prob;

import Jama.Matrix;
import java.util.Random;

/* loaded from: input_file:fig/prob/NormalInverseWishartDistrib.class */
public class NormalInverseWishartDistrib implements Distrib<NormalInverseWishart> {
    private double nu;
    private Matrix delta;
    private Matrix scriptV;
    private double kappa;
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !NormalInverseWishartDistrib.class.desiredAssertionStatus();
    }

    public NormalInverseWishartDistrib(double d, Matrix matrix, double d2, Matrix matrix2) {
        if (d <= 0.0d) {
            throw new RuntimeException("kappa " + d + " should be > 0");
        }
        if (d2 <= matrix2.getColumnDimension() + 1) {
            throw new RuntimeException("nu " + d2 + " should be > d + 1, d = " + matrix2.getColumnDimension());
        }
        this.nu = d2;
        this.delta = matrix2;
        this.scriptV = matrix;
        this.kappa = d;
    }

    public NormalInverseWishartDistrib(double d, double[] dArr, double d2, double[][] dArr2) {
        this(d, new Matrix(dArr, dArr.length), d2, new Matrix(dArr2));
    }

    public double unNormalizedLogProb(double[] dArr, double[][] dArr2) {
        Matrix matrix = new Matrix(dArr, dArr.length);
        Matrix matrix2 = new Matrix(dArr2);
        if (!$assertionsDisabled && !isIdentity(matrix2)) {
            throw new AssertionError();
        }
        return ((-(((this.nu + dim()) / 2.0d) + 1.0d)) * Math.log(1.0d)) + ((-0.5d) * this.delta.times(matrix2).trace()) + (((-this.kappa) / 2.0d) * norm(matrix2, matrix.minus(this.scriptV)));
    }

    @Override // fig.prob.Distrib
    public double logProb(SuffStats suffStats) {
        throw new RuntimeException("Not implemented");
    }

    @Override // fig.prob.Distrib
    public double logProbObject(NormalInverseWishart normalInverseWishart) {
        throw new RuntimeException("Not supported right now");
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // fig.prob.Distrib
    public NormalInverseWishart sampleObject(Random random) {
        throw new RuntimeException("Not supported right now");
    }

    @Override // fig.prob.Distrib
    public double crossEntropy(Distrib<NormalInverseWishart> distrib) {
        throw new RuntimeException("Not supported");
    }

    private boolean isIdentity(Matrix matrix) {
        for (int i = 0; i < matrix.getRowDimension(); i++) {
            for (int i2 = 0; i2 < matrix.getColumnDimension(); i2++) {
                if (i == i2) {
                    if (matrix.get(i, i2) != 1.0d) {
                        return false;
                    }
                } else if (matrix.get(i, i2) != 0.0d) {
                    return false;
                }
            }
        }
        return true;
    }

    public static double norm(Matrix matrix, Matrix matrix2) {
        if (!$assertionsDisabled && matrix.getColumnDimension() != matrix.getRowDimension()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && matrix.getColumnDimension() != matrix2.getRowDimension()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && matrix2.getColumnDimension() != 1) {
            throw new AssertionError();
        }
        Matrix times = matrix2.transpose().times(matrix).times(matrix2);
        if (!$assertionsDisabled && times.getColumnDimension() != 1) {
            throw new AssertionError();
        }
        if ($assertionsDisabled || times.getRowDimension() == 1) {
            return times.get(0, 0);
        }
        throw new AssertionError();
    }

    public Matrix expectedVariance() {
        return this.delta.times(this.nu / ((this.nu - dim()) - 1.0d));
    }

    public int dim() {
        return this.delta.getColumnDimension();
    }

    public Matrix getDelta() {
        return this.delta;
    }

    public double getKappa() {
        return this.kappa;
    }

    public double getNu() {
        return this.nu;
    }

    public Matrix getScriptV() {
        return this.scriptV;
    }

    public String toString() {
        return "NIW(nu=" + this.nu + ", kappa=" + this.kappa + ")";
    }
}
