package fig.prob;

import fig.basic.Exceptions;
import fig.basic.Fmt;
import fig.basic.ListUtils;
import fig.basic.NumUtils;

/* loaded from: input_file:fig/prob/DirichletUtils.class */
public class DirichletUtils {
    private static double fastExpMaxRange;
    private static double[] fastExpDigammaBuckets;
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !DirichletUtils.class.desiredAssertionStatus();
        fastExpMaxRange = 100.0d;
    }

    public static double expectedLog(double d, double d2) {
        double digamma = NumUtils.digamma(d) - NumUtils.digamma(d2);
        if (NumUtils.isFinite(digamma)) {
            return digamma;
        }
        throw Exceptions.bad("count=%f, totalCount=%f", Double.valueOf(d), Double.valueOf(d2));
    }

    public static double thatTotalCountContrib(double d) {
        return NumUtils.logGamma(d);
    }

    public static double elementContrib(double d, double d2, double d3) {
        return ((d2 - 1.0d) * expectedLog(d, d3)) - NumUtils.logGamma(d2);
    }

    public static double logGammaRatio(double d, double d2) {
        return d2 == 1.0d ? Math.log(d) : NumUtils.logGamma(d + d2) - NumUtils.logGamma(d);
    }

    public static double fastExpDigamma(double d) {
        if (!$assertionsDisabled && d < 0.0d) {
            throw new AssertionError(d);
        }
        if (d >= fastExpMaxRange) {
            return d - 0.5d;
        }
        if (fastExpDigammaBuckets == null) {
            fastExpDigammaBuckets = new double[1000000];
            for (int i = 1; i < fastExpDigammaBuckets.length; i++) {
                fastExpDigammaBuckets[i] = Math.exp(NumUtils.digamma((fastExpMaxRange * i) / fastExpDigammaBuckets.length));
            }
        }
        int length = (int) (((fastExpDigammaBuckets.length * d) / fastExpMaxRange) + 0.5d);
        if (length >= fastExpDigammaBuckets.length) {
            length = fastExpDigammaBuckets.length - 1;
        }
        return fastExpDigammaBuckets[length];
    }

    public static double[] fastExpExpectedLog(double[] dArr) {
        int length = dArr.length;
        double[] dArr2 = new double[length];
        double fastExpDigamma = fastExpDigamma(ListUtils.sum(dArr));
        for (int i = 0; i < length; i++) {
            dArr2[i] = fastExpDigamma(dArr[i]) / fastExpDigamma;
        }
        return dArr2;
    }

    public static boolean fastExpExpectedLogMut(double[] dArr) {
        int length = dArr.length;
        double fastExpDigamma = fastExpDigamma(ListUtils.sum(dArr));
        if (fastExpDigamma == 0.0d) {
            return false;
        }
        for (int i = 0; i < length; i++) {
            dArr[i] = fastExpDigamma(dArr[i]) / fastExpDigamma;
        }
        return true;
    }

    public static double[] expExpectedLog(double[] dArr) {
        int length = dArr.length;
        double[] dArr2 = new double[length];
        double exp = Math.exp(NumUtils.digamma(ListUtils.sum(dArr)));
        for (int i = 0; i < length; i++) {
            dArr2[i] = Math.exp(NumUtils.digamma(dArr[i])) / exp;
        }
        return dArr2;
    }

    public static void main(String[] strArr) {
        double[] dArr = {3.0d, 2.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 0.0d, 0.0d, 0.0d, 0.0d};
        double[] newDouble = ListUtils.newDouble(dArr.length, 1.0d);
        System.out.println(dArr.length);
        double[] expExpectedLog = expExpectedLog(ListUtils.add(dArr, newDouble));
        System.out.println(Fmt.D(expExpectedLog));
        System.out.println("sum = " + ListUtils.sum(expExpectedLog));
        System.out.println("norm = " + Fmt.D(norm(expExpectedLog)));
        System.out.println("MLE = " + Fmt.D(norm(dArr)));
    }

    static double[] norm(double[] dArr) {
        double[] dArr2 = (double[]) dArr.clone();
        NumUtils.normalize(dArr2);
        return dArr2;
    }
}
