/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.bayes;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Random;
import java.util.TreeMap;
import weka.classifiers.Classifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionHandler;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class DMNBtext
extends Classifier
implements OptionHandler,
WeightedInstancesHandler,
TechnicalInformationHandler,
UpdateableClassifier {
    static final long serialVersionUID = 5932177450183457085L;
    protected int m_NumIterations = 1;
    protected boolean m_BinaryWord = true;
    int m_numClasses = -1;
    protected Instances m_headerInfo;
    DNBBinary[] m_binaryClassifiers = null;

    public String globalInfo() {
        return "Class for building and using a Discriminative Multinomial Naive Bayes classifier. For more information see,\n\n" + this.getTechnicalInformation().toString() + "\n\n" + "The core equation for this classifier:\n\n" + "P[Ci|D] = (P[D|Ci] x P[Ci]) / P[D] (Bayes rule)\n\n" + "where Ci is class i and D is a document.";
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Jiang Su,Harry Zhang,Charles X. Ling,Stan Matwin");
        result.setValue(TechnicalInformation.Field.YEAR, "2008");
        result.setValue(TechnicalInformation.Field.TITLE, "Discriminative Parameter Learning for Bayesian Networks");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "ICML 2008'");
        return result;
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    public void buildClassifier(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        Instances instances = new Instances(data);
        instances.deleteWithMissingClass();
        this.m_binaryClassifiers = new DNBBinary[instances.numClasses()];
        this.m_numClasses = instances.numClasses();
        this.m_headerInfo = new Instances(instances, 0);
        for (int i = 0; i < instances.numClasses(); ++i) {
            this.m_binaryClassifiers[i] = new DNBBinary();
            this.m_binaryClassifiers[i].setTargetClass(i);
            this.m_binaryClassifiers[i].initClassifier(instances);
        }
        if (instances.numInstances() == 0) {
            return;
        }
        Random random = new Random();
        for (int it = 0; it < this.m_NumIterations; ++it) {
            for (int i = 0; i < instances.numInstances(); ++i) {
                this.updateClassifier(instances.instance(i));
            }
        }
    }

    public void updateClassifier(Instance instance) throws Exception {
        if (this.m_numClasses == 2) {
            this.m_binaryClassifiers[0].updateClassifier(instance);
        } else {
            for (int i = 0; i < instance.numClasses(); ++i) {
                this.m_binaryClassifiers[i].updateClassifier(instance);
            }
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_numClasses == 2) {
            return this.m_binaryClassifiers[0].distributionForInstance(instance);
        }
        double[] logDocGivenClass = new double[instance.numClasses()];
        for (int i = 0; i < this.m_numClasses; ++i) {
            logDocGivenClass[i] = this.m_binaryClassifiers[i].getLogProbForTargetClass(instance);
        }
        double max = logDocGivenClass[Utils.maxIndex(logDocGivenClass)];
        for (int i = 0; i < this.m_numClasses; ++i) {
            logDocGivenClass[i] = Math.exp(logDocGivenClass[i] - max);
        }
        try {
            Utils.normalize(logDocGivenClass);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return logDocGivenClass;
    }

    public String toString() {
        StringBuffer result = new StringBuffer("");
        result.append("The log ratio of two conditional probabilities of a word w_i: log(p(w_i)|+)/p(w_i)|-)) in decent order based on their absolute values\n");
        result.append("Can be used to measure the discriminative power of each word.\n");
        if (this.m_numClasses == 2) {
            return result.append(this.m_binaryClassifiers[0].toString()).toString();
        }
        for (int i = 0; i < this.m_numClasses; ++i) {
            result.append(i + " against the rest classes\n");
            result.append(this.m_binaryClassifiers[i].toString() + "\n");
        }
        return result.toString();
    }

    public void setOptions(String[] options) throws Exception {
        String iterations = Utils.getOption('I', options);
        if (iterations.length() != 0) {
            this.setNumIterations(Integer.parseInt(iterations));
        } else {
            this.setNumIterations(this.m_NumIterations);
        }
        iterations = Utils.getOption('B', options);
        if (iterations.length() != 0) {
            this.setBinaryWord(Boolean.parseBoolean(iterations));
        } else {
            this.setBinaryWord(this.m_BinaryWord);
        }
    }

    public String[] getOptions() {
        String[] options = new String[4];
        int current = 0;
        options[current++] = "-I";
        options[current++] = "" + this.getNumIterations();
        options[current++] = "-B";
        options[current++] = "" + this.getBinaryWord();
        return options;
    }

    public String numIterationsTipText() {
        return "The number of iterations that the classifier will scan the training data";
    }

    public void setNumIterations(int numIterations) {
        this.m_NumIterations = numIterations;
    }

    public int getNumIterations() {
        return this.m_NumIterations;
    }

    public String binaryWordTipText() {
        return " whether ingore the frequency information in data";
    }

    public void setBinaryWord(boolean val) {
        this.m_BinaryWord = val;
    }

    public boolean getBinaryWord() {
        return this.m_BinaryWord;
    }

    public String getRevision() {
        return "$Revision: 1.0";
    }

    public static void main(String[] argv) {
        DMNBtext c = new DMNBtext();
        DMNBtext.runClassifier(c, argv);
    }

    public class DNBBinary
    implements Serializable {
        private double[][] m_perWordPerClass;
        private double[] m_wordsPerClass;
        int m_classIndex = -1;
        private double[] m_classDistribution;
        private int m_numAttributes;
        private int m_targetClass = -1;
        private double m_WordLaplace = 1.0;
        private double[] m_coefficient;
        private double m_classRatio;
        private double m_wordRatio;

        public void initClassifier(Instances instances) throws Exception {
            this.m_numAttributes = instances.numAttributes();
            this.m_perWordPerClass = new double[2][this.m_numAttributes];
            this.m_coefficient = new double[this.m_numAttributes];
            this.m_wordsPerClass = new double[2];
            this.m_classDistribution = new double[2];
            this.m_WordLaplace = Math.log(this.m_numAttributes);
            this.m_classIndex = instances.classIndex();
            for (int c = 0; c < 2; ++c) {
                this.m_classDistribution[c] = 1.0;
                this.m_wordsPerClass[c] = this.m_WordLaplace * (double)this.m_numAttributes;
                Arrays.fill(this.m_perWordPerClass[c], this.m_WordLaplace);
            }
        }

        public void updateClassifier(Instance ins) throws Exception {
            int classIndex = 0;
            if (ins.value(ins.classIndex()) != (double)this.m_targetClass) {
                classIndex = 1;
            }
            double prob = 1.0 - this.distributionForInstance(ins)[classIndex];
            double weight = prob * ins.weight();
            for (int a = 0; a < ins.numValues(); ++a) {
                if (ins.index(a) == this.m_classIndex) continue;
                if (DMNBtext.this.m_BinaryWord) {
                    if (ins.valueSparse(a) > 0.0) {
                        int n = classIndex;
                        this.m_wordsPerClass[n] = this.m_wordsPerClass[n] + weight;
                        double[] dArray = this.m_perWordPerClass[classIndex];
                        int n2 = ins.index(a);
                        dArray[n2] = dArray[n2] + weight;
                    }
                } else {
                    double t = ins.valueSparse(a) * weight;
                    int n = classIndex;
                    this.m_wordsPerClass[n] = this.m_wordsPerClass[n] + t;
                    double[] dArray = this.m_perWordPerClass[classIndex];
                    int n3 = ins.index(a);
                    dArray[n3] = dArray[n3] + t;
                }
                this.m_coefficient[ins.index((int)a)] = Math.log(this.m_perWordPerClass[0][ins.index(a)] / this.m_perWordPerClass[1][ins.index(a)]);
            }
            this.m_wordRatio = Math.log(this.m_wordsPerClass[0] / this.m_wordsPerClass[1]);
            int n = classIndex;
            this.m_classDistribution[n] = this.m_classDistribution[n] + weight;
            this.m_classRatio = Math.log(this.m_classDistribution[0] / this.m_classDistribution[1]);
        }

        public double getLogProbForTargetClass(Instance ins) throws Exception {
            double probLog = this.m_classRatio;
            for (int a = 0; a < ins.numValues(); ++a) {
                if (ins.index(a) == this.m_classIndex) continue;
                if (DMNBtext.this.m_BinaryWord) {
                    if (!(ins.valueSparse(a) > 0.0)) continue;
                    probLog += this.m_coefficient[ins.index(a)] - this.m_wordRatio;
                    continue;
                }
                probLog += ins.valueSparse(a) * (this.m_coefficient[ins.index(a)] - this.m_wordRatio);
            }
            return probLog;
        }

        public double[] distributionForInstance(Instance instance) throws Exception {
            double[] probOfClassGivenDoc = new double[2];
            double ratio = this.getLogProbForTargetClass(instance);
            if (ratio > 709.0) {
                probOfClassGivenDoc[0] = 1.0;
            } else {
                ratio = Math.exp(ratio);
                probOfClassGivenDoc[0] = ratio / (1.0 + ratio);
            }
            probOfClassGivenDoc[1] = 1.0 - probOfClassGivenDoc[0];
            return probOfClassGivenDoc;
        }

        public String toString() {
            StringBuffer result = new StringBuffer();
            result.append("\n");
            TreeMap<Double, String> sort = new TreeMap<Double, String>();
            double[] absCoeff = new double[this.m_numAttributes];
            for (int w = 0; w < this.m_numAttributes; ++w) {
                if (w == DMNBtext.this.m_headerInfo.classIndex()) continue;
                String val = DMNBtext.this.m_headerInfo.attribute(w).name() + ": " + this.m_coefficient[w];
                sort.put(-1.0 * Math.abs(this.m_coefficient[w]), val);
            }
            Iterator it = sort.values().iterator();
            while (it.hasNext()) {
                result.append((String)it.next());
                result.append("\n");
            }
            return result.toString();
        }

        public void setTargetClass(int targetClass) {
            this.m_targetClass = targetClass;
        }

        public int getTargetClass() {
            return this.m_targetClass;
        }
    }
}

