package net.sf.javaml.classification.svm;

import java.util.HashSet;
import java.util.Set;
import net.sf.javaml.classification.Classifier;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.Instance;
import net.sf.javaml.core.Verbose;
import net.sf.javaml.distance.DistanceMeasure;
import net.sf.javaml.distance.LinearKernel;

/* loaded from: input_file:net/sf/javaml/classification/svm/BinaryLinearSMO.class */
public class BinaryLinearSMO extends Verbose implements Classifier {
    private static final long serialVersionUID = 1202307139516461728L;
    private double m_C;
    private double m_eps;
    private double m_Del;
    private double m_tol;
    private DistanceMeasure m_kernel;
    private double[] m_alpha;
    private double m_b;
    private double m_bLow;
    private double m_bUp;
    private int m_iLow;
    private int m_iUp;
    private Dataset m_data;
    private double[] m_weights;
    private double[] m_sparseWeights;
    private int[] m_sparseIndices;
    private double[] m_class;
    private double[] m_errors;
    private Set<Integer> m_I0;
    private Set<Integer> m_I1;
    private Set<Integer> m_I2;
    private Set<Integer> m_I3;
    private Set<Integer> m_I4;
    private Set<Integer> m_supportVectors;
    private double m_sumOfWeights;

    public BinaryLinearSMO() {
        this(1.0d);
    }

    public BinaryLinearSMO(double d) {
        this(d, new LinearKernel());
    }

    public BinaryLinearSMO(double d, DistanceMeasure distanceMeasure) {
        this.m_C = 1.0d;
        this.m_eps = 1.0E-12d;
        this.m_Del = 4.94E-321d;
        this.m_tol = 0.001d;
        this.m_kernel = null;
        this.m_sumOfWeights = 0.0d;
        this.m_C = d;
        this.m_kernel = distanceMeasure;
    }

    @Override // net.sf.javaml.classification.Classifier
    public void buildClassifier(Dataset dataset) {
        this.m_bUp = -1.0d;
        this.m_bLow = 1.0d;
        this.m_b = 0.0d;
        this.m_alpha = null;
        this.m_data = null;
        this.m_weights = null;
        this.m_errors = null;
        this.m_I0 = null;
        this.m_I1 = null;
        this.m_I2 = null;
        this.m_I3 = null;
        this.m_I4 = null;
        this.m_sparseWeights = null;
        this.m_sparseIndices = null;
        verbose("Storing the sum of weights...");
        this.m_sumOfWeights = 0.0d;
        for (int i = 0; i < dataset.size(); i++) {
            this.m_sumOfWeights += dataset.getInstance(i).getWeight();
        }
        verbose("Setting class values...");
        this.m_class = new double[dataset.size()];
        this.m_iUp = -1;
        this.m_iLow = -1;
        for (int i2 = 0; i2 < this.m_class.length; i2++) {
            if (dataset.getInstance(i2).getClassValue() == 0) {
                this.m_class[i2] = -1.0d;
                this.m_iLow = i2;
            } else {
                if (dataset.getInstance(i2).getClassValue() != 1) {
                    throw new RuntimeException("This should never happen! A binary SMO can only take 0 and 1 as class values.");
                }
                this.m_class[i2] = 1.0d;
                this.m_iUp = i2;
            }
        }
        verbose("Checking for missing classes...");
        if (this.m_iUp == -1 || this.m_iLow == -1) {
            if (this.m_iUp != -1) {
                this.m_b = -1.0d;
            } else {
                if (this.m_iLow == -1) {
                    this.m_class = null;
                    return;
                }
                this.m_b = 1.0d;
            }
            this.m_sparseWeights = new double[0];
            this.m_sparseIndices = new int[0];
            this.m_class = null;
            return;
        }
        this.m_data = dataset;
        this.m_weights = new double[this.m_data.getInstance(0).size()];
        this.m_alpha = new double[this.m_data.size()];
        this.m_supportVectors = new HashSet();
        this.m_I0 = new HashSet();
        this.m_I1 = new HashSet();
        this.m_I2 = new HashSet();
        this.m_I3 = new HashSet();
        this.m_I4 = new HashSet();
        this.m_sparseWeights = null;
        this.m_sparseIndices = null;
        this.m_errors = new double[this.m_data.size()];
        this.m_errors[this.m_iLow] = 1.0d;
        this.m_errors[this.m_iUp] = -1.0d;
        for (int i3 = 0; i3 < this.m_class.length; i3++) {
            if (this.m_class[i3] == 1.0d) {
                this.m_I1.add(Integer.valueOf(i3));
            } else {
                this.m_I4.add(Integer.valueOf(i3));
            }
        }
        verbose("Searching for support vectors...");
        int i4 = 0;
        boolean z = true;
        while (true) {
            if (i4 <= 0 && !z) {
                break;
            }
            i4 = 0;
            if (z) {
                for (int i5 = 0; i5 < this.m_alpha.length; i5++) {
                    if (examineExample(i5)) {
                        i4++;
                    }
                }
            } else {
                for (int i6 = 0; i6 < this.m_alpha.length; i6++) {
                    if (this.m_alpha[i6] > 0.0d && this.m_alpha[i6] < this.m_C * this.m_data.getInstance(i6).getWeight()) {
                        if (examineExample(i6)) {
                            i4++;
                        }
                        if (this.m_bUp > this.m_bLow - (2.0d * this.m_tol)) {
                            break;
                        }
                    }
                }
                boolean z2 = true;
                i4 = 0;
                while (this.m_bUp < this.m_bLow - (2.0d * this.m_tol) && z2) {
                    z2 = takeStep(this.m_iUp, this.m_iLow, this.m_errors[this.m_iLow]);
                }
            }
            if (z) {
                z = false;
            } else if (i4 == 0) {
                z = true;
            }
        }
        this.m_b = (this.m_bLow + this.m_bUp) / 2.0d;
        this.m_errors = null;
        this.m_I4 = null;
        this.m_I3 = null;
        this.m_I2 = null;
        this.m_I1 = null;
        this.m_I0 = null;
        this.m_supportVectors = null;
        this.m_class = null;
        double[] dArr = new double[this.m_weights.length];
        int[] iArr = new int[this.m_weights.length];
        int i7 = 0;
        for (int i8 = 0; i8 < this.m_weights.length; i8++) {
            if (this.m_weights[i8] != 0.0d) {
                dArr[i7] = this.m_weights[i8];
                iArr[i7] = i8;
                i7++;
            }
        }
        this.m_sparseWeights = new double[i7];
        this.m_sparseIndices = new int[i7];
        System.arraycopy(dArr, 0, this.m_sparseWeights, 0, i7);
        System.arraycopy(iArr, 0, this.m_sparseIndices, 0, i7);
        this.m_weights = null;
        this.m_alpha = null;
    }

