/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.ThresholdCurve;
import weka.classifiers.functions.Logistic;
import weka.core.Attribute;
import weka.core.AttributeStats;
import weka.core.Capabilities;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;

public class ThresholdSelector
extends RandomizableSingleClassifierEnhancer
implements OptionHandler,
Drawable {
    static final long serialVersionUID = -1795038053239867444L;
    public static final int RANGE_NONE = 0;
    public static final int RANGE_BOUNDS = 1;
    public static final Tag[] TAGS_RANGE = new Tag[]{new Tag(0, "No range correction"), new Tag(1, "Correct based on min/max observed")};
    public static final int EVAL_TRAINING_SET = 2;
    public static final int EVAL_TUNED_SPLIT = 1;
    public static final int EVAL_CROSS_VALIDATION = 0;
    public static final Tag[] TAGS_EVAL = new Tag[]{new Tag(2, "Entire training set"), new Tag(1, "Single tuned fold"), new Tag(0, "N-Fold cross validation")};
    public static final int OPTIMIZE_0 = 0;
    public static final int OPTIMIZE_1 = 1;
    public static final int OPTIMIZE_LFREQ = 2;
    public static final int OPTIMIZE_MFREQ = 3;
    public static final int OPTIMIZE_POS_NAME = 4;
    public static final Tag[] TAGS_OPTIMIZE = new Tag[]{new Tag(0, "First class value"), new Tag(1, "Second class value"), new Tag(2, "Least frequent class value"), new Tag(3, "Most frequent class value"), new Tag(4, "Class value named: \"yes\", \"pos(itive)\",\"1\"")};
    public static final int FMEASURE = 1;
    public static final int ACCURACY = 2;
    public static final int TRUE_POS = 3;
    public static final int TRUE_NEG = 4;
    public static final int TP_RATE = 5;
    public static final int PRECISION = 6;
    public static final int RECALL = 7;
    public static final Tag[] TAGS_MEASURE = new Tag[]{new Tag(1, "FMEASURE"), new Tag(2, "ACCURACY"), new Tag(3, "TRUE_POS"), new Tag(4, "TRUE_NEG"), new Tag(5, "TP_RATE"), new Tag(6, "PRECISION"), new Tag(7, "RECALL")};
    protected double m_HighThreshold = 1.0;
    protected double m_LowThreshold = 0.0;
    protected double m_BestThreshold = -1.7976931348623157E308;
    protected double m_BestValue = -1.7976931348623157E308;
    protected int m_NumXValFolds = 3;
    protected int m_DesignatedClass = 0;
    protected int m_ClassMode = 4;
    protected int m_EvalMode = 1;
    protected int m_RangeMode = 0;
    int m_nMeasure = 1;
    protected boolean m_manualThreshold = false;
    protected double m_manualThresholdValue = -1.0;
    protected static final double MIN_VALUE = 0.05;

    public ThresholdSelector() {
        this.m_Classifier = new Logistic();
    }

    protected String defaultClassifierString() {
        return "weka.classifiers.functions.Logistic";
    }

    protected FastVector getPredictions(Instances instances, int n, int n2) throws Exception {
        EvaluationUtils evaluationUtils = new EvaluationUtils();
        evaluationUtils.setSeed(this.m_Seed);
        switch (n) {
            case 1: {
                Instances instances2 = null;
                Instances instances3 = null;
                Instances instances4 = new Instances(instances);
                Random random = new Random(this.m_Seed);
                instances4.randomize(random);
                instances4.stratify(n2);
                for (int i = 0; i < n2; ++i) {
                    instances2 = instances4.trainCV(n2, i, random);
                    instances3 = instances4.testCV(n2, i);
                    if (this.checkForInstance(instances2) && this.checkForInstance(instances3)) break;
                }
                return evaluationUtils.getTrainTestPredictions(this.m_Classifier, instances2, instances3);
            }
            case 2: {
                return evaluationUtils.getTrainTestPredictions(this.m_Classifier, instances, instances);
            }
            case 0: {
                return evaluationUtils.getCVPredictions(this.m_Classifier, instances, n2);
            }
        }
        throw new RuntimeException("Unrecognized evaluation mode");
    }

    public String measureTipText() {
        return "Sets the measure for determining the threshold.";
    }

    public void setMeasure(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_MEASURE) {
            this.m_nMeasure = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getMeasure() {
        return new SelectedTag(this.m_nMeasure, TAGS_MEASURE);
    }

    protected void findThreshold(FastVector fastVector) {
        Instances instances = new ThresholdCurve().getCurve(fastVector, this.m_DesignatedClass);
        double d = 1.0;
        double d2 = 0.0;
        if (instances.numInstances() > 0) {
            Instance instance = instances.instance(0);
            double d3 = 0.0;
            int n = 0;
            int n2 = 0;
            switch (this.m_nMeasure) {
                case 1: {
                    n = instances.attribute("FMeasure").index();
                    d3 = instance.value(n);
                    break;
                }
                case 3: {
                    n = instances.attribute("True Positives").index();
                    d3 = instance.value(n);
                    break;
                }
                case 4: {
                    n = instances.attribute("True Negatives").index();
                    d3 = instance.value(n);
                    break;
                }
                case 5: {
                    n = instances.attribute("True Positive Rate").index();
                    d3 = instance.value(n);
                    break;
                }
                case 6: {
                    n = instances.attribute("Precision").index();
                    d3 = instance.value(n);
                    break;
                }
                case 7: {
                    n = instances.attribute("Recall").index();
                    d3 = instance.value(n);
                    break;
                }
                case 2: {
                    n = instances.attribute("True Positives").index();
                    n2 = instances.attribute("True Negatives").index();
                    d3 = instance.value(n) + instance.value(n2);
                }
            }
            int n3 = instances.attribute("Threshold").index();
            for (int i = 1; i < instances.numInstances(); ++i) {
                Instance instance2 = instances.instance(i);
                double d4 = 0.0;
                d4 = this.m_nMeasure == 2 ? instance2.value(n) + instance2.value(n2) : instance2.value(n);
                if (d4 > d3) {
                    instance = instance2;
                    d3 = d4;
                }
                if (this.m_RangeMode != 1) continue;
                double d5 = instance2.value(n3);
                if (d5 < d) {
                    d = d5;
                }
                if (!(d5 > d2)) continue;
                d2 = d5;
            }
            if (d3 > 0.05) {
                this.m_BestThreshold = instance.value(n3);
                this.m_BestValue = d3;
            }
            if (this.m_RangeMode == 1) {
                this.m_LowThreshold = d;
                this.m_HighThreshold = d2;
            }
        }
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>(5);
        vector.addElement(new Option("\tThe class for which threshold is determined. Valid values are:\n\t1, 2 (for first and second classes, respectively), 3 (for whichever\n\tclass is least frequent), and 4 (for whichever class value is most\n\tfrequent), and 5 (for the first class named any of \"yes\",\"pos(itive)\"\n\t\"1\", or method 3 if no matches). (default 5).", "C", 1, "-C <integer>"));
        vector.addElement(new Option("\tNumber of folds used for cross validation. If just a\n\thold-out set is used, this determines the size of the hold-out set\n\t(default 3).", "X", 1, "-X <number of folds>"));
        vector.addElement(new Option("\tSets whether confidence range correction is applied. This\n\tcan be used to ensure the confidences range from 0 to 1.\n\tUse 0 for no range correction, 1 for correction based on\n\tthe min/max values seen during threshold selection\n\t(default 0).", "R", 1, "-R <integer>"));
        vector.addElement(new Option("\tSets the evaluation mode. Use 0 for\n\tevaluation using cross-validation,\n\t1 for evaluation using hold-out set,\n\tand 2 for evaluation on the\n\ttraining data (default 1).", "E", 1, "-E <integer>"));
        vector.addElement(new Option("\tMeasure used for evaluation (default is FMEASURE).\n", "M", 1, "-M [FMEASURE|ACCURACY|TRUE_POS|TRUE_NEG|TP_RATE|PRECISION|RECALL]"));
        vector.addElement(new Option("\tSet a manual threshold to use. This option overrides\n\tautomatic selection and options pertaining to\n\tautomatic selection will be ignored.\n\t(default -1, i.e. do not use a manual threshold).", "manual", 1, "-manual <real>"));
        Enumeration enumeration = super.listOptions();
        while (enumeration.hasMoreElements()) {
            vector.addElement((Option)enumeration.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        String string;
        double d;
        String string2 = Utils.getOption("manual", stringArray);
        if (string2.length() > 0 && (d = Double.parseDouble(string2)) >= 0.0) {
            this.setManualThresholdValue(d);
        }
        if ((string = Utils.getOption('C', stringArray)).length() != 0) {
            this.setDesignatedClass(new SelectedTag(Integer.parseInt(string) - 1, TAGS_OPTIMIZE));
        } else {
            this.setDesignatedClass(new SelectedTag(4, TAGS_OPTIMIZE));
        }
        String string3 = Utils.getOption('E', stringArray);
        if (string3.length() != 0) {
            this.setEvaluationMode(new SelectedTag(Integer.parseInt(string3), TAGS_EVAL));
        } else {
            this.setEvaluationMode(new SelectedTag(1, TAGS_EVAL));
        }
        String string4 = Utils.getOption('R', stringArray);
        if (string4.length() != 0) {
            this.setRangeCorrection(new SelectedTag(Integer.parseInt(string4), TAGS_RANGE));
        } else {
            this.setRangeCorrection(new SelectedTag(0, TAGS_RANGE));
        }
        String string5 = Utils.getOption('M', stringArray);
        if (string5.length() != 0) {
            this.setMeasure(new SelectedTag(string5, TAGS_MEASURE));
        } else {
            this.setMeasure(new SelectedTag(1, TAGS_MEASURE));
        }
        String string6 = Utils.getOption('X', stringArray);
        if (string6.length() != 0) {
            this.setNumXValFolds(Integer.parseInt(string6));
        } else {
            this.setNumXValFolds(3);
        }
        super.setOptions(stringArray);
    }

    public String[] getOptions() {
        String[] stringArray = super.getOptions();
        String[] stringArray2 = new String[stringArray.length + 12];
        int n = 0;
        if (this.m_manualThreshold) {
            stringArray2[n++] = "-manual";
            stringArray2[n++] = "" + this.getManualThresholdValue();
        }
        stringArray2[n++] = "-C";
        stringArray2[n++] = "" + (this.m_ClassMode + 1);
        stringArray2[n++] = "-X";
        stringArray2[n++] = "" + this.getNumXValFolds();
        stringArray2[n++] = "-E";
        stringArray2[n++] = "" + this.m_EvalMode;
        stringArray2[n++] = "-R";
        stringArray2[n++] = "" + this.m_RangeMode;
        stringArray2[n++] = "-M";
        stringArray2[n++] = "" + this.getMeasure().getSelectedTag().getReadable();
        System.arraycopy(stringArray, 0, stringArray2, n, stringArray.length);
        n += stringArray.length;
        while (n < stringArray2.length) {
            stringArray2[n++] = "";
        }
        return stringArray2;
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAllClasses();
        capabilities.disableAllClassDependencies();
        capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        this.getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        AttributeStats attributeStats = instances.attributeStats(instances.classIndex());
        this.m_BestThreshold = this.m_manualThreshold ? this.m_manualThresholdValue : 0.5;
        this.m_BestValue = 0.05;
        this.m_HighThreshold = 1.0;
        this.m_LowThreshold = 0.0;
        if (attributeStats.distinctCount != 2) {
            System.err.println("Couldn't find examples of both classes. No adjustment.");
            this.m_Classifier.buildClassifier(instances);
        } else {
            switch (this.m_ClassMode) {
                case 0: {
                    this.m_DesignatedClass = 0;
                    break;
                }
                case 1: {
                    this.m_DesignatedClass = 1;
                    break;
                }
                case 4: {
                    Attribute attribute = instances.classAttribute();
                    boolean bl = false;
                    for (int i = 0; i < attribute.numValues() && !bl; ++i) {
                        String string = attribute.value(i).toLowerCase();
                        if (!string.startsWith("yes") && !string.equals("1") && !string.startsWith("pos")) continue;
                        bl = true;
                        this.m_DesignatedClass = i;
                    }
                    if (bl) break;
                }
                case 2: {
                    this.m_DesignatedClass = attributeStats.nominalCounts[0] > attributeStats.nominalCounts[1] ? 1 : 0;
                    break;
                }
                case 3: {
                    this.m_DesignatedClass = attributeStats.nominalCounts[0] > attributeStats.nominalCounts[1] ? 0 : 1;
                    break;
                }
                default: {
                    throw new Exception("Unrecognized class value selection mode");
                }
            }
            if (this.m_manualThreshold) {
                this.m_Classifier.buildClassifier(instances);
                return;
            }
            if (attributeStats.nominalCounts[this.m_DesignatedClass] == 1) {
                System.err.println("Only 1 positive found: optimizing on training data");
                this.findThreshold(this.getPredictions(instances, 2, 0));
            } else {
                int n = Math.min(this.m_NumXValFolds, attributeStats.nominalCounts[this.m_DesignatedClass]);
                this.findThreshold(this.getPredictions(instances, this.m_EvalMode, n));
                if (this.m_EvalMode != 2) {
                    this.m_Classifier.buildClassifier(instances);
                }
            }
        }
    }

    private boolean checkForInstance(Instances instances) throws Exception {
        for (int i = 0; i < instances.numInstances(); ++i) {
            if ((int)instances.instance(i).classValue() != this.m_DesignatedClass) continue;
            return true;
        }
        return false;
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] dArray = this.m_Classifier.distributionForInstance(instance);
        double d = dArray[this.m_DesignatedClass];
        d = d > this.m_BestThreshold ? 0.5 + (d - this.m_BestThreshold) / ((this.m_HighThreshold - this.m_BestThreshold) * 2.0) : (d - this.m_LowThreshold) / ((this.m_BestThreshold - this.m_LowThreshold) * 2.0);
        if (d < 0.0) {
            d = 0.0;
        } else if (d > 1.0) {
            d = 1.0;
        }
        dArray[this.m_DesignatedClass] = d;
        if (dArray.length == 2) {
            dArray[(this.m_DesignatedClass + 1) % 2] = 1.0 - d;
        }
        return dArray;
    }

    public String globalInfo() {
        return "A metaclassifier that selecting a mid-point threshold on the probability output by a Classifier. The midpoint threshold is set so that a given performance measure is optimized. Currently this is the F-measure. Performance is measured either on the training data, a hold-out set or using cross-validation. In addition, the probabilities returned by the base learner can have their range expanded so that the output probabilities will reside between 0 and 1 (this is useful if the scheme normally produces probabilities in a very narrow range).";
    }

    public String designatedClassTipText() {
        return "Sets the class value for which the optimization is performed. The options are: pick the first class value; pick the second class value; pick whichever class is least frequent; pick whichever class value is most frequent; pick the first class named any of \"yes\",\"pos(itive)\", \"1\", or the least frequent if no matches).";
    }

    public SelectedTag getDesignatedClass() {
        return new SelectedTag(this.m_ClassMode, TAGS_OPTIMIZE);
    }

    public void setDesignatedClass(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_OPTIMIZE) {
            this.m_ClassMode = selectedTag.getSelectedTag().getID();
        }
    }

    public String evaluationModeTipText() {
        return "Sets the method used to determine the threshold/performance curve. The options are: perform optimization based on the entire training set (may result in overfitting); perform an n-fold cross-validation (may be time consuming); perform one fold of an n-fold cross-validation (faster but likely less accurate).";
    }

    public void setEvaluationMode(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_EVAL) {
            this.m_EvalMode = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getEvaluationMode() {
        return new SelectedTag(this.m_EvalMode, TAGS_EVAL);
    }

    public String rangeCorrectionTipText() {
        return "Sets the type of prediction range correction performed. The options are: do not do any range correction; expand predicted probabilities so that the minimum probability observed during the optimization maps to 0, and the maximum maps to 1 (values outside this range are clipped to 0 and 1).";
    }

    public void setRangeCorrection(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_RANGE) {
            this.m_RangeMode = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getRangeCorrection() {
        return new SelectedTag(this.m_RangeMode, TAGS_RANGE);
    }

    public String numXValFoldsTipText() {
        return "Sets the number of folds used during full cross-validation and tuned fold evaluation. This number will be automatically reduced if there are insufficient positive examples.";
    }

    public int getNumXValFolds() {
        return this.m_NumXValFolds;
    }

    public void setNumXValFolds(int n) {
        if (n < 2) {
            throw new IllegalArgumentException("Number of folds must be greater than 1");
        }
        this.m_NumXValFolds = n;
    }

    public int graphType() {
        if (this.m_Classifier instanceof Drawable) {
            return ((Drawable)((Object)this.m_Classifier)).graphType();
        }
        return 0;
    }

    public String graph() throws Exception {
        if (this.m_Classifier instanceof Drawable) {
            return ((Drawable)((Object)this.m_Classifier)).graph();
        }
        throw new Exception("Classifier: " + this.getClassifierSpec() + " cannot be graphed");
    }

    public String manualThresholdValueTipText() {
        return "Sets a manual threshold value to use. If this is set (non-negative value between 0 and 1), then all options pertaining to automatic threshold selection are ignored. ";
    }

    public void setManualThresholdValue(double d) throws Exception {
        this.m_manualThresholdValue = d;
        if (d >= 0.0 && d <= 1.0) {
            this.m_manualThreshold = true;
        } else {
            this.m_manualThreshold = false;
            if (d >= 0.0) {
                throw new IllegalArgumentException("Threshold must be in the range 0..1.");
            }
        }
    }

    public double getManualThresholdValue() {
        return this.m_manualThresholdValue;
    }

    public String toString() {
        if (this.m_BestValue == -1.7976931348623157E308) {
            return "ThresholdSelector: No model built yet.";
        }
        String string = "Threshold Selector.\nClassifier: " + this.m_Classifier.getClass().getName() + "\n";
        string = string + "Index of designated class: " + this.m_DesignatedClass + "\n";
        if (this.m_manualThreshold) {
            string = string + "User supplied threshold: " + this.m_BestThreshold + "\n";
        } else {
            string = string + "Evaluation mode: ";
            switch (this.m_EvalMode) {
                case 0: {
                    string = string + this.m_NumXValFolds + "-fold cross-validation";
                    break;
                }
                case 1: {
                    string = string + "tuning on 1/" + this.m_NumXValFolds + " of the data";
                    break;
                }
                default: {
                    string = string + "tuning on the training data";
                }
            }
            string = string + "\n";
            string = string + "Threshold: " + this.m_BestThreshold + "\n";
            string = string + "Best value: " + this.m_BestValue + "\n";
            if (this.m_RangeMode == 1) {
                string = string + "Expanding range [" + this.m_LowThreshold + "," + this.m_HighThreshold + "] to [0, 1]\n";
            }
            string = string + "Measure: " + this.getMeasure().getSelectedTag().getReadable() + "\n";
        }
        string = string + this.m_Classifier.toString();
        return string;
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.43 $");
    }

    public static void main(String[] stringArray) {
        ThresholdSelector.runClassifier(new ThresholdSelector(), stringArray);
    }
}

