package net.sf.javaml.clustering;

import net.sf.javaml.clustering.evaluation.SumOfSquaredErrors;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.DatasetTools;
import net.sf.javaml.core.Instance;
import net.sf.javaml.core.SimpleDataset;
import net.sf.javaml.distance.DistanceMeasure;
import net.sf.javaml.distance.EuclideanDistance;

/* loaded from: input_file:net/sf/javaml/clustering/EMClustering.class */
public class EMClustering implements Clusterer {
    private double[][][] m_modelNormal;
    private double m_minStdDev;
    private double[] m_minStdDevPerAtt;
    private double[][] m_weights;
    private double[] m_priors;
    private Dataset m_theInstances;
    private int m_num_clusters;
    private int m_num_attribs;
    private int m_max_iterations;
    private DistanceMeasure dm;
    private static double m_normConst = Math.log(Math.sqrt(6.283185307179586d));
    private double clusterThreshold;

    private void EM_Init(Dataset dataset) {
        Dataset[] datasetArr = (Dataset[]) null;
        Instance standardDeviation = DatasetTools.getStandardDeviation(dataset);
        double d = Double.MAX_VALUE;
        for (int i = 0; i < 10; i++) {
            KMeans kMeans = new KMeans(this.m_num_clusters, 100);
            SumOfSquaredErrors sumOfSquaredErrors = new SumOfSquaredErrors(this.dm);
            Dataset[] executeClustering = kMeans.executeClustering(dataset);
            double score = sumOfSquaredErrors.score(executeClustering);
            if (score < d) {
                d = score;
                datasetArr = executeClustering;
            }
        }
        this.m_num_clusters = datasetArr.length;
        this.m_weights = new double[dataset.size()][this.m_num_clusters];
        this.m_modelNormal = new double[this.m_num_clusters][this.m_num_attribs][3];
        this.m_priors = new double[this.m_num_clusters];
        Instance[] instanceArr = new Instance[this.m_num_clusters];
        int[] iArr = new int[this.m_num_clusters];
        Instance[] instanceArr2 = new Instance[this.m_num_clusters];
        for (int i2 = 0; i2 < datasetArr.length; i2++) {
            instanceArr[i2] = DatasetTools.getCentroid(datasetArr[i2]);
            iArr[i2] = datasetArr[i2].size();
            instanceArr2[i2] = DatasetTools.getStandardDeviation(datasetArr[i2]);
        }
        for (int i3 = 0; i3 < this.m_num_clusters; i3++) {
            Instance instance = instanceArr[i3];
            for (int i4 = 0; i4 < this.m_num_attribs; i4++) {
                double d2 = this.m_minStdDevPerAtt != null ? this.m_minStdDevPerAtt[i4] : this.m_minStdDev;
                this.m_modelNormal[i3][i4][0] = instance.getValue(i4);
                double value = instanceArr2[i3].getValue(i4);
                if (value < d2) {
                    value = standardDeviation.getValue(i4);
                    if (Double.isInfinite(value)) {
                        value = d2;
                    }
                    if (value < d2) {
                        value = d2;
                    }
                }
                if (value <= 0.0d) {
                    value = this.m_minStdDev;
                }
                this.m_modelNormal[i3][i4][1] = value;
                this.m_modelNormal[i3][i4][2] = 1.0d;
            }
        }
        for (int i5 = 0; i5 < this.m_num_clusters; i5++) {
            this.m_priors[i5] = iArr[i5];
        }
        normalize(this.m_priors);
    }

    private static void normalize(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        normalize(dArr, d);
    }