    private double SVMOutput(Instance instance) {
        double d = 0.0d;
        if (this.m_sparseWeights == null) {
            int size = instance.size();
            for (int i = 0; i < size; i++) {
                d += this.m_weights[i] * instance.getValue(i);
            }
        } else {
            int length = this.m_sparseWeights.length;
            for (int i2 = 0; i2 < length; i2++) {
                d += instance.getValue(this.m_sparseIndices[i2]) * this.m_sparseWeights[i2];
            }
        }
        return d - this.m_b;
    }

    private boolean examineExample(int i) {
        double SVMOutput;
        int i2 = -1;
        double d = this.m_class[i];
        if (this.m_I0.contains(Integer.valueOf(i))) {
            SVMOutput = this.m_errors[i];
        } else {
            SVMOutput = (SVMOutput(this.m_data.getInstance(i)) + this.m_b) - d;
            this.m_errors[i] = SVMOutput;
            if ((this.m_I1.contains(Integer.valueOf(i)) || this.m_I2.contains(Integer.valueOf(i))) && SVMOutput < this.m_bUp) {
                this.m_bUp = SVMOutput;
                this.m_iUp = i;
            } else if ((this.m_I3.contains(Integer.valueOf(i)) || this.m_I4.contains(Integer.valueOf(i))) && SVMOutput > this.m_bLow) {
                this.m_bLow = SVMOutput;
                this.m_iLow = i;
            }
        }
        boolean z = true;
        if ((this.m_I0.contains(Integer.valueOf(i)) || this.m_I1.contains(Integer.valueOf(i)) || this.m_I2.contains(Integer.valueOf(i))) && this.m_bLow - SVMOutput > 2.0d * this.m_tol) {
            z = false;
            i2 = this.m_iLow;
        }
        if ((this.m_I0.contains(Integer.valueOf(i)) || this.m_I3.contains(Integer.valueOf(i)) || this.m_I4.contains(Integer.valueOf(i))) && SVMOutput - this.m_bUp > 2.0d * this.m_tol) {
            z = false;
            i2 = this.m_iUp;
        }
        if (z) {
            return false;
        }
        if (this.m_I0.contains(Integer.valueOf(i))) {
            i2 = this.m_bLow - SVMOutput > SVMOutput - this.m_bUp ? this.m_iLow : this.m_iUp;
        }
        if (i2 == -1) {
            throw new RuntimeException("This should never happen!");
        }
        return takeStep(i2, i, SVMOutput);
    }

