package fig.prob;

import fig.basic.NumUtils;
import java.util.Random;

/* loaded from: input_file:fig/prob/Beta.class */
public class Beta implements BetaInterface {
    private double alpha;
    private double beta;

    public Beta(double d, double d2) {
        this.alpha = d;
        this.beta = d2;
    }

    public double logProb(double d) {
        return logProb(this.alpha, this.beta, d);
    }

    public static double logProb(double d, double d2, double d3) {
        return ((NumUtils.logGamma(d + d2) - NumUtils.logGamma(d)) - NumUtils.logGamma(d2)) + ((d - 1.0d) * Math.log(d3)) + ((d2 - 1.0d) * Math.log(1.0d - d3));
    }

    @Override // fig.prob.Distrib
    public double logProb(SuffStats suffStats) {
        throw new RuntimeException("Not implemented");
    }

    @Override // fig.prob.Distrib
    public double logProbObject(Double d) {
        return logProb(d.doubleValue());
    }

    public double sample(Random random) {
        return sample(random, this.alpha, this.beta);
    }

    public static double sample(Random random, double d, double d2) {
        double sample = Gamma.sample(random, d, 1.0d);
        return sample / (sample + Gamma.sample(random, d2, 1.0d));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // fig.prob.Distrib
    public Double sampleObject(Random random) {
        return Double.valueOf(sample(random));
    }

    @Override // fig.prob.BetaInterface
    public double expectedLog(boolean z) {
        return z ? DirichletUtils.expectedLog(this.alpha, totalCount()) : DirichletUtils.expectedLog(this.beta, totalCount());
    }

    @Override // fig.prob.Distrib
    public double crossEntropy(Distrib<Double> distrib) {
        Beta beta = (Beta) distrib;
        return 0.0d + DirichletUtils.thatTotalCountContrib(beta.totalCount()) + DirichletUtils.elementContrib(this.alpha, beta.alpha, totalCount()) + DirichletUtils.elementContrib(this.beta, beta.beta, totalCount());
    }

    @Override // fig.prob.BetaInterface
    public double getAlpha() {
        return this.alpha;
    }

    @Override // fig.prob.BetaInterface
    public double getBeta() {
        return this.beta;
    }

    @Override // fig.prob.BetaInterface
    public double getMean() {
        return this.alpha / (this.alpha + this.beta);
    }

    @Override // fig.prob.BetaInterface
    public double getMode() {
        return (this.alpha <= 1.0d || this.beta <= 1.0d) ? this.alpha > this.beta ? 0.99999999d : 1.0E-8d : (this.alpha - 1.0d) / ((this.alpha + this.beta) - 2.0d);
    }

    @Override // fig.prob.BetaInterface
    public double totalCount() {
        return this.alpha + this.beta;
    }

    @Override // fig.prob.BetaInterface
    public BetaInterface modeSpike() {
        return new DegenerateBeta(getMode());
    }

    public BetaInterface perturb(Random random) {
        double bound = NumUtils.bound(sample(random), 1.0E-8d, 0.99999999d);
        return new Beta(totalCount() * bound, totalCount() * (1.0d - bound));
    }

    public BetaInterface degeneratePerturb(Random random) {
        return new DegenerateBeta(sample(random));
    }

    public String toString() {
        return String.format("Beta(%.3f,%.3f)", Double.valueOf(this.alpha), Double.valueOf(this.beta));
    }
}
