/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.mi;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.functions.SMO;
import weka.classifiers.functions.supportVector.Kernel;
import weka.classifiers.functions.supportVector.PolyKernel;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.MultiInstanceToPropositional;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.Standardize;
import weka.filters.unsupervised.instance.SparseToNonSparse;

public class MISVM
extends Classifier
implements OptionHandler,
MultiInstanceCapabilitiesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = 7622231064035278145L;
    protected Filter m_SparseFilter = new SparseToNonSparse();
    protected SVM m_SVM;
    protected Kernel m_kernel = new PolyKernel();
    protected double m_C = 1.0;
    protected Filter m_Filter = null;
    protected int m_filterType = 0;
    public static final int FILTER_NORMALIZE = 0;
    public static final int FILTER_STANDARDIZE = 1;
    public static final int FILTER_NONE = 2;
    public static final Tag[] TAGS_FILTER = new Tag[]{new Tag(0, "Normalize training data"), new Tag(1, "Standardize training data"), new Tag(2, "No normalization/standardization")};
    protected int m_MaxIterations = 500;
    protected MultiInstanceToPropositional m_ConvertToProp = new MultiInstanceToPropositional();

    public String globalInfo() {
        return "Implements Stuart Andrews' mi_SVM (Maximum pattern Margin Formulation of MIL). Applying weka.classifiers.functions.SMO to solve multiple instances problem.\nThe algorithm first assign the bag label to each instance in the bag as its initial class label.  After that applying SMO to compute SVM solution for all instances in positive bags And then reassign the class label of each instance in the positive bag according to the SVM result Keep on iteration until labels do not change anymore.\n\nFor more information see:\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Stuart Andrews and Ioannis Tsochantaridis and Thomas Hofmann");
        result.setValue(TechnicalInformation.Field.YEAR, "2003");
        result.setValue(TechnicalInformation.Field.TITLE, "Support Vector Machines for Multiple-Instance Learning");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "Advances in Neural Information Processing Systems 15");
        result.setValue(TechnicalInformation.Field.PUBLISHER, "MIT Press");
        result.setValue(TechnicalInformation.Field.PAGES, "561-568");
        return result;
    }

    @Override
    public Enumeration listOptions() {
        Vector result = new Vector();
        Enumeration enm = super.listOptions();
        while (enm.hasMoreElements()) {
            result.addElement(enm.nextElement());
        }
        result.addElement(new Option("\tThe complexity constant C. (default 1)", "C", 1, "-C <double>"));
        result.addElement(new Option("\tWhether to 0=normalize/1=standardize/2=neither.\n\t(default: 0=normalize)", "N", 1, "-N <default 0>"));
        result.addElement(new Option("\tThe maximum number of iterations to perform.\n\t(default: 500)", "I", 1, "-I <num>"));
        result.addElement(new Option("\tThe Kernel to use.\n\t(default: weka.classifiers.functions.supportVector.PolyKernel)", "K", 1, "-K <classname and parameters>"));
        result.addElement(new Option("", "", 0, "\nOptions specific to kernel " + this.getKernel().getClass().getName() + ":"));
        enm = this.getKernel().listOptions();
        while (enm.hasMoreElements()) {
            result.addElement(enm.nextElement());
        }
        return result.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String tmpStr = Utils.getOption('C', options);
        if (tmpStr.length() != 0) {
            this.setC(Double.parseDouble(tmpStr));
        } else {
            this.setC(1.0);
        }
        tmpStr = Utils.getOption('N', options);
        if (tmpStr.length() != 0) {
            this.setFilterType(new SelectedTag(Integer.parseInt(tmpStr), TAGS_FILTER));
        } else {
            this.setFilterType(new SelectedTag(0, TAGS_FILTER));
        }
        tmpStr = Utils.getOption('I', options);
        if (tmpStr.length() != 0) {
            this.setMaxIterations(Integer.parseInt(tmpStr));
        } else {
            this.setMaxIterations(500);
        }
        tmpStr = Utils.getOption('K', options);
        String[] tmpOptions = Utils.splitOptions(tmpStr);
        if (tmpOptions.length != 0) {
            tmpStr = tmpOptions[0];
            tmpOptions[0] = "";
            this.setKernel(Kernel.forName(tmpStr, tmpOptions));
        }
        super.setOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        if (this.getDebug()) {
            result.add("-D");
        }
        result.add("-C");
        result.add("" + this.getC());
        result.add("-N");
        result.add("" + this.m_filterType);
        result.add("-K");
        result.add(this.getKernel().getClass().getName() + " " + Utils.joinOptions(this.getKernel().getOptions()));
        return result.toArray(new String[result.size()]);
    }

    public String kernelTipText() {
        return "The kernel to use.";
    }

    public Kernel getKernel() {
        return this.m_kernel;
    }

    public void setKernel(Kernel value) {
        this.m_kernel = value;
    }

    public String filterTypeTipText() {
        return "The filter type for transforming the training data.";
    }

    public void setFilterType(SelectedTag newType) {
        if (newType.getTags() == TAGS_FILTER) {
            this.m_filterType = newType.getSelectedTag().getID();
        }
    }

    public SelectedTag getFilterType() {
        return new SelectedTag(this.m_filterType, TAGS_FILTER);
    }

    public String cTipText() {
        return "The value for C.";
    }

    public double getC() {
        return this.m_C;
    }

    public void setC(double v) {
        this.m_C = v;
    }

    public String maxIterationsTipText() {
        return "The maximum number of iterations to perform.";
    }

    public int getMaxIterations() {
        return this.m_MaxIterations;
    }

    public void setMaxIterations(int value) {
        if (value < 1) {
            System.out.println("At least 1 iteration is necessary (provided: " + value + ")!");
        } else {
            this.m_MaxIterations = value;
        }
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        result.disableAllClasses();
        result.disableAllClassDependencies();
        result.enable(Capabilities.Capability.BINARY_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        return result;
    }

    @Override
    public Capabilities getMultiInstanceCapabilities() {
        SVM classifier = null;
        Capabilities result = null;
        try {
            classifier = new SVM();
            classifier.setKernel(Kernel.makeCopy(this.getKernel()));
            result = classifier.getCapabilities();
            result.setOwner(this);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        result.disableAllClasses();
        result.enable(Capabilities.Capability.NO_CLASS);
        return result;
    }

    @Override
    public void buildClassifier(Instances train) throws Exception {
        this.getCapabilities().testWithFail(train);
        train = new Instances(train);
        train.deleteWithMissingClass();
        int numBags = train.numInstances();
        int[] bagSize = new int[numBags];
        int[] classes = new int[numBags];
        Vector<Double> instLabels = new Vector<Double>();
        Vector pre_instLabels = new Vector();
        int h = 0;
        while (h < numBags) {
            classes[h] = (int)train.instance(h).classValue();
            bagSize[h] = train.instance(h).relationalValue(1).numInstances();
            int i = 0;
            while (i < bagSize[h]) {
                instLabels.addElement(new Double(classes[h]));
                ++i;
            }
            ++h;
        }
        this.m_ConvertToProp.setWeightMethod(new SelectedTag(1, MultiInstanceToPropositional.TAGS_WEIGHTMETHOD));
        this.m_ConvertToProp.setInputFormat(train);
        train = Filter.useFilter(train, this.m_ConvertToProp);
        train.deleteAttributeAt(0);
        this.m_Filter = this.m_filterType == 1 ? new Standardize() : (this.m_filterType == 0 ? new Normalize() : null);
        if (this.m_Filter != null) {
            this.m_Filter.setInputFormat(train);
            train = Filter.useFilter(train, this.m_Filter);
        }
        if (this.m_Debug) {
            System.out.println("\nIteration History...");
        }
        if (this.getDebug()) {
            System.out.println("\nstart building model ...");
        }
        Vector<Integer> max_index = new Vector<Integer>();
        Instance inst = null;
        int loopNum = 0;
        do {
            ++loopNum;
            int index = -1;
            if (this.m_Debug) {
                System.out.println("=====================loop: " + loopNum);
            }
            pre_instLabels = (Vector)instLabels.clone();
            this.m_SVM = new SVM();
            this.m_SVM.setC(this.getC());
            this.m_SVM.setKernel(Kernel.makeCopy(this.getKernel()));
            this.m_SVM.setFilterType(new SelectedTag(2, TAGS_FILTER));
            this.m_SVM.buildClassifier(train);
            int h2 = 0;
            while (h2 < numBags) {
                if (classes[h2] == 1) {
                    double output;
                    if (this.m_Debug) {
                        System.out.println("--------------- " + h2 + " ----------------");
                    }
                    double sum = 0.0;
                    int i = 0;
                    while (i < bagSize[h2]) {
                        if ((output = this.m_SVM.output(-1, inst = train.instance(++index))) <= 0.0) {
                            if (inst.classValue() == 1.0) {
                                train.instance(index).setClassValue(0.0);
                                instLabels.set(index, new Double(0.0));
                                if (this.m_Debug) {
                                    System.out.println(String.valueOf(index) + "- changed to 0");
                                }
                            }
                        } else if (inst.classValue() == 0.0) {
                            train.instance(index).setClassValue(1.0);
                            instLabels.set(index, new Double(1.0));
                            if (this.m_Debug) {
                                System.out.println(String.valueOf(index) + "+ changed to 1");
                            }
                        }
                        sum += train.instance(index).classValue();
                        ++i;
                    }
                    if (sum == 0.0) {
                        double max_output = -1.7976931348623157E308;
                        max_index.clear();
                        int j = index - bagSize[h2] + 1;
                        while (j < index + 1) {
                            inst = train.instance(j);
                            output = this.m_SVM.output(-1, inst);
                            if (max_output < output) {
                                max_output = output;
                                max_index.clear();
                                max_index.add(new Integer(j));
                            } else if (max_output == output) {
                                max_index.add(new Integer(j));
                            }
                            ++j;
                        }
                        int vecIndex = 0;
                        while (vecIndex < max_index.size()) {
                            Integer i2 = (Integer)max_index.get(vecIndex);
                            train.instance(i2).setClassValue(1.0);
                            instLabels.set(i2, new Double(1.0));
                            if (this.m_Debug) {
                                System.out.println("##change to 1 ###outpput: " + max_output + " max_index: " + i2 + " bag: " + h2);
                            }
                            ++vecIndex;
                        }
                    }
                } else {
                    index += bagSize[h2];
                }
                ++h2;
            }
        } while (!instLabels.equals(pre_instLabels) && loopNum < this.m_MaxIterations);
        if (this.getDebug()) {
            System.out.println("finish building model.");
        }
    }

    @Override
    public double[] distributionForInstance(Instance exmp) throws Exception {
        double sum = 0.0;
        double[] distribution = new double[2];
        Instances testData = new Instances(exmp.dataset(), 0);
        testData.add(exmp);
        testData = Filter.useFilter(testData, this.m_ConvertToProp);
        testData.deleteAttributeAt(0);
        if (this.m_Filter != null) {
            testData = Filter.useFilter(testData, this.m_Filter);
        }
        int j = 0;
        while (j < testData.numInstances()) {
            Instance inst = testData.instance(j);
            double output = this.m_SVM.output(-1, inst);
            double classValue = output <= 0.0 ? 0.0 : 1.0;
            sum += classValue;
            ++j;
        }
        distribution[0] = sum == 0.0 ? 1.0 : 0.0;
        distribution[1] = 1.0 - distribution[0];
        return distribution;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 9144 $");
    }

    public static void main(String[] argv) {
        MISVM.runClassifier(new MISVM(), argv);
    }

    private class SVM
    extends SMO {
        static final long serialVersionUID = -8325638229658828931L;

        protected SVM() {
        }

        protected double output(int index, Instance inst) throws Exception {
            double output = 0.0;
            output = this.m_classifiers[0][1].SVMOutput(index, inst);
            return output;
        }

        @Override
        public String getRevision() {
            return RevisionUtils.extract("$Revision: 9144 $");
        }
    }
}

