package fig.prob;

import fig.basic.LogInfo;
import fig.basic.StatFig;
import fig.basic.TDoubleMap;
import java.util.Random;

/* loaded from: input_file:fig/prob/DistribUtils.class */
public class DistribUtils {
    public static final double margin = 1.0E-8d;

    public static double predLogLikelihood(MargDistrib margDistrib, SuffStats suffStats, SuffStats suffStats2) {
        double margLogLikelihood = margDistrib.margLogLikelihood(suffStats2);
        suffStats2.add(suffStats);
        double margLogLikelihood2 = margDistrib.margLogLikelihood(suffStats2);
        suffStats2.sub(suffStats);
        return margLogLikelihood2 - margLogLikelihood;
    }

    public static double KL(Distrib distrib, Distrib distrib2) {
        return distrib.crossEntropy(distrib) - distrib.crossEntropy(distrib2);
    }

    public static <T> void verifyCrossEntropy(Distrib<T> distrib, Distrib<T> distrib2) {
        Random random = new Random();
        StatFig statFig = new StatFig();
        for (int i = 0; i < 100000; i++) {
            statFig.add(distrib2.logProbObject(distrib.sampleObject(random)));
        }
        double crossEntropy = distrib.crossEntropy(distrib2);
        double mean = statFig.mean();
        System.out.println(String.valueOf(crossEntropy) + " " + mean + " " + (crossEntropy - mean));
    }

    public static void verifyExpectedLogLikelihood(MargDistrib margDistrib, SuffStats suffStats) {
        verifyExpectedLogLikelihood(margDistrib, suffStats, 100000);
    }

    public static void verifyExpectedLogLikelihood(MargDistrib margDistrib, SuffStats suffStats, int i) {
        Random random = new Random();
        StatFig statFig = new StatFig();
        for (int i2 = 0; i2 < i; i2++) {
            statFig.add(((Distrib) margDistrib.sampleObject(random)).logProb(suffStats));
        }
        double expectedLogLikelihood = margDistrib.expectedLogLikelihood(suffStats);
        double mean = statFig.mean();
        System.out.println(String.valueOf(expectedLogLikelihood) + " " + mean + " " + (expectedLogLikelihood - mean));
    }

    public static void verifyPassed() {
        verifyCrossEntropy(new Gaussian(2.0d, 0.3d), new Gaussian(8.0d, 1.7d));
        verifyCrossEntropy(new Gamma(2.0d, 0.3d), new Gamma(8.0d, 1.7d));
        verifyCrossEntropy(new Dirichlet(10, 0.3d), new Dirichlet(10, 1.7d));
        TDoubleMap tDoubleMap = new TDoubleMap();
        tDoubleMap.put("A", 3.0d);
        tDoubleMap.put("B", 8.0d);
        tDoubleMap.put("C", 0.0d);
        TDoubleMap tDoubleMap2 = new TDoubleMap();
        tDoubleMap2.put("A", 3.0d);
        tDoubleMap.put("B", 0.0d);
        tDoubleMap2.put("C", 1.0d);
        verifyCrossEntropy(new SparseDirichlet(10, 0.3d, tDoubleMap), new SparseDirichlet(10, 1.7d, tDoubleMap2));
        verifyExpectedLogLikelihood(new MargMultinomial(new Dirichlet(5, 1.3d)), new MultinomialSuffStats(new double[]{4.0d, 21.0d, 0.3d, 2.0d, 4.0d}));
        TDoubleMap tDoubleMap3 = new TDoubleMap();
        tDoubleMap3.put("A", 3.0d);
        tDoubleMap3.put("B", 8.0d);
        tDoubleMap3.put("C", 4.0d);
        verifyExpectedLogLikelihood(new MargSparseMultinomial(new SparseDirichlet(10, 1.3d, tDoubleMap3)), new SparseMultinomialSuffStats(tDoubleMap3));
        verifyExpectedLogLikelihood(new MargMeanGaussian(new Gaussian(0.0d, 1.0d), 1.0d), new GaussianSuffStats(0.0d, 0.0d, 1.0d));
        verifyExpectedLogLikelihood(new MargMeanGaussian(new Gaussian(2.0d, 10.0d), 0.7d), new GaussianSuffStats(2.0d, 10.0d, 0.3d), 1000000);
        verifyExpectedLogLikelihood(new MargMeanDiagMultGaussian(new DiagMultGaussian(new double[]{3.0d, 4.0d, -2.0d}, new double[]{0.7d, 1.5d, 4.4d}), new double[]{1.7d, 3.5d, 0.4d}), new DiagMultGaussianSuffStats(new double[]{1.0d, 1.0d, 0.0d}, new double[]{8.0d, 4.0d, 3.0d}, 0.37d), 1000000);
        Random random = new Random();
        Gamma gamma = new Gamma(2.0d, 0.3d);
        StatFig statFig = new StatFig();
        for (int i = 0; i < 10000; i++) {
            statFig.add(Math.log(gamma.sample(random)));
        }
        System.out.println(String.valueOf(gamma.expectedLog()) + " " + statFig.mean());
    }

    public static void main(String[] strArr) {
        LogInfo.init();
        LogInfo.msPerLine = 0;
        Dirichlet dirichlet = new Dirichlet(new double[]{10001.0d, 1.0d, 2.0d, 1.0d});
        Dirichlet dirichlet2 = new Dirichlet(new double[]{50001.0d, 1.0d, 6.0d, 1.0d});
        LogInfo.logs(String.valueOf(Math.exp(dirichlet.expectedLog(2))) + " " + Math.exp(dirichlet.expectedLog(3)));
        LogInfo.logs(String.valueOf(Math.exp(dirichlet2.expectedLog(2))) + " " + Math.exp(dirichlet2.expectedLog(3)));
        DirichletInterface modeSpike = dirichlet.modeSpike();
        DirichletInterface modeSpike2 = dirichlet2.modeSpike();
        LogInfo.logs(String.valueOf(Math.exp(modeSpike.expectedLog(2))) + " " + Math.exp(modeSpike.expectedLog(3)));
        LogInfo.logs(String.valueOf(Math.exp(modeSpike2.expectedLog(2))) + " " + Math.exp(modeSpike2.expectedLog(3)));
    }
}
