/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.IteratedSingleClassifierEnhancer;
import weka.classifiers.rules.ZeroR;
import weka.classifiers.trees.DecisionStump;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.UnassignedClassException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class AdditiveRegression
extends IteratedSingleClassifierEnhancer
implements OptionHandler,
AdditionalMeasureProducer,
WeightedInstancesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = -2368937577670527151L;
    protected double m_shrinkage = 1.0;
    protected int m_NumIterationsPerformed;
    protected ZeroR m_zeroR;
    protected boolean m_SuitableData = true;

    public String globalInfo() {
        return " Meta classifier that enhances the performance of a regression base classifier. Each iteration fits a model to the residuals left by the classifier on the previous iteration. Prediction is accomplished by adding the predictions of each classifier. Reducing the shrinkage (learning rate) parameter helps prevent overfitting and has a smoothing effect but increases the learning time.\n\nFor more information see:\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.TECHREPORT);
        result.setValue(TechnicalInformation.Field.AUTHOR, "J.H. Friedman");
        result.setValue(TechnicalInformation.Field.YEAR, "1999");
        result.setValue(TechnicalInformation.Field.TITLE, "Stochastic Gradient Boosting");
        result.setValue(TechnicalInformation.Field.INSTITUTION, "Stanford University");
        result.setValue(TechnicalInformation.Field.PS, "http://www-stat.stanford.edu/~jhf/ftp/stobst.ps");
        return result;
    }

    public AdditiveRegression() {
        this(new DecisionStump());
    }

    public AdditiveRegression(Classifier classifier) {
        this.m_Classifier = classifier;
    }

    @Override
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.DecisionStump";
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>(1);
        newVector.addElement(new Option("\tSpecify shrinkage rate. (default = 1.0, ie. no shrinkage)\n", "S", 1, "-S"));
        newVector.addAll(Collections.list(super.listOptions()));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String optionString = Utils.getOption('S', options);
        if (optionString.length() != 0) {
            Double temp = Double.valueOf(optionString);
            this.setShrinkage(temp);
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> options = new Vector<String>();
        options.add("-S");
        options.add("" + this.getShrinkage());
        Collections.addAll(options, super.getOptions());
        return options.toArray(new String[0]);
    }

    public String shrinkageTipText() {
        return "Shrinkage rate. Smaller values help prevent overfitting and have a smoothing effect (but increase learning time). Default = 1.0, ie. no shrinkage.";
    }

    public void setShrinkage(double l) {
        this.m_shrinkage = l;
    }

    public double getShrinkage() {
        return this.m_shrinkage;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAllClasses();
        result.disableAllClassDependencies();
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        result.enable(Capabilities.Capability.DATE_CLASS);
        return result;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        int i;
        super.buildClassifier(data);
        this.getCapabilities().testWithFail(data);
        Instances newData = new Instances(data);
        newData.deleteWithMissingClass();
        double sum = 0.0;
        double temp_sum = 0.0;
        this.m_zeroR = new ZeroR();
        this.m_zeroR.buildClassifier(newData);
        if (newData.numAttributes() == 1) {
            System.err.println("Cannot build model (only class attribute present in data!), using ZeroR model instead!");
            this.m_SuitableData = false;
            return;
        }
        this.m_SuitableData = true;
        newData = this.residualReplace(newData, this.m_zeroR, false);
        for (i = 0; i < newData.numInstances(); ++i) {
            sum += newData.instance(i).weight() * newData.instance(i).classValue() * newData.instance(i).classValue();
        }
        if (this.m_Debug) {
            System.err.println("Sum of squared residuals (predicting the mean) : " + sum);
        }
        this.m_NumIterationsPerformed = 0;
        do {
            temp_sum = sum;
            this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(newData);
            newData = this.residualReplace(newData, this.m_Classifiers[this.m_NumIterationsPerformed], true);
            sum = 0.0;
            for (i = 0; i < newData.numInstances(); ++i) {
                sum += newData.instance(i).weight() * newData.instance(i).classValue() * newData.instance(i).classValue();
            }
            if (this.m_Debug) {
                System.err.println("Sum of squared residuals : " + sum);
            }
            ++this.m_NumIterationsPerformed;
        } while (temp_sum - sum > Utils.SMALL && this.m_NumIterationsPerformed < this.m_Classifiers.length);
    }

    @Override
    public double classifyInstance(Instance inst) throws Exception {
        double prediction = this.m_zeroR.classifyInstance(inst);
        if (!this.m_SuitableData) {
            return prediction;
        }
        for (int i = 0; i < this.m_NumIterationsPerformed; ++i) {
            double toAdd = this.m_Classifiers[i].classifyInstance(inst);
            if (Utils.isMissingValue(toAdd)) {
                throw new UnassignedClassException("AdditiveRegression: base learner predicted missing value.");
            }
            prediction += (toAdd *= this.getShrinkage());
        }
        return prediction;
    }

    private Instances residualReplace(Instances data, Classifier c, boolean useShrinkage) throws Exception {
        Instances newInst = new Instances(data);
        for (int i = 0; i < newInst.numInstances(); ++i) {
            double pred = c.classifyInstance(newInst.instance(i));
            if (Utils.isMissingValue(pred)) {
                throw new UnassignedClassException("AdditiveRegression: base learner predicted missing value.");
            }
            if (useShrinkage) {
                pred *= this.getShrinkage();
            }
            double residual = newInst.instance(i).classValue() - pred;
            newInst.instance(i).setClassValue(residual);
        }
        return newInst;
    }

    @Override
    public Enumeration<String> enumerateMeasures() {
        Vector<String> newVector = new Vector<String>(1);
        newVector.addElement("measureNumIterations");
        return newVector.elements();
    }

    @Override
    public double getMeasure(String additionalMeasureName) {
        if (additionalMeasureName.compareToIgnoreCase("measureNumIterations") == 0) {
            return this.measureNumIterations();
        }
        throw new IllegalArgumentException(additionalMeasureName + " not supported (AdditiveRegression)");
    }

    public double measureNumIterations() {
        return this.m_NumIterationsPerformed;
    }

    public String toString() {
        StringBuffer text = new StringBuffer();
        if (!this.m_SuitableData) {
            StringBuffer buf = new StringBuffer();
            buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
            buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");
            buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
            buf.append(this.m_zeroR.toString());
            return buf.toString();
        }
        if (this.m_NumIterations == 0) {
            return "Classifier hasn't been built yet!";
        }
        text.append("Additive Regression\n\n");
        text.append("ZeroR model\n\n" + this.m_zeroR + "\n\n");
        text.append("Base classifier " + this.getClassifier().getClass().getName() + "\n\n");
        text.append("" + this.m_NumIterationsPerformed + " models generated.\n");
        for (int i = 0; i < this.m_NumIterationsPerformed; ++i) {
            text.append("\nModel number " + i + "\n\n" + this.m_Classifiers[i] + "\n");
        }
        return text.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 10470 $");
    }

    public static void main(String[] argv) {
        AdditiveRegression.runClassifier(new AdditiveRegression(), argv);
    }
}

