package fig.prob;

import fig.basic.Fmt;
import fig.basic.ListUtils;
import java.util.Random;

/* loaded from: input_file:fig/prob/DiagMultGaussian.class */
public class DiagMultGaussian implements Distrib<double[]> {
    private double[] mean;
    private double[] var;

    public DiagMultGaussian(double[] dArr, double d) {
        this.mean = dArr;
        this.var = ListUtils.newDouble(dArr.length, d);
    }

    public DiagMultGaussian(double[] dArr, double[] dArr2) {
        this.mean = dArr;
        this.var = dArr2;
    }

    public DiagMultGaussian(Gaussian[] gaussianArr) {
        this.mean = new double[gaussianArr.length];
        this.var = new double[gaussianArr.length];
        for (int i = 0; i < dim(); i++) {
            this.mean[i] = gaussianArr[i].getMean();
            this.var[i] = gaussianArr[i].getVar();
        }
    }

    public DiagMultGaussian(int i, Gaussian gaussian) {
        this.mean = ListUtils.newDouble(i, gaussian.getMean());
        this.var = ListUtils.newDouble(i, gaussian.getVar());
    }

    public double logProb(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < dim(); i++) {
            d += Gaussian.logProb(this.mean[i], this.var[i], dArr[i]);
        }
        return d;
    }

    @Override // fig.prob.Distrib
    public double logProb(SuffStats suffStats) {
        DiagMultGaussianSuffStats diagMultGaussianSuffStats = (DiagMultGaussianSuffStats) suffStats;
        double d = 0.0d;
        for (int i = 0; i < dim(); i++) {
            d += Gaussian.logProb(this.mean[i], this.var[i], diagMultGaussianSuffStats.getSum(i), diagMultGaussianSuffStats.getSumSq(i), diagMultGaussianSuffStats.numPoints());
        }
        return d;
    }

    @Override // fig.prob.Distrib
    public double logProbObject(double[] dArr) {
        return logProb(dArr);
    }

    public double[] sample(Random random) {
        double[] dArr = new double[this.mean.length];
        for (int i = 0; i < dim(); i++) {
            dArr[i] = Gaussian.sample(random, this.mean[i], this.var[i]);
        }
        return dArr;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // fig.prob.Distrib
    public double[] sampleObject(Random random) {
        return sample(random);
    }

    @Override // fig.prob.Distrib
    public double crossEntropy(Distrib<double[]> distrib) {
        DiagMultGaussian diagMultGaussian = (DiagMultGaussian) distrib;
        double d = 0.0d;
        for (int i = 0; i < dim(); i++) {
            d += getComponent(i).crossEntropy(diagMultGaussian.getComponent(i));
        }
        return d;
    }

    public Gaussian getComponent(int i) {
        return new Gaussian(this.mean[i], this.var[i]);
    }

    public double[] getMean() {
        return this.mean;
    }

    public double[] getVar() {
        return this.var;
    }

    public double getMean(int i) {
        return this.mean[i];
    }

    public double getVar(int i) {
        return this.var[i];
    }

    public int dim() {
        return this.mean.length;
    }

    public String toString() {
        return String.format("mean(%s),var(%s)", Fmt.D(this.mean), Fmt.D(this.var));
    }
}
