package fig.prob;

import fig.basic.Fmt;
import fig.basic.ListUtils;
import fig.basic.NumUtils;
import java.io.Serializable;
import java.util.Random;

/* loaded from: input_file:fig/prob/Dirichlet.class */
public class Dirichlet implements DirichletInterface, Serializable {
    private static long serialVersionUID = 42;
    private double[] alpha;
    private double totalCount;

    public Dirichlet(int i, double d) {
        this.alpha = ListUtils.newDouble(i, d);
        this.totalCount = d * i;
    }

    public Dirichlet(double[] dArr) {
        this.alpha = dArr;
        this.totalCount = ListUtils.sum(dArr);
    }

    public double logProb(double[] dArr) {
        return logProb(this.alpha, this.totalCount, dArr);
    }

    public static double logProb(double[] dArr, double d, double[] dArr2) {
        if (NumUtils.equals(d, dArr.length)) {
            return 0.0d;
        }
        double logGamma = NumUtils.logGamma(d);
        for (int i = 0; i < dArr.length; i++) {
            logGamma = (logGamma - NumUtils.logGamma(dArr[i])) + ((dArr[i] - 1.0d) * Math.log(dArr2[i]));
        }
        return logGamma;
    }

    @Override // fig.prob.Distrib
    public double logProb(SuffStats suffStats) {
        throw new RuntimeException("Not implemented");
    }

    @Override // fig.prob.Distrib
    public double logProbObject(double[] dArr) {
        return logProb(dArr);
    }

    public double[] sample(Random random) {
        return sample(random, this.alpha);
    }

    public static double[] sample(Random random, double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = Gamma.sample(random, dArr[i], 1.0d);
        }
        NumUtils.normalize(dArr2);
        return dArr2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // fig.prob.Distrib
    public double[] sampleObject(Random random) {
        return sample(random);
    }

    @Override // fig.prob.Distrib
    public double crossEntropy(Distrib<double[]> distrib) {
        Dirichlet dirichlet = (Dirichlet) distrib;
        double thatTotalCountContrib = 0.0d + DirichletUtils.thatTotalCountContrib(dirichlet.totalCount());
        for (int i = 0; i < this.alpha.length; i++) {
            thatTotalCountContrib += DirichletUtils.elementContrib(this.alpha[i], dirichlet.alpha[i], totalCount());
        }
        return thatTotalCountContrib;
    }

    @Override // fig.prob.DirichletInterface
    public double expectedLog(int i) {
        return DirichletUtils.expectedLog(this.alpha[i], this.totalCount);
    }

    @Override // fig.prob.DirichletInterface
    public double[] expectedLog() {
        double[] dArr = new double[dim()];
        for (int i = 0; i < dim(); i++) {
            dArr[i] = expectedLog(i);
        }
        return dArr;
    }

    @Override // fig.prob.DirichletInterface
    public DirichletInterface modeSpike() {
        return new DegenerateDirichlet(getMode());
    }

    public Dirichlet perturb(Random random) {
        return new Dirichlet(ListUtils.mult(this.totalCount, sample(random)));
    }

    @Override // fig.prob.DirichletInterface
    public double[] getMean() {
        double[] dArr = (double[]) this.alpha.clone();
        NumUtils.normalize(dArr);
        return dArr;
    }

    @Override // fig.prob.DirichletInterface
    public double[] getMode() {
        double[] add = ListUtils.add(this.alpha, -1.0d);
        for (int i = 0; i < add.length; i++) {
            add[i] = Math.max(add[i], 1.0E-8d);
        }
        NumUtils.normalize(add);
        return add;
    }

    public double[] getAlpha() {
        return this.alpha;
    }

    @Override // fig.prob.DirichletInterface
    public double getAlpha(int i) {
        return this.alpha[i];
    }

    @Override // fig.prob.DirichletInterface
    public double totalCount() {
        return this.totalCount;
    }

    @Override // fig.prob.DirichletInterface
    public int dim() {
        return this.alpha.length;
    }

    public String toString() {
        return String.format("Dirichlet(%s)", Fmt.D(this.alpha));
    }
}
