/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.mi;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.functions.SMO;
import weka.classifiers.functions.supportVector.Kernel;
import weka.classifiers.functions.supportVector.PolyKernel;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.MultiInstanceToPropositional;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.Standardize;
import weka.filters.unsupervised.instance.SparseToNonSparse;

public class MISVM
extends Classifier
implements OptionHandler,
MultiInstanceCapabilitiesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = 7622231064035278145L;
    protected Filter m_SparseFilter = new SparseToNonSparse();
    protected SVM m_SVM;
    protected Kernel m_kernel = new PolyKernel();
    protected double m_C = 1.0;
    protected Filter m_Filter = null;
    protected int m_filterType = 0;
    public static final int FILTER_NORMALIZE = 0;
    public static final int FILTER_STANDARDIZE = 1;
    public static final int FILTER_NONE = 2;
    public static final Tag[] TAGS_FILTER = new Tag[]{new Tag(0, "Normalize training data"), new Tag(1, "Standardize training data"), new Tag(2, "No normalization/standardization")};
    protected int m_MaxIterations = 500;
    protected MultiInstanceToPropositional m_ConvertToProp = new MultiInstanceToPropositional();

    public String globalInfo() {
        return "Implements Stuart Andrews' mi_SVM (Maximum pattern Margin Formulation of MIL). Applying weka.classifiers.functions.SMO to solve multiple instances problem.\nThe algorithm first assign the bag label to each instance in the bag as its initial class label.  After that applying SMO to compute SVM solution for all instances in positive bags And then reassign the class label of each instance in the positive bag according to the SVM result Keep on iteration until labels do not change anymore.\n\nFor more information see:\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Stuart Andrews and Ioannis Tsochantaridis and Thomas Hofmann");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2003");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Support Vector Machines for Multiple-Instance Learning");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Advances in Neural Information Processing Systems 15");
        technicalInformation.setValue(TechnicalInformation.Field.PUBLISHER, "MIT Press");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "561-568");
        return technicalInformation;
    }

    public Enumeration listOptions() {
        Vector vector = new Vector();
        Enumeration enumeration = super.listOptions();
        while (enumeration.hasMoreElements()) {
            vector.addElement(enumeration.nextElement());
        }
        vector.addElement(new Option("\tThe complexity constant C. (default 1)", "C", 1, "-C <double>"));
        vector.addElement(new Option("\tWhether to 0=normalize/1=standardize/2=neither.\n\t(default: 0=normalize)", "N", 1, "-N <default 0>"));
        vector.addElement(new Option("\tThe maximum number of iterations to perform.\n\t(default: 500)", "I", 1, "-I <num>"));
        vector.addElement(new Option("\tThe Kernel to use.\n\t(default: weka.classifiers.functions.supportVector.PolyKernel)", "K", 1, "-K <classname and parameters>"));
        vector.addElement(new Option("", "", 0, "\nOptions specific to kernel " + this.getKernel().getClass().getName() + ":"));
        enumeration = this.getKernel().listOptions();
        while (enumeration.hasMoreElements()) {
            vector.addElement(enumeration.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        String string = Utils.getOption('C', stringArray);
        if (string.length() != 0) {
            this.setC(Double.parseDouble(string));
        } else {
            this.setC(1.0);
        }
        string = Utils.getOption('N', stringArray);
        if (string.length() != 0) {
            this.setFilterType(new SelectedTag(Integer.parseInt(string), TAGS_FILTER));
        } else {
            this.setFilterType(new SelectedTag(0, TAGS_FILTER));
        }
        string = Utils.getOption('I', stringArray);
        if (string.length() != 0) {
            this.setMaxIterations(Integer.parseInt(string));
        } else {
            this.setMaxIterations(500);
        }
        string = Utils.getOption('K', stringArray);
        String[] stringArray2 = Utils.splitOptions(string);
        if (stringArray2.length != 0) {
            string = stringArray2[0];
            stringArray2[0] = "";
            this.setKernel(Kernel.forName(string, stringArray2));
        }
        super.setOptions(stringArray);
    }

    public String[] getOptions() {
        Vector<String> vector = new Vector<String>();
        if (this.getDebug()) {
            vector.add("-D");
        }
        vector.add("-C");
        vector.add("" + this.getC());
        vector.add("-N");
        vector.add("" + this.m_filterType);
        vector.add("-K");
        vector.add("" + this.getKernel().getClass().getName() + " " + Utils.joinOptions(this.getKernel().getOptions()));
        return vector.toArray(new String[vector.size()]);
    }

    public String kernelTipText() {
        return "The kernel to use.";
    }

    public Kernel getKernel() {
        return this.m_kernel;
    }

    public void setKernel(Kernel kernel) {
        this.m_kernel = kernel;
    }

    public String filterTypeTipText() {
        return "The filter type for transforming the training data.";
    }

    public void setFilterType(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_FILTER) {
            this.m_filterType = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getFilterType() {
        return new SelectedTag(this.m_filterType, TAGS_FILTER);
    }

    public String cTipText() {
        return "The value for C.";
    }

    public double getC() {
        return this.m_C;
    }

    public void setC(double d) {
        this.m_C = d;
    }

    public String maxIterationsTipText() {
        return "The maximum number of iterations to perform.";
    }

    public int getMaxIterations() {
        return this.m_MaxIterations;
    }

    public void setMaxIterations(int n) {
        if (n < 1) {
            System.out.println("At least 1 iteration is necessary (provided: " + n + ")!");
        } else {
            this.m_MaxIterations = n;
        }
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.disableAllClasses();
        capabilities.disableAllClassDependencies();
        capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        return capabilities;
    }

    public Capabilities getMultiInstanceCapabilities() {
        SVM sVM = null;
        Capabilities capabilities = null;
        try {
            sVM = new SVM();
            sVM.setKernel(Kernel.makeCopy(this.getKernel()));
            capabilities = sVM.getCapabilities();
            capabilities.setOwner(this);
        }
        catch (Exception exception) {
            exception.printStackTrace();
        }
        capabilities.disableAllClasses();
        capabilities.enable(Capabilities.Capability.NO_CLASS);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        int n;
        this.getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        int n2 = instances.numInstances();
        int[] nArray = new int[n2];
        int[] nArray2 = new int[n2];
        Vector<Double> vector = new Vector<Double>();
        Vector vector2 = new Vector();
        for (n = 0; n < n2; ++n) {
            nArray2[n] = (int)instances.instance(n).classValue();
            nArray[n] = instances.instance(n).relationalValue(1).numInstances();
            for (int i = 0; i < nArray[n]; ++i) {
                vector.addElement(new Double(nArray2[n]));
            }
        }
        this.m_ConvertToProp.setWeightMethod(new SelectedTag(1, MultiInstanceToPropositional.TAGS_WEIGHTMETHOD));
        this.m_ConvertToProp.setInputFormat(instances);
        instances = Filter.useFilter(instances, this.m_ConvertToProp);
        instances.deleteAttributeAt(0);
        this.m_Filter = this.m_filterType == 1 ? new Standardize() : (this.m_filterType == 0 ? new Normalize() : null);
        if (this.m_Filter != null) {
            this.m_Filter.setInputFormat(instances);
            instances = Filter.useFilter(instances, this.m_Filter);
        }
        if (this.m_Debug) {
            System.out.println("\nIteration History...");
        }
        if (this.getDebug()) {
            System.out.println("\nstart building model ...");
        }
        Vector<Integer> vector3 = new Vector<Integer>();
        Instance instance = null;
        int n3 = 0;
        do {
            ++n3;
            n = -1;
            if (this.m_Debug) {
                System.out.println("=====================loop: " + n3);
            }
            vector2 = (Vector)vector.clone();
            this.m_SVM = new SVM();
            this.m_SVM.setC(this.getC());
            this.m_SVM.setKernel(Kernel.makeCopy(this.getKernel()));
            this.m_SVM.setFilterType(new SelectedTag(2, TAGS_FILTER));
            this.m_SVM.buildClassifier(instances);
            for (int i = 0; i < n2; ++i) {
                if (nArray2[i] == 1) {
                    double d;
                    int n4;
                    if (this.m_Debug) {
                        System.out.println("--------------- " + i + " ----------------");
                    }
                    double d2 = 0.0;
                    for (n4 = 0; n4 < nArray[i]; ++n4) {
                        if ((d = this.m_SVM.output(-1, instance = instances.instance(++n))) <= 0.0) {
                            if (instance.classValue() == 1.0) {
                                instances.instance(n).setClassValue(0.0);
                                vector.set(n, new Double(0.0));
                                if (this.m_Debug) {
                                    System.out.println(n + "- changed to 0");
                                }
                            }
                        } else if (instance.classValue() == 0.0) {
                            instances.instance(n).setClassValue(1.0);
                            vector.set(n, new Double(1.0));
                            if (this.m_Debug) {
                                System.out.println(n + "+ changed to 1");
                            }
                        }
                        d2 += instances.instance(n).classValue();
                    }
                    if (d2 != 0.0) continue;
                    double d3 = -1.7976931348623157E308;
                    vector3.clear();
                    for (n4 = n - nArray[i] + 1; n4 < n + 1; ++n4) {
                        instance = instances.instance(n4);
                        d = this.m_SVM.output(-1, instance);
                        if (d3 < d) {
                            d3 = d;
                            vector3.clear();
                            vector3.add(new Integer(n4));
                            continue;
                        }
                        if (d3 != d) continue;
                        vector3.add(new Integer(n4));
                    }
                    for (n4 = 0; n4 < vector3.size(); ++n4) {
                        Integer n5 = (Integer)vector3.get(n4);
                        instances.instance(n5).setClassValue(1.0);
                        vector.set(n5, new Double(1.0));
                        if (!this.m_Debug) continue;
                        System.out.println("##change to 1 ###outpput: " + d3 + " max_index: " + n5 + " bag: " + i);
                    }
                    continue;
                }
                n += nArray[i];
            }
        } while (!vector.equals(vector2) && n3 < this.m_MaxIterations);
        if (this.getDebug()) {
            System.out.println("finish building model.");
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        double d = 0.0;
        double[] dArray = new double[2];
        Instances instances = new Instances(instance.dataset(), 0);
        instances.add(instance);
        instances = Filter.useFilter(instances, this.m_ConvertToProp);
        instances.deleteAttributeAt(0);
        if (this.m_Filter != null) {
            instances = Filter.useFilter(instances, this.m_Filter);
        }
        for (int i = 0; i < instances.numInstances(); ++i) {
            Instance instance2 = instances.instance(i);
            double d2 = this.m_SVM.output(-1, instance2);
            double d3 = d2 <= 0.0 ? 0.0 : 1.0;
            d += d3;
        }
        dArray[0] = d == 0.0 ? 1.0 : 0.0;
        dArray[1] = 1.0 - dArray[0];
        return dArray;
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.6 $");
    }

    public static void main(String[] stringArray) {
        MISVM.runClassifier(new MISVM(), stringArray);
    }

    private class SVM
    extends SMO {
        static final long serialVersionUID = -8325638229658828931L;

        protected SVM() {
        }

        protected double output(int n, Instance instance) throws Exception {
            double d = 0.0;
            d = this.m_classifiers[0][1].SVMOutput(n, instance);
            return d;
        }

        public String getRevision() {
            return RevisionUtils.extract("$Revision: 1.6 $");
        }
    }
}

