/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.trees.lmt;

import java.util.Collections;
import java.util.Vector;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.SimpleLinearRegression;
import weka.classifiers.trees.j48.ClassifierSplitModel;
import weka.classifiers.trees.j48.ModelSelection;
import weka.classifiers.trees.lmt.CompareNode;
import weka.classifiers.trees.lmt.LogisticBase;
import weka.classifiers.trees.lmt.ResidualModelSelection;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.supervised.attribute.NominalToBinary;

public class LMTNode
extends LogisticBase {
    static final long serialVersionUID = 1862737145870398755L;
    protected double m_totalInstanceWeight;
    protected int m_id;
    protected int m_leafModelNum;
    public double m_alpha;
    public double m_numIncorrectModel;
    public double m_numIncorrectTree;
    protected int m_minNumInstances;
    protected ModelSelection m_modelSelection;
    protected NominalToBinary m_nominalToBinary;
    protected SimpleLinearRegression[][] m_higherRegressions;
    protected int m_numHigherRegressions = 0;
    protected static int m_numFoldsPruning = 5;
    protected boolean m_fastRegression;
    protected int m_numInstances;
    protected ClassifierSplitModel m_localModel;
    protected LMTNode[] m_sons;
    protected boolean m_isLeaf;

    public LMTNode(ModelSelection modelSelection, int n, boolean bl, boolean bl2, int n2, double d, boolean bl3) {
        this.m_modelSelection = modelSelection;
        this.m_fixedNumIterations = n;
        this.m_fastRegression = bl;
        this.m_errorOnProbabilities = bl2;
        this.m_minNumInstances = n2;
        this.m_maxIterations = 200;
        this.setWeightTrimBeta(d);
        this.setUseAIC(bl3);
    }

    public void buildClassifier(Instances instances) throws Exception {
        double d;
        int n;
        Object object;
        int n2;
        if (this.m_fastRegression && this.m_fixedNumIterations < 0) {
            this.m_fixedNumIterations = this.tryLogistic(instances);
        }
        Instances instances2 = new Instances(instances);
        instances2.stratify(m_numFoldsPruning);
        double[][] dArrayArray = new double[m_numFoldsPruning][];
        double[][] dArrayArray2 = new double[m_numFoldsPruning][];
        for (n2 = 0; n2 < m_numFoldsPruning; ++n2) {
            object = instances2.trainCV(m_numFoldsPruning, n2);
            Instances instances3 = instances2.testCV(m_numFoldsPruning, n2);
            this.buildTree((Instances)object, null, ((Instances)object).numInstances(), 0.0);
            int n3 = this.getNumInnerNodes();
            dArrayArray[n2] = new double[n3 + 2];
            dArrayArray2[n2] = new double[n3 + 2];
            this.prune(dArrayArray[n2], dArrayArray2[n2], instances3);
        }
        this.buildTree(instances, null, instances.numInstances(), 0.0);
        n2 = this.getNumInnerNodes();
        object = new double[n2 + 2];
        int n4 = this.prune((double[])object, null, null);
        double[] dArray = new double[n2 + 2];
        for (n = 0; n <= n4; ++n) {
            d = Math.sqrt((double)(object[n] * object[n + 1]));
            double d2 = 0.0;
            for (int i = 0; i < m_numFoldsPruning; ++i) {
                int n5 = 0;
                while (dArrayArray[i][n5] <= d) {
                    ++n5;
                }
                d2 += dArrayArray2[i][n5 - 1];
            }
            dArray[n] = d2;
        }
        n = -1;
        d = Double.MAX_VALUE;
        for (int i = n4; i >= 0; --i) {
            if (!(dArray[i] < d)) continue;
            d = dArray[i];
            n = i;
        }
        double d3 = Math.sqrt((double)(object[n] * object[n + 1]));
        this.unprune();
        this.prune(d3);
        this.cleanup();
    }

    public void buildTree(Instances instances, SimpleLinearRegression[][] simpleLinearRegressionArray, double d, double d2) throws Exception {
        boolean bl;
        Object object;
        this.m_totalInstanceWeight = d;
        this.m_train = new Instances(instances);
        this.m_isLeaf = true;
        this.m_sons = null;
        this.m_numInstances = this.m_train.numInstances();
        this.m_numClasses = this.m_train.numClasses();
        this.m_numericData = this.getNumericData(this.m_train);
        this.m_numericDataHeader = new Instances(this.m_numericData, 0);
        this.m_regressions = this.initRegressions();
        this.m_numRegressions = 0;
        this.m_higherRegressions = simpleLinearRegressionArray != null ? simpleLinearRegressionArray : new SimpleLinearRegression[this.m_numClasses][0];
        this.m_numHigherRegressions = this.m_higherRegressions[0].length;
        this.m_numParameters = d2;
        if (this.m_numInstances >= m_numFoldsBoosting) {
            if (this.m_fixedNumIterations > 0) {
                this.performBoosting(this.m_fixedNumIterations);
            } else if (this.getUseAIC()) {
                this.performBoostingInfCriterion();
            } else {
                this.performBoostingCV();
            }
        }
        this.m_numParameters += (double)this.m_numRegressions;
        this.m_regressions = this.selectRegressions(this.m_regressions);
        if (this.m_numInstances > this.m_minNumInstances) {
            if (this.m_modelSelection instanceof ResidualModelSelection) {
                object = this.getProbs(this.getFs(this.m_numericData));
                double[][] dArray = this.getYs(this.m_train);
                double[][] dArray2 = this.getZs((double[][])object, dArray);
                double[][] dArray3 = this.getWs((double[][])object, dArray);
                this.m_localModel = ((ResidualModelSelection)this.m_modelSelection).selectModel(this.m_train, dArray2, dArray3);
            } else {
                this.m_localModel = this.m_modelSelection.selectModel(this.m_train);
            }
            bl = this.m_localModel.numSubsets() > 1;
        } else {
            bl = false;
        }
        if (bl) {
            this.m_isLeaf = false;
            object = this.m_localModel.split(this.m_train);
            this.m_sons = new LMTNode[this.m_localModel.numSubsets()];
            for (int i = 0; i < this.m_sons.length; ++i) {
                this.m_sons[i] = new LMTNode(this.m_modelSelection, this.m_fixedNumIterations, this.m_fastRegression, this.m_errorOnProbabilities, this.m_minNumInstances, this.getWeightTrimBeta(), this.getUseAIC());
                this.m_sons[i].buildTree((Instances)object[i], this.mergeArrays(this.m_regressions, this.m_higherRegressions), this.m_totalInstanceWeight, this.m_numParameters);
                object[i] = null;
            }
        }
    }

    public void prune(double d) throws Exception {
        boolean bl;
        CompareNode compareNode = new CompareNode();
        this.modelErrors();
        this.treeErrors();
        this.calculateAlphas();
        Vector vector = this.getNodes();
        boolean bl2 = bl = vector.size() > 0;
        while (bl) {
            LMTNode lMTNode = (LMTNode)Collections.min(vector, compareNode);
            if (lMTNode.m_alpha > d) break;
            lMTNode.m_isLeaf = true;
            lMTNode.m_sons = null;
            this.treeErrors();
            this.calculateAlphas();
            vector = this.getNodes();
            bl = vector.size() > 0;
        }
    }

    public int prune(double[] dArray, double[] dArray2, Instances instances) throws Exception {
        Evaluation evaluation;
        CompareNode compareNode = new CompareNode();
        this.modelErrors();
        this.treeErrors();
        this.calculateAlphas();
        Vector vector = this.getNodes();
        boolean bl = vector.size() > 0;
        dArray[0] = 0.0;
        if (dArray2 != null) {
            evaluation = new Evaluation(instances);
            evaluation.evaluateModel(this, instances);
            dArray2[0] = evaluation.errorRate();
        }
        int n = 0;
        while (bl) {
            LMTNode lMTNode = (LMTNode)Collections.min(vector, compareNode);
            lMTNode.m_isLeaf = true;
            dArray[++n] = lMTNode.m_alpha;
            if (dArray2 != null) {
                evaluation = new Evaluation(instances);
                evaluation.evaluateModel(this, instances);
                dArray2[n] = evaluation.errorRate();
            }
            this.treeErrors();
            this.calculateAlphas();
            vector = this.getNodes();
            bl = vector.size() > 0;
        }
        dArray[n + 1] = 1.0;
        return n;
    }

    protected void unprune() {
        if (this.m_sons != null) {
            this.m_isLeaf = false;
            for (int i = 0; i < this.m_sons.length; ++i) {
                this.m_sons[i].unprune();
            }
        }
    }

    protected int tryLogistic(Instances instances) throws Exception {
        Instances instances2 = new Instances(instances);
        NominalToBinary nominalToBinary = new NominalToBinary();
        nominalToBinary.setInputFormat(instances2);
        instances2 = Filter.useFilter(instances2, nominalToBinary);
        LogisticBase logisticBase = new LogisticBase(0, true, this.m_errorOnProbabilities);
        logisticBase.setMaxIterations(200);
        logisticBase.setWeightTrimBeta(this.getWeightTrimBeta());
        logisticBase.setUseAIC(this.getUseAIC());
        logisticBase.buildClassifier(instances2);
        return logisticBase.getNumRegressions();
    }

    public int getNumInnerNodes() {
        if (this.m_isLeaf) {
            return 0;
        }
        int n = 1;
        for (int i = 0; i < this.m_sons.length; ++i) {
            n += this.m_sons[i].getNumInnerNodes();
        }
        return n;
    }

    public int getNumLeaves() {
        int n;
        if (!this.m_isLeaf) {
            n = 0;
            int n2 = 0;
            for (int i = 0; i < this.m_sons.length; ++i) {
                n += this.m_sons[i].getNumLeaves();
                if (!this.m_sons[i].m_isLeaf || this.m_sons[i].hasModels()) continue;
                ++n2;
            }
            if (n2 > 1) {
                n -= n2 - 1;
            }
        } else {
            n = 1;
        }
        return n;
    }

    public void modelErrors() throws Exception {
        Evaluation evaluation = new Evaluation(this.m_train);
        if (!this.m_isLeaf) {
            this.m_isLeaf = true;
            evaluation.evaluateModel(this, this.m_train);
            this.m_isLeaf = false;
            this.m_numIncorrectModel = evaluation.incorrect();
            for (int i = 0; i < this.m_sons.length; ++i) {
                this.m_sons[i].modelErrors();
            }
        } else {
            evaluation.evaluateModel(this, this.m_train);
            this.m_numIncorrectModel = evaluation.incorrect();
        }
    }

    public void treeErrors() {
        if (this.m_isLeaf) {
            this.m_numIncorrectTree = this.m_numIncorrectModel;
        } else {
            this.m_numIncorrectTree = 0.0;
            for (int i = 0; i < this.m_sons.length; ++i) {
                this.m_sons[i].treeErrors();
                this.m_numIncorrectTree += this.m_sons[i].m_numIncorrectTree;
            }
        }
    }

    public void calculateAlphas() throws Exception {
        if (!this.m_isLeaf) {
            double d = this.m_numIncorrectModel - this.m_numIncorrectTree;
            if (d <= 0.0) {
                this.m_isLeaf = true;
                this.m_sons = null;
                this.m_alpha = Double.MAX_VALUE;
            } else {
                this.m_alpha = (d /= this.m_totalInstanceWeight) / (double)(this.getNumLeaves() - 1);
                for (int i = 0; i < this.m_sons.length; ++i) {
                    this.m_sons[i].calculateAlphas();
                }
            }
        } else {
            this.m_alpha = Double.MAX_VALUE;
        }
    }

    protected SimpleLinearRegression[][] mergeArrays(SimpleLinearRegression[][] simpleLinearRegressionArray, SimpleLinearRegression[][] simpleLinearRegressionArray2) {
        int n;
        int n2;
        int n3 = simpleLinearRegressionArray[0].length;
        int n4 = simpleLinearRegressionArray2[0].length;
        SimpleLinearRegression[][] simpleLinearRegressionArray3 = new SimpleLinearRegression[this.m_numClasses][n3 + n4];
        for (n2 = 0; n2 < this.m_numClasses; ++n2) {
            for (n = 0; n < n3; ++n) {
                simpleLinearRegressionArray3[n2][n] = simpleLinearRegressionArray[n2][n];
            }
        }
        for (n2 = 0; n2 < this.m_numClasses; ++n2) {
            for (n = 0; n < n4; ++n) {
                simpleLinearRegressionArray3[n2][n + n3] = simpleLinearRegressionArray2[n2][n];
            }
        }
        return simpleLinearRegressionArray3;
    }

    public Vector getNodes() {
        Vector vector = new Vector();
        this.getNodes(vector);
        return vector;
    }

    public void getNodes(Vector vector) {
        if (!this.m_isLeaf) {
            vector.add(this);
            for (int i = 0; i < this.m_sons.length; ++i) {
                this.m_sons[i].getNodes(vector);
            }
        }
    }

    protected Instances getNumericData(Instances instances) throws Exception {
        Instances instances2 = new Instances(instances);
        this.m_nominalToBinary = new NominalToBinary();
        this.m_nominalToBinary.setInputFormat(instances2);
        instances2 = Filter.useFilter(instances2, this.m_nominalToBinary);
        return super.getNumericData(instances2);
    }

    protected double[] getFs(Instance instance) throws Exception {
        double[] dArray = new double[this.m_numClasses];
        double[] dArray2 = super.getFs(instance);
        for (int i = 0; i < this.m_numHigherRegressions; ++i) {
            int n;
            double d = 0.0;
            for (n = 0; n < this.m_numClasses; ++n) {
                dArray[n] = this.m_higherRegressions[n][i].classifyInstance(instance);
                d += dArray[n];
            }
            d /= (double)this.m_numClasses;
            for (n = 0; n < this.m_numClasses; ++n) {
                int n2 = n;
                dArray2[n2] = dArray2[n2] + (dArray[n] - d) * (double)(this.m_numClasses - 1) / (double)this.m_numClasses;
            }
        }
        return dArray2;
    }

    public boolean hasModels() {
        return this.m_numRegressions > 0;
    }

    public double[] modelDistributionForInstance(Instance instance) throws Exception {
        instance = (Instance)instance.copy();
        this.m_nominalToBinary.input(instance);
        instance = this.m_nominalToBinary.output();
        instance.setDataset(this.m_numericDataHeader);
        return this.probs(this.getFs(instance));
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] dArray;
        if (this.m_isLeaf) {
            dArray = this.modelDistributionForInstance(instance);
        } else {
            int n = this.m_localModel.whichSubset(instance);
            dArray = this.m_sons[n].distributionForInstance(instance);
        }
        return dArray;
    }

    public int numLeaves() {
        if (this.m_isLeaf) {
            return 1;
        }
        int n = 0;
        for (int i = 0; i < this.m_sons.length; ++i) {
            n += this.m_sons[i].numLeaves();
        }
        return n;
    }

    public int numNodes() {
        if (this.m_isLeaf) {
            return 1;
        }
        int n = 1;
        for (int i = 0; i < this.m_sons.length; ++i) {
            n += this.m_sons[i].numNodes();
        }
        return n;
    }

    public String toString() {
        this.assignLeafModelNumbers(0);
        try {
            StringBuffer stringBuffer = new StringBuffer();
            if (this.m_isLeaf) {
                stringBuffer.append(": ");
                stringBuffer.append("LM_" + this.m_leafModelNum + ":" + this.getModelParameters());
            } else {
                this.dumpTree(0, stringBuffer);
            }
            stringBuffer.append("\n\nNumber of Leaves  : \t" + this.numLeaves() + "\n");
            stringBuffer.append("\nSize of the Tree : \t" + this.numNodes() + "\n");
            stringBuffer.append(this.modelsToString());
            return stringBuffer.toString();
        }
        catch (Exception exception) {
            return "Can't print logistic model tree";
        }
    }

    public String getModelParameters() {
        StringBuffer stringBuffer = new StringBuffer();
        int n = this.m_numRegressions + this.m_numHigherRegressions;
        stringBuffer.append(this.m_numRegressions + "/" + n + " (" + this.m_numInstances + ")");
        return stringBuffer.toString();
    }

    protected void dumpTree(int n, StringBuffer stringBuffer) throws Exception {
        for (int i = 0; i < this.m_sons.length; ++i) {
            stringBuffer.append("\n");
            for (int j = 0; j < n; ++j) {
                stringBuffer.append("|   ");
            }
            stringBuffer.append(this.m_localModel.leftSide(this.m_train));
            stringBuffer.append(this.m_localModel.rightSide(i, this.m_train));
            if (this.m_sons[i].m_isLeaf) {
                stringBuffer.append(": ");
                stringBuffer.append("LM_" + this.m_sons[i].m_leafModelNum + ":" + this.m_sons[i].getModelParameters());
                continue;
            }
            this.m_sons[i].dumpTree(n + 1, stringBuffer);
        }
    }

    public int assignIDs(int n) {
        int n2;
        this.m_id = n2 = n + 1;
        if (this.m_sons != null) {
            for (int i = 0; i < this.m_sons.length; ++i) {
                n2 = this.m_sons[i].assignIDs(n2);
            }
        }
        return n2;
    }

    public int assignLeafModelNumbers(int n) {
        if (!this.m_isLeaf) {
            this.m_leafModelNum = 0;
            for (int i = 0; i < this.m_sons.length; ++i) {
                n = this.m_sons[i].assignLeafModelNumbers(n);
            }
        } else {
            this.m_leafModelNum = ++n;
        }
        return n;
    }

    protected double[][] getCoefficients() {
        double[][] dArray = super.getCoefficients();
        double d = (double)(this.m_numClasses - 1) / (double)this.m_numClasses;
        for (int i = 0; i < this.m_numClasses; ++i) {
            for (int j = 0; j < this.m_numHigherRegressions; ++j) {
                double d2 = this.m_higherRegressions[i][j].getSlope();
                double d3 = this.m_higherRegressions[i][j].getIntercept();
                int n = this.m_higherRegressions[i][j].getAttributeIndex();
                double[] dArray2 = dArray[i];
                dArray2[0] = dArray2[0] + d * d3;
                double[] dArray3 = dArray[i];
                int n2 = n + 1;
                dArray3[n2] = dArray3[n2] + d * d2;
            }
        }
        return dArray;
    }

    public String modelsToString() {
        StringBuffer stringBuffer = new StringBuffer();
        if (this.m_isLeaf) {
            stringBuffer.append("LM_" + this.m_leafModelNum + ":" + super.toString());
        } else {
            for (int i = 0; i < this.m_sons.length; ++i) {
                stringBuffer.append("\n" + this.m_sons[i].modelsToString());
            }
        }
        return stringBuffer.toString();
    }

    public String graph() throws Exception {
        StringBuffer stringBuffer = new StringBuffer();
        this.assignIDs(-1);
        this.assignLeafModelNumbers(0);
        stringBuffer.append("digraph LMTree {\n");
        if (this.m_isLeaf) {
            stringBuffer.append("N" + this.m_id + " [label=\"LM_" + this.m_leafModelNum + ":" + this.getModelParameters() + "\" " + "shape=box style=filled");
            stringBuffer.append("]\n");
        } else {
            stringBuffer.append("N" + this.m_id + " [label=\"" + this.m_localModel.leftSide(this.m_train) + "\" ");
            stringBuffer.append("]\n");
            this.graphTree(stringBuffer);
        }
        return stringBuffer.toString() + "}\n";
    }

    private void graphTree(StringBuffer stringBuffer) throws Exception {
        for (int i = 0; i < this.m_sons.length; ++i) {
            stringBuffer.append("N" + this.m_id + "->" + "N" + this.m_sons[i].m_id + " [label=\"" + this.m_localModel.rightSide(i, this.m_train).trim() + "\"]\n");
            if (this.m_sons[i].m_isLeaf) {
                stringBuffer.append("N" + this.m_sons[i].m_id + " [label=\"LM_" + this.m_sons[i].m_leafModelNum + ":" + this.m_sons[i].getModelParameters() + "\" " + "shape=box style=filled");
                stringBuffer.append("]\n");
                continue;
            }
            stringBuffer.append("N" + this.m_sons[i].m_id + " [label=\"" + this.m_sons[i].m_localModel.leftSide(this.m_train) + "\" ");
            stringBuffer.append("]\n");
            this.m_sons[i].graphTree(stringBuffer);
        }
    }

    public void cleanup() {
        super.cleanup();
        if (!this.m_isLeaf) {
            for (int i = 0; i < this.m_sons.length; ++i) {
                this.m_sons[i].cleanup();
            }
        }
    }
}

