/*
 * Decompiled with CFR 0.152.
 */
package weka.knowledgeflow.steps;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.Evaluation;
import weka.core.Instance;
import weka.core.Utils;
import weka.core.WekaException;
import weka.knowledgeflow.Data;
import weka.knowledgeflow.steps.BaseStep;
import weka.knowledgeflow.steps.KFStep;

@KFStep(name="IncrementalClassifierEvaluator", category="Evaluation", toolTipText="Evaluate the performance of incrementally training classifiers", iconPath="weka/gui/knowledgeflow/icons/IncrementalClassifierEvaluator.gif")
public class IncrementalClassifierEvaluator
extends BaseStep {
    private static final long serialVersionUID = -5951569492213633100L;
    protected List<String> m_dataLegend;
    protected double[] m_dataPoint;
    protected Data m_chartData = new Data("chart");
    protected double m_min = Double.MAX_VALUE;
    protected double m_max = Double.MIN_VALUE;
    protected int m_statusFrequency = 2000;
    protected int m_instanceCount;
    protected boolean m_outputInfoRetrievalStats;
    protected Evaluation m_eval;
    protected int m_windowSize;
    protected Evaluation m_windowEval;
    protected LinkedList<Instance> m_window;
    protected LinkedList<double[]> m_windowedPreds;
    protected boolean m_reset;
    protected String m_classifierName;

    @Override
    public void stepInit() throws WekaException {
        this.m_instanceCount = 0;
        this.m_dataPoint = new double[1];
        this.m_dataLegend = new ArrayList<String>();
        if (this.m_windowSize > 0) {
            this.m_window = new LinkedList();
            this.m_windowedPreds = new LinkedList();
            this.getStepManager().logBasic("Chart output using windowed evaluation over " + this.m_windowSize + " instances");
        }
        this.m_reset = true;
    }

    @Override
    public List<String> getIncomingConnectionTypes() {
        if (this.getStepManager().numIncomingConnections() == 0) {
            return Arrays.asList("incrementalClassifier");
        }
        return new ArrayList<String>();
    }

    @Override
    public List<String> getOutgoingConnectionTypes() {
        ArrayList<String> result = new ArrayList<String>();
        if (this.getStepManager().numIncomingConnectionsOfType("incrementalClassifier") > 0) {
            result.add("text");
            result.add("chart");
        }
        return result;
    }

    @Override
    public void processIncoming(Data data) throws WekaException {
        if (this.isStopRequested()) {
            return;
        }
        if (this.getStepManager().isStreamFinished(data)) {
            Data d = new Data("chart");
            this.getStepManager().throughputFinished(d);
            this.m_windowEval = null;
            this.m_window = null;
            this.m_windowedPreds = null;
            if (this.getStepManager().numOutgoingConnectionsOfType("text") > 0) {
                try {
                    String textTitle = this.m_classifierName;
                    String results = "=== Performance information ===\n\nScheme:   " + textTitle + "\n" + "Relation: " + this.m_eval.getHeader().relationName() + "\n\n" + this.m_eval.toSummaryString();
                    if (this.m_eval.getHeader().classIndex() >= 0 && this.m_eval.getHeader().classAttribute().isNominal() && this.m_outputInfoRetrievalStats) {
                        results = results + "\n" + this.m_eval.toClassDetailsString();
                    }
                    if (this.m_eval.getHeader().classIndex() >= 0 && this.m_eval.getHeader().classAttribute().isNominal()) {
                        results = results + "\n" + this.m_eval.toMatrixString();
                    }
                    textTitle = "Results: " + textTitle;
                    Data textData = new Data("text");
                    textData.setPayloadElement("text", results);
                    textData.setPayloadElement("aux_textTitle", textTitle);
                    this.getStepManager().outputData(textData);
                }
                catch (Exception ex) {
                    throw new WekaException(ex);
                }
            }
            return;
        }
        Classifier classifier = (Classifier)data.getPayloadElement("incrementalClassifier");
        Instance instance = (Instance)data.getPayloadElement("aux_testInstance");
        try {
            if (this.m_reset) {
                this.m_reset = false;
                this.m_classifierName = classifier.getClass().getName();
                this.m_classifierName = this.m_classifierName.substring(this.m_classifierName.lastIndexOf(".") + 1, this.m_classifierName.length());
                this.m_eval = new Evaluation(instance.dataset());
                this.m_eval.useNoPriors();
                if (this.m_windowSize > 0) {
                    this.m_windowEval = new Evaluation(instance.dataset());
                    this.m_windowEval.useNoPriors();
                }
                if (instance.classAttribute().isNominal()) {
                    if (!instance.classIsMissing()) {
                        this.m_dataPoint = new double[3];
                        this.m_dataLegend.add("Accuracy");
                        this.m_dataLegend.add("RMSE (prob)");
                        this.m_dataLegend.add("Kappa");
                    } else {
                        this.m_dataPoint = new double[1];
                        this.m_dataLegend.add("Confidence");
                    }
                } else {
                    this.m_dataPoint = new double[1];
                    if (instance.classIsMissing()) {
                        this.m_dataLegend.add("Prediction");
                    } else {
                        this.m_dataLegend.add("RMSE");
                    }
                }
            }
            this.getStepManager().throughputUpdateStart();
            ++this.m_instanceCount;
            double[] dist = classifier.distributionForInstance(instance);
            double pred = 0.0;
            if (!instance.classIsMissing()) {
                if (this.m_outputInfoRetrievalStats) {
                    this.m_eval.evaluateModelOnceAndRecordPrediction(dist, instance);
                } else {
                    this.m_eval.evaluateModelOnce(dist, instance);
                }
                if (this.m_windowSize > 0) {
                    this.m_windowEval.evaluateModelOnce(dist, instance);
                    this.m_window.addFirst(instance);
                    this.m_windowedPreds.addFirst(dist);
                    if (this.m_instanceCount > this.m_windowSize) {
                        Instance oldest = this.m_window.removeLast();
                        double[] oldDist = this.m_windowedPreds.removeLast();
                        oldest.setWeight(-oldest.weight());
                        this.m_windowEval.evaluateModelOnce(oldDist, oldest);
                        oldest.setWeight(-oldest.weight());
                    }
                }
            } else {
                pred = classifier.classifyInstance(instance);
            }
            if (instance.classIndex() >= 0) {
                if (instance.classAttribute().isNominal()) {
                    if (!instance.classIsMissing()) {
                        if (this.m_windowSize > 0) {
                            this.m_dataPoint[1] = this.m_windowEval.rootMeanSquaredError();
                            this.m_dataPoint[2] = this.m_windowEval.kappa();
                        } else {
                            this.m_dataPoint[1] = this.m_eval.rootMeanSquaredError();
                            this.m_dataPoint[2] = this.m_eval.kappa();
                        }
                    }
                    double primaryMeasure = 0.0;
                    primaryMeasure = !instance.classIsMissing() ? (this.m_windowSize > 0 ? 1.0 - this.m_windowEval.errorRate() : 1.0 - this.m_eval.errorRate()) : dist[Utils.maxIndex(dist)];
                    this.m_dataPoint[0] = primaryMeasure;
                    this.m_chartData.setPayloadElement("chart_min", 0.0);
                    this.m_chartData.setPayloadElement("chart_max", 1.0);
                    this.m_chartData.setPayloadElement("chart_legend", this.m_dataLegend);
                    this.m_chartData.setPayloadElement("chart_data_point", this.m_dataPoint);
                } else {
                    double update = !instance.classIsMissing() ? (this.m_windowSize > 0 ? this.m_windowEval.rootMeanSquaredError() : this.m_eval.rootMeanSquaredError()) : pred;
                    this.m_dataPoint[0] = update;
                    if (update > this.m_max) {
                        this.m_max = update;
                    }
                    if (update < this.m_min) {
                        this.m_min = update;
                    }
                    this.m_chartData.setPayloadElement("chart_min", instance.classIsMissing() ? this.m_min : 0.0);
                    this.m_chartData.setPayloadElement("chart_max", this.m_max);
                    this.m_chartData.setPayloadElement("chart_legend", this.m_dataLegend);
                    this.m_chartData.setPayloadElement("chart_data_point", this.m_dataPoint);
                }
                if (this.isStopRequested()) {
                    return;
                }
                this.getStepManager().throughputUpdateEnd();
                this.getStepManager().outputData(this.m_chartData.getConnectionName(), this.m_chartData);
            }
        }
        catch (Exception ex) {
            throw new WekaException(ex);
        }
    }

    public void setStatusFrequency(int s) {
        this.m_statusFrequency = s;
    }

    public int getStatusFrequency() {
        return this.m_statusFrequency;
    }

    public String statusFrequencyTipText() {
        return "How often to report progress to the status bar.";
    }

    public void setOutputPerClassInfoRetrievalStats(boolean i) {
        this.m_outputInfoRetrievalStats = i;
    }

    public boolean getOutputPerClassInfoRetrievalStats() {
        return this.m_outputInfoRetrievalStats;
    }

    public String outputPerClassInfoRetrievalStatsTipText() {
        return "Output per-class info retrieval stats. If set to true, predictions get stored so that stats such as AUC can be computed. Note: this consumes some memory.";
    }

    public void setChartingEvalWindowSize(int windowSize) {
        this.m_windowSize = windowSize;
    }

    public int getChartingEvalWindowSize() {
        return this.m_windowSize;
    }

    public String chartingEvalWindowSizeTipText() {
        return "For charting only, specify a sliding window size over which to compute performance stats. <= 0 means eval on whole stream";
    }
}