    private static void normalize(double[] dArr, double d) {
        if (Double.isNaN(d)) {
            throw new IllegalArgumentException("Can't normalize array. Sum is NaN.");
        }
        if (d == 0.0d) {
            throw new IllegalArgumentException("Can't normalize array. Sum is zero.");
        }
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] / d;
        }
    }

    private void estimate_priors(Dataset dataset) throws Exception {
        for (int i = 0; i < this.m_num_clusters; i++) {
            this.m_priors[i] = 0.0d;
        }
        for (int i2 = 0; i2 < dataset.size(); i2++) {
            for (int i3 = 0; i3 < this.m_num_clusters; i3++) {
                double[] dArr = this.m_priors;
                int i4 = i3;
                dArr[i4] = dArr[i4] + (dataset.getInstance(i2).getWeight() * this.m_weights[i2][i3]);
            }
        }
        normalize(this.m_priors);
    }

    private double logNormalDens(double d, double d2, double d3) {
        double d4 = d - d2;
        return ((-((d4 * d4) / ((2.0d * d3) * d3))) - m_normConst) - Math.log(d3);
    }

    private void new_estimators() {
        for (int i = 0; i < this.m_num_clusters; i++) {
            for (int i2 = 0; i2 < this.m_num_attribs; i2++) {
                double[] dArr = this.m_modelNormal[i][i2];
                double[] dArr2 = this.m_modelNormal[i][i2];
                this.m_modelNormal[i][i2][2] = 0.0d;
                dArr2[1] = 0.0d;
                dArr[0] = 0.0d;
            }
        }
    }

    private void M(Dataset dataset) throws Exception {
        new_estimators();
        Instance standardDeviation = DatasetTools.getStandardDeviation(dataset);
        for (int i = 0; i < this.m_num_clusters; i++) {
            for (int i2 = 0; i2 < this.m_num_attribs; i2++) {
                for (int i3 = 0; i3 < dataset.size(); i3++) {
                    Instance dataset2 = dataset.getInstance(i3);
                    double[] dArr = this.m_modelNormal[i][i2];
                    dArr[0] = dArr[0] + (dataset2.getValue(i2) * dataset2.getWeight() * this.m_weights[i3][i]);
                    double[] dArr2 = this.m_modelNormal[i][i2];
                    dArr2[2] = dArr2[2] + (dataset2.getWeight() * this.m_weights[i3][i]);
                    double[] dArr3 = this.m_modelNormal[i][i2];
                    dArr3[1] = dArr3[1] + (dataset2.getValue(i2) * dataset2.getValue(i2) * dataset2.getWeight() * this.m_weights[i3][i]);
                }
            }
        }
        for (int i4 = 0; i4 < this.m_num_attribs; i4++) {
            for (int i5 = 0; i5 < this.m_num_clusters; i5++) {
                if (this.m_modelNormal[i5][i4][2] <= 0.0d) {
                    this.m_modelNormal[i5][i4][1] = Double.MAX_VALUE;
                    this.m_modelNormal[i5][i4][0] = this.m_minStdDev;
                } else {
                    this.m_modelNormal[i5][i4][1] = (this.m_modelNormal[i5][i4][1] - ((this.m_modelNormal[i5][i4][0] * this.m_modelNormal[i5][i4][0]) / this.m_modelNormal[i5][i4][2])) / this.m_modelNormal[i5][i4][2];
                    if (this.m_modelNormal[i5][i4][1] < 0.0d) {
                        this.m_modelNormal[i5][i4][1] = 0.0d;
                    }
                    double d = this.m_minStdDevPerAtt != null ? this.m_minStdDevPerAtt[i4] : this.m_minStdDev;
                    this.m_modelNormal[i5][i4][1] = Math.sqrt(this.m_modelNormal[i5][i4][1]);
                    if (this.m_modelNormal[i5][i4][1] <= d) {
                        this.m_modelNormal[i5][i4][1] = standardDeviation.getValue(i4);
                        if (this.m_modelNormal[i5][i4][1] <= d) {
                            this.m_modelNormal[i5][i4][1] = d;
                        }
                    }
                    if (this.m_modelNormal[i5][i4][1] <= 0.0d) {
                        this.m_modelNormal[i5][i4][1] = this.m_minStdDev;
                    }
                    if (Double.isInfinite(this.m_modelNormal[i5][i4][1])) {
                        this.m_modelNormal[i5][i4][1] = this.m_minStdDev;
                    }
                    double[] dArr4 = this.m_modelNormal[i5][i4];
                    dArr4[0] = dArr4[0] / this.m_modelNormal[i5][i4][2];
                }
            }
        }
    }

    private double E(Dataset dataset, boolean z) throws Exception {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < dataset.size(); i++) {
            Instance dataset2 = dataset.getInstance(i);
            d += dataset2.getWeight() * logDensityForInstance(dataset2);
            d2 += dataset2.getWeight();
            if (z) {
                this.m_weights[i] = distributionForInstance(dataset2);
            }
        }
        if (z) {
            estimate_priors(dataset);
        }
        return d / d2;
    }

    private double logDensityForInstance(Instance instance) {
        double[] logJointDensitiesForInstance = logJointDensitiesForInstance(instance);
        double d = logJointDensitiesForInstance[maxIndex(logJointDensitiesForInstance)];
        double d2 = 0.0d;
        for (double d3 : logJointDensitiesForInstance) {
            d2 += Math.exp(d3 - d);
        }
        return d + Math.log(d2);
    }

    private double[] distributionForInstance(Instance instance) {
        return logs2probs(logJointDensitiesForInstance(instance));
    }

    private static int maxIndex(double[] dArr) {
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (i2 == 0 || dArr[i2] > d) {
                i = i2;
                d = dArr[i2];
            }
        }
        return i;
    }

    private static double[] logs2probs(double[] dArr) {
        double d = dArr[maxIndex(dArr)];
        double d2 = 0.0d;
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = Math.exp(dArr[i] - d);
            d2 += dArr2[i];
        }
        normalize(dArr2, d2);
        return dArr2;
    }

    private double[] clusterPriors() {
        double[] dArr = new double[this.m_priors.length];
        System.arraycopy(this.m_priors, 0, dArr, 0, dArr.length);
        return dArr;
    }

    private double[] logJointDensitiesForInstance(Instance instance) {
        double[] logDensityPerClusterForInstance = logDensityPerClusterForInstance(instance);
        double[] clusterPriors = clusterPriors();
        for (int i = 0; i < logDensityPerClusterForInstance.length; i++) {
            if (clusterPriors[i] <= 0.0d) {
                throw new IllegalArgumentException("Cluster empty!");
            }
            int i2 = i;
            logDensityPerClusterForInstance[i2] = logDensityPerClusterForInstance[i2] + Math.log(clusterPriors[i]);
        }
        return logDensityPerClusterForInstance;
    }

    private double[] logDensityPerClusterForInstance(Instance instance) {
        double[] dArr = new double[this.m_num_clusters];
        for (int i = 0; i < this.m_num_clusters; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.m_num_attribs; i2++) {
                d += logNormalDens(instance.getValue(i2), this.m_modelNormal[i][i2][0], this.m_modelNormal[i][i2][1]);
            }
            dArr[i] = d;
        }
        return dArr;
    }

    private double doEM() {
        this.m_num_attribs = this.m_theInstances.getInstance(0).size();
        EM_Init(this.m_theInstances);
        return iterate(this.m_theInstances);
    }

    private double iterate(Dataset dataset) {
        double d = 0.0d;
        boolean z = false;
        int i = 0;
        while (!z) {
            for (int i2 = 0; i2 < this.m_max_iterations; i2++) {
                try {
                    double d2 = d;
                    d = E(dataset, true);
                    if (i2 > 0 && d - d2 < 1.0E-6d) {
                        break;
                    }
                    M(dataset);
                } catch (Exception e) {
                    System.err.println("Restarting after training failure");
                    e.printStackTrace();
                    i++;
                    EM_Init(this.m_theInstances);
                }
            }
            z = true;
        }
        return d;
    }

    public EMClustering() {
        this(4, 100, new EuclideanDistance());
    }

    public EMClustering(int i, int i2, DistanceMeasure distanceMeasure) {
        this.m_minStdDev = 1.0E-6d;
        this.m_theInstances = null;
        this.clusterThreshold = 0.75d;
        this.m_num_clusters = i;
        this.m_max_iterations = i2;
        this.dm = distanceMeasure;
    }

    @Override // net.sf.javaml.clustering.Clusterer
    public Dataset[] executeClustering(Dataset dataset) {
        this.m_theInstances = dataset;
        doEM();
        Dataset[] datasetArr = new Dataset[this.m_num_clusters];
        for (int i = 0; i < this.m_num_clusters; i++) {
            datasetArr[i] = new SimpleDataset();
        }
        for (int i2 = 0; i2 < dataset.size(); i2++) {
            double[] distributionForInstance = distributionForInstance(dataset.getInstance(i2));
            for (int i3 = 0; i3 < this.m_num_clusters; i3++) {
                if (distributionForInstance[i3] > this.clusterThreshold) {
                    datasetArr[i3].addInstance(dataset.getInstance(i2));
                }
            }
        }
        return filter(datasetArr);
    }

    private Dataset[] filter(Dataset[] datasetArr) {
        int i = 0;
        for (Dataset dataset : datasetArr) {
            if (dataset.size() == 0) {
                i++;
            }
        }
        Dataset[] datasetArr2 = new Dataset[datasetArr.length - i];
        int i2 = 0;
        for (Dataset dataset2 : datasetArr) {
            if (dataset2.size() > 0) {
                datasetArr2[i2] = dataset2;
                i2++;
            }
        }
        return datasetArr2;
    }
}
