/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.converter.mining;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
import org.dmg.pmml.Field;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.regression.RegressionModelUtil;

public class MiningModelUtil {
    private static final Function<Model, Feature> MODEL_PREDICTION = new Function<Model, Feature>(){

        @Override
        public Feature apply(Model model) {
            Output output = model.getOutput();
            if (output == null || !output.hasOutputFields()) {
                throw new IllegalArgumentException();
            }
            OutputField outputField = (OutputField)Iterables.getLast((Iterable)output.getOutputFields());
            return new ContinuousFeature(null, (Field<?>)outputField);
        }
    };

    private MiningModelUtil() {
    }

    public static MiningModel createRegression(Model model, RegressionModel.NormalizationMethod normalizationMethod, Schema schema) {
        Feature feature = MODEL_PREDICTION.apply(model);
        RegressionModel regressionModel = RegressionModelUtil.createRegression(model.getMathContext(), Collections.singletonList(feature), Collections.singletonList(1.0), null, normalizationMethod, schema);
        return MiningModelUtil.createModelChain(Arrays.asList(model, regressionModel));
    }

    public static MiningModel createBinaryLogisticClassification(Model model, double coefficient, double intercept, RegressionModel.NormalizationMethod normalizationMethod, boolean hasProbabilityDistribution, Schema schema) {
        Feature feature = MODEL_PREDICTION.apply(model);
        RegressionModel regressionModel = RegressionModelUtil.createBinaryLogisticClassification(model.getMathContext(), Collections.singletonList(feature), Collections.singletonList(coefficient), intercept, normalizationMethod, hasProbabilityDistribution, schema);
        return MiningModelUtil.createModelChain(Arrays.asList(model, regressionModel));
    }

    public static MiningModel createClassification(List<? extends Model> models, RegressionModel.NormalizationMethod normalizationMethod, boolean hasProbabilityDistribution, Schema schema) {
        CategoricalLabel categoricalLabel;
        block13: {
            block12: {
                categoricalLabel = (CategoricalLabel)schema.getLabel();
                SchemaUtil.checkSize(models.size(), categoricalLabel);
                if (normalizationMethod == null) break block12;
                switch (normalizationMethod) {
                    case NONE: {
                        if (categoricalLabel.size() < 3) {
                            throw new IllegalArgumentException();
                        }
                        break block13;
                    }
                    case SIMPLEMAX: 
                    case SOFTMAX: {
                        if (categoricalLabel.size() < 2) {
                            throw new IllegalArgumentException();
                        }
                        break block13;
                    }
                    default: {
                        throw new IllegalArgumentException();
                    }
                }
            }
            if (categoricalLabel.size() < 3) {
                throw new IllegalArgumentException();
            }
        }
        MathContext mathContext = null;
        ArrayList<RegressionTable> regressionTables = new ArrayList<RegressionTable>();
        for (int i = 0; i < categoricalLabel.size(); ++i) {
            Model model = models.get(i);
            MathContext modelMathContext = model.getMathContext();
            if (modelMathContext == null) {
                modelMathContext = MathContext.DOUBLE;
            }
            if (mathContext == null) {
                mathContext = modelMathContext;
            } else if (!Objects.equals(mathContext, modelMathContext)) {
                throw new IllegalArgumentException();
            }
            Feature feature = MODEL_PREDICTION.apply(model);
            RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(mathContext, Collections.singletonList(feature), Collections.singletonList(1.0), null).setTargetCategory(categoricalLabel.getValue(i));
            regressionTables.add(regressionTable);
        }
        RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables).setNormalizationMethod(normalizationMethod).setMathContext(ModelUtil.simplifyMathContext(mathContext)).setOutput(hasProbabilityDistribution ? ModelUtil.createProbabilityOutput(mathContext, categoricalLabel) : null);
        ArrayList<? extends Model> segmentationModels = new ArrayList<Model>(models);
        segmentationModels.add((Model)regressionModel);
        return MiningModelUtil.createModelChain(segmentationModels);
    }

    public static MiningModel createModelChain(List<? extends Model> models) {
        return MiningModelUtil.createModelChain(models, Segmentation.MissingPredictionTreatment.RETURN_MISSING);
    }

    public static MiningModel createModelChain(List<? extends Model> models, Segmentation.MissingPredictionTreatment missingPredictionTreatment) {
        if (models.size() < 1) {
            throw new IllegalArgumentException();
        }
        MiningSchema miningSchema = new MiningSchema();
        models.stream().map(Model::getMiningSchema).map(MiningSchema::getMiningFields).flatMap(Collection::stream).filter(miningField -> {
            MiningField.UsageType usageType = miningField.getUsageType();
            switch (usageType) {
                case PREDICTED: 
                case TARGET: {
                    return true;
                }
            }
            return false;
        }).map(MiningField::getName).distinct().map(name -> ModelUtil.createMiningField(name, MiningField.UsageType.TARGET)).forEach(xva$0 -> miningSchema.addMiningFields(new MiningField[]{xva$0}));
        Segmentation segmentation = MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MODEL_CHAIN, models).setMissingPredictionTreatment(missingPredictionTreatment);
        Model lastModel = (Model)Iterables.getLast(models);
        MiningModel miningModel = new MiningModel(lastModel.getMiningFunction(), miningSchema).setMathContext(ModelUtil.simplifyMathContext(lastModel.getMathContext())).setSegmentation(segmentation);
        return miningModel;
    }

    public static Segmentation createSegmentation(Segmentation.MultipleModelMethod multipleModelMethod, List<? extends Model> models) {
        return MiningModelUtil.createSegmentation(multipleModelMethod, models, null);
    }

    public static Segmentation createSegmentation(Segmentation.MultipleModelMethod multipleModelMethod, List<? extends Model> models, List<? extends Number> weights) {
        if (weights != null && models.size() != weights.size()) {
            throw new IllegalArgumentException();
        }
        ArrayList<Segment> segments = new ArrayList<Segment>();
        for (int i = 0; i < models.size(); ++i) {
            Model model = models.get(i);
            Number weight = weights != null ? (Number)weights.get(i) : (Number)null;
            Segment segment = new Segment((Predicate)True.INSTANCE, model).setId(String.valueOf(i + 1));
            if (weight != null && !ValueUtil.isOne(weight)) {
                segment.setWeight(weight);
            }
            segments.add(segment);
        }
        return new Segmentation(multipleModelMethod, segments);
    }

    public static Model getFinalModel(Model model) {
        if (model instanceof MiningModel) {
            MiningModel miningModel = (MiningModel)model;
            return MiningModelUtil.getFinalModel(miningModel);
        }
        return model;
    }

    public static Model getFinalModel(MiningModel miningModel) {
        Segmentation segmentation = miningModel.getSegmentation();
        Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
        switch (multipleModelMethod) {
            case SELECT_FIRST: 
            case SELECT_ALL: {
                throw new IllegalArgumentException();
            }
            case MODEL_CHAIN: {
                List segments = segmentation.getSegments();
                if (segments.isEmpty()) {
                    throw new IllegalArgumentException();
                }
                Segment finalSegment = (Segment)segments.get(segments.size() - 1);
                Predicate predicate = finalSegment.getPredicate();
                if (!(predicate instanceof True)) {
                    throw new IllegalArgumentException();
                }
                Model model = finalSegment.getModel();
                return MiningModelUtil.getFinalModel(model);
            }
        }
        return miningModel;
    }
}