    private boolean takeStep(int i, int i2, double d) {
        double max;
        double min;
        double d2;
        double weight = this.m_C * this.m_data.getInstance(i).getWeight();
        double weight2 = this.m_C * this.m_data.getInstance(i2).getWeight();
        if (i == i2) {
            return false;
        }
        double d3 = this.m_alpha[i];
        double d4 = this.m_alpha[i2];
        double d5 = this.m_class[i];
        double d6 = this.m_class[i2];
        double d7 = this.m_errors[i];
        double d8 = d5 * d6;
        if (d5 != d6) {
            max = Math.max(0.0d, d4 - d3);
            min = Math.min(weight2, (weight + d4) - d3);
        } else {
            max = Math.max(0.0d, (d3 + d4) - weight);
            min = Math.min(weight2, d3 + d4);
        }
        if (max >= min) {
            return false;
        }
        double calculateDistance = this.m_kernel.calculateDistance(this.m_data.getInstance(i), this.m_data.getInstance(i));
        double calculateDistance2 = this.m_kernel.calculateDistance(this.m_data.getInstance(i2), this.m_data.getInstance(i));
        double calculateDistance3 = this.m_kernel.calculateDistance(this.m_data.getInstance(i2), this.m_data.getInstance(i2));
        double d9 = ((2.0d * calculateDistance2) - calculateDistance) - calculateDistance3;
        if (d9 < 0.0d) {
            d2 = d4 - ((d6 * (d7 - d)) / d9);
            if (d2 < max) {
                d2 = max;
            } else if (d2 > min) {
                d2 = min;
            }
        } else {
            double SVMOutput = SVMOutput(this.m_data.getInstance(i));
            double SVMOutput2 = SVMOutput(this.m_data.getInstance(i2));
            double d10 = ((SVMOutput + this.m_b) - ((d5 * d3) * calculateDistance)) - ((d6 * d4) * calculateDistance2);
            double d11 = ((SVMOutput2 + this.m_b) - ((d5 * d3) * calculateDistance2)) - ((d6 * d4) * calculateDistance3);
            double d12 = d3 + (d8 * d4);
            double d13 = ((((((d12 - (d8 * max)) + max) - (((0.5d * calculateDistance) * (d12 - (d8 * max))) * (d12 - (d8 * max)))) - (((0.5d * calculateDistance3) * max) * max)) - (((d8 * calculateDistance2) * (d12 - (d8 * max))) * max)) - ((d5 * (d12 - (d8 * max))) * d10)) - ((d6 * max) * d11);
            double d14 = ((((((d12 - (d8 * min)) + min) - (((0.5d * calculateDistance) * (d12 - (d8 * min))) * (d12 - (d8 * min)))) - (((0.5d * calculateDistance3) * min) * min)) - (((d8 * calculateDistance2) * (d12 - (d8 * min))) * min)) - ((d5 * (d12 - (d8 * min))) * d10)) - ((d6 * min) * d11);
            d2 = d13 > d14 + this.m_eps ? max : d13 < d14 - this.m_eps ? min : d4;
        }
        if (Math.abs(d2 - d4) < this.m_eps * (d2 + d4 + this.m_eps)) {
            return false;
        }
        if (d2 > weight2 - (this.m_Del * weight2)) {
            d2 = weight2;
        } else if (d2 <= this.m_Del * weight2) {
            d2 = 0.0d;
        }
        double d15 = d3 + (d8 * (d4 - d2));
        if (d15 > weight - (this.m_Del * weight)) {
            d15 = weight;
        } else if (d15 <= this.m_Del * weight) {
            d15 = 0.0d;
        }
        if (d15 > 0.0d) {
            this.m_supportVectors.add(Integer.valueOf(i));
        } else {
            this.m_supportVectors.remove(Integer.valueOf(i));
        }
        if (d15 <= 0.0d || d15 >= weight) {
            this.m_I0.remove(Integer.valueOf(i));
        } else {
            this.m_I0.add(Integer.valueOf(i));
        }
        if (d5 == 1.0d && d15 == 0.0d) {
            this.m_I1.add(Integer.valueOf(i));
        } else {
            this.m_I1.remove(Integer.valueOf(i));
        }
        if (d5 == -1.0d && d15 == weight) {
            this.m_I2.add(Integer.valueOf(i));
        } else {
            this.m_I2.remove(Integer.valueOf(i));
        }
        if (d5 == 1.0d && d15 == weight) {
            this.m_I3.add(Integer.valueOf(i));
        } else {
            this.m_I3.remove(Integer.valueOf(i));
        }
        if (d5 == -1.0d && d15 == 0.0d) {
            this.m_I4.add(Integer.valueOf(i));
        } else {
            this.m_I4.remove(Integer.valueOf(i));
        }
        if (d2 > 0.0d) {
            this.m_supportVectors.add(Integer.valueOf(i2));
        } else {
            this.m_supportVectors.remove(Integer.valueOf(i2));
        }
        if (d2 <= 0.0d || d2 >= weight2) {
            this.m_I0.remove(Integer.valueOf(i2));
        } else {
            this.m_I0.add(Integer.valueOf(i2));
        }
        if (d6 == 1.0d && d2 == 0.0d) {
            this.m_I1.add(Integer.valueOf(i2));
        } else {
            this.m_I1.remove(Integer.valueOf(i2));
        }
        if (d6 == -1.0d && d2 == weight2) {
            this.m_I2.add(Integer.valueOf(i2));
        } else {
            this.m_I2.remove(Integer.valueOf(i2));
        }
        if (d6 == 1.0d && d2 == weight2) {
            this.m_I3.add(Integer.valueOf(i2));
        } else {
            this.m_I3.remove(Integer.valueOf(i2));
        }
        if (d6 == -1.0d && d2 == 0.0d) {
            this.m_I4.add(Integer.valueOf(i2));
        } else {
            this.m_I4.remove(Integer.valueOf(i2));
        }
        Instance dataset = this.m_data.getInstance(i);
        for (int i3 = 0; i3 < dataset.size(); i3++) {
            double[] dArr = this.m_weights;
            int i4 = i3;
            dArr[i4] = dArr[i4] + (d5 * (d15 - d3) * dataset.getValue(i3));
        }
        Instance dataset2 = this.m_data.getInstance(i2);
        for (int i5 = 0; i5 < dataset2.size(); i5++) {
            double[] dArr2 = this.m_weights;
            int i6 = i5;
            dArr2[i6] = dArr2[i6] + (d6 * (d2 - d4) * dataset2.getValue(i5));
        }
        for (Integer num : this.m_I0) {
            if (num.intValue() != i && num.intValue() != i2) {
                double[] dArr3 = this.m_errors;
                int intValue = num.intValue();
                dArr3[intValue] = dArr3[intValue] + (d5 * (d15 - d3) * this.m_kernel.calculateDistance(this.m_data.getInstance(i), this.m_data.getInstance(num.intValue()))) + (d6 * (d2 - d4) * this.m_kernel.calculateDistance(this.m_data.getInstance(num.intValue()), this.m_data.getInstance(i2)));
            }
        }
        double[] dArr4 = this.m_errors;
        dArr4[i] = dArr4[i] + (d5 * (d15 - d3) * calculateDistance) + (d6 * (d2 - d4) * calculateDistance2);
        double[] dArr5 = this.m_errors;
        dArr5[i2] = dArr5[i2] + (d5 * (d15 - d3) * calculateDistance2) + (d6 * (d2 - d4) * calculateDistance3);
        this.m_alpha[i] = d15;
        this.m_alpha[i2] = d2;
        this.m_bLow = -1.7976931348623157E308d;
        this.m_bUp = Double.MAX_VALUE;
        this.m_iLow = -1;
        this.m_iUp = -1;
        for (Integer num2 : this.m_I0) {
            if (this.m_errors[num2.intValue()] < this.m_bUp) {
                this.m_bUp = this.m_errors[num2.intValue()];
                this.m_iUp = num2.intValue();
            }
            if (this.m_errors[num2.intValue()] > this.m_bLow) {
                this.m_bLow = this.m_errors[num2.intValue()];
                this.m_iLow = num2.intValue();
            }
        }
        if (!this.m_I0.contains(Integer.valueOf(i))) {
            if (this.m_I3.contains(Integer.valueOf(i)) || this.m_I4.contains(Integer.valueOf(i))) {
                if (this.m_errors[i] > this.m_bLow) {
                    this.m_bLow = this.m_errors[i];
                    this.m_iLow = i;
                }
            } else if (this.m_errors[i] < this.m_bUp) {
                this.m_bUp = this.m_errors[i];
                this.m_iUp = i;
            }
        }
        if (!this.m_I0.contains(Integer.valueOf(i2))) {
            if (this.m_I3.contains(Integer.valueOf(i2)) || this.m_I4.contains(Integer.valueOf(i2))) {
                if (this.m_errors[i2] > this.m_bLow) {
                    this.m_bLow = this.m_errors[i2];
                    this.m_iLow = i2;
                }
            } else if (this.m_errors[i2] < this.m_bUp) {
                this.m_bUp = this.m_errors[i2];
                this.m_iUp = i2;
            }
        }
        if (this.m_iLow == -1 || this.m_iUp == -1) {
            throw new RuntimeException("This should never happen!");
        }
        return true;
    }

    @Override // net.sf.javaml.classification.Classifier
    public int classifyInstance(Instance instance) {
        return SVMOutput(instance) < 0.0d ? 0 : 1;
    }

    @Override // net.sf.javaml.classification.Classifier
    public double[] distributionForInstance(Instance instance) {
        double[] dArr = new double[2];
        int classifyInstance = classifyInstance(instance);
        dArr[classifyInstance] = dArr[classifyInstance] + 1.0d;
        return dArr;
    }
}
