/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.xgboost;

import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import java.io.DataInput;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.BaseNFeature;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.MissingValueFeature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.visitors.NaNAsMissingDecorator;
import org.jpmml.xgboost.BinaryLoadable;
import org.jpmml.xgboost.BinomialLogisticRegression;
import org.jpmml.xgboost.Dart;
import org.jpmml.xgboost.FeatureMap;
import org.jpmml.xgboost.GBTree;
import org.jpmml.xgboost.GeneralizedLinearRegression;
import org.jpmml.xgboost.HingeClassification;
import org.jpmml.xgboost.JSONLoadable;
import org.jpmml.xgboost.LambdaMART;
import org.jpmml.xgboost.LinearRegression;
import org.jpmml.xgboost.LogisticRegression;
import org.jpmml.xgboost.MultinomialLogisticRegression;
import org.jpmml.xgboost.ObjFunction;
import org.jpmml.xgboost.PoissonRegression;
import org.jpmml.xgboost.XGBoostDataInput;
import org.jpmml.xgboost.XGBoostEncoder;
import org.jpmml.xgboost.visitors.TreeModelCompactor;

public class Learner
implements BinaryLoadable,
JSONLoadable {
    private float base_score;
    private int num_feature;
    private int num_class;
    private int contain_extra_attrs;
    private int contain_eval_metrics;
    private int major_version;
    private int minor_version;
    private ObjFunction obj;
    private GBTree gbtree;
    private Map<String, String> attributes = null;
    private String[] metrics = null;

    @Override
    public void loadBinary(XGBoostDataInput input) throws IOException {
        this.base_score = input.readFloat();
        this.num_feature = input.readInt();
        this.num_class = input.readInt();
        this.contain_extra_attrs = input.readInt();
        this.contain_eval_metrics = input.readInt();
        this.major_version = input.readInt();
        this.minor_version = input.readInt();
        if (this.major_version < 0 || this.major_version > 1) {
            throw new IllegalArgumentException(this.major_version + "." + this.minor_version);
        }
        input.readReserved(27);
        String name_obj = input.readString();
        this.obj = this.parseObjective(name_obj);
        this.base_score = this.major_version >= 1 ? this.obj.probToMargin(this.base_score) + 0.0f : this.base_score;
        String name_gbm = input.readString();
        this.gbtree = this.parseGradientBooster(name_gbm);
        this.gbtree.loadBinary(input);
        if (this.contain_extra_attrs != 0) {
            this.attributes = input.readStringMap();
        }
        if (this.major_version >= 1) {
            return;
        }
        if (this.obj instanceof PoissonRegression) {
            try {
                String max_delta_step = input.readString();
            }
            catch (EOFException eOFException) {
                // empty catch block
            }
        }
        if (this.contain_eval_metrics != 0) {
            this.metrics = input.readStringVector();
        }
    }

    @Override
    public void loadJSON(JsonObject root) {
        JsonArray version = root.getAsJsonArray("version");
        this.major_version = version.get(0).getAsInt();
        this.minor_version = version.get(1).getAsInt();
        if (this.major_version < 1 || this.minor_version < 3) {
            throw new IllegalArgumentException();
        }
        JsonObject learner = root.getAsJsonObject("learner");
        JsonObject learnerModelParam = learner.getAsJsonObject("learner_model_param");
        this.base_score = learnerModelParam.getAsJsonPrimitive("base_score").getAsFloat();
        this.num_feature = learnerModelParam.getAsJsonPrimitive("num_feature").getAsInt();
        this.num_class = learnerModelParam.getAsJsonPrimitive("num_class").getAsInt();
        JsonObject objective = learner.getAsJsonObject("objective");
        String name_obj = objective.getAsJsonPrimitive("name").getAsString();
        this.obj = this.parseObjective(name_obj);
        this.base_score = this.obj.probToMargin(this.base_score) + 0.0f;
        JsonObject gradientBooster = learner.getAsJsonObject("gradient_booster");
        String name_gbm = gradientBooster.getAsJsonPrimitive("name").getAsString();
        this.gbtree = this.parseGradientBooster(name_gbm);
        this.gbtree.loadJSON(gradientBooster);
    }

    public <DIS extends InputStream> void loadBinary(DIS is, String charset) throws IOException {
        long offset;
        boolean hasSerializationHeader = Learner.consumeHeader(is, "CONFIG-offset:");
        if (hasSerializationHeader && (offset = ((DataInput)((Object)is)).readLong()) < 0L) {
            throw new IOException();
        }
        boolean hasBInfHeader = Learner.consumeHeader(is, "binf");
        if (hasBInfHeader) {
            // empty if block
        }
        try (XGBoostDataInput input = new XGBoostDataInput(is, charset);){
            this.loadBinary(input);
            if (hasSerializationHeader) {
            } else {
                int eof = is.read();
                if (eof != -1) {
                    throw new IOException();
                }
            }
        }
    }

    public void loadJSON(InputStream is, String charset, String jsonPath) throws IOException {
        JsonParser parser = new JsonParser();
        if (charset == null) {
            charset = "UTF-8";
        }
        try (InputStreamReader reader = new InputStreamReader(is, charset);){
            JsonElement element = parser.parse((Reader)reader);
            JsonObject object = element.getAsJsonObject();
            String[] names = jsonPath.split("\\.");
            for (int i = 0; i < names.length; ++i) {
                String name = names[i];
                if (i == 0) {
                    if ("$".equals(name)) continue;
                    throw new IllegalArgumentException(jsonPath);
                }
                object = object.getAsJsonObject(name);
            }
            this.loadJSON(object);
            int eof = is.read();
            if (eof != -1) {
                throw new IOException();
            }
        }
    }

    public Schema encodeSchema(FieldName targetField, List<String> targetCategories, FeatureMap featureMap, XGBoostEncoder encoder) {
        if (targetField == null) {
            targetField = FieldName.create((String)"_target");
        }
        Label label = this.obj.encodeLabel(targetField, targetCategories, (PMMLEncoder)encoder);
        List<Feature> features = featureMap.encodeFeatures((PMMLEncoder)encoder);
        return new Schema((PMMLEncoder)encoder, label, features);
    }

    public Schema toXGBoostSchema(Schema schema) {
        Function<Feature, Feature> function = new Function<Feature, Feature>(){

            @Override
            public Feature apply(Feature feature) {
                if (feature instanceof BaseNFeature) {
                    BaseNFeature baseFeature = (BaseNFeature)feature;
                    return baseFeature;
                }
                if (feature instanceof BinaryFeature) {
                    BinaryFeature binaryFeature = (BinaryFeature)feature;
                    return binaryFeature;
                }
                if (feature instanceof MissingValueFeature) {
                    MissingValueFeature missingValueFeature = (MissingValueFeature)feature;
                    return missingValueFeature;
                }
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                DataType dataType = continuousFeature.getDataType();
                switch (dataType) {
                    case INTEGER: 
                    case FLOAT: {
                        break;
                    }
                    case DOUBLE: {
                        continuousFeature = continuousFeature.toContinuousFeature(DataType.FLOAT);
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException("Expected integer, float or double data type for continuous feature " + continuousFeature.getName() + ", got " + dataType.value() + " data type");
                    }
                }
                return continuousFeature;
            }
        };
        return schema.toTransformedSchema((Function)function);
    }

    public PMML encodePMML(Map<String, ?> options, FieldName targetField, List<String> targetCategories, FeatureMap featureMap) {
        XGBoostEncoder encoder = new XGBoostEncoder();
        Boolean nanAsMissing = (Boolean)options.get("nan_as_missing");
        Schema schema = this.encodeSchema(targetField, targetCategories, featureMap, encoder);
        MiningModel miningModel = this.encodeMiningModel(options, schema);
        PMML pmml = encoder.encodePMML((Model)miningModel);
        if (Boolean.TRUE.equals(nanAsMissing)) {
            NaNAsMissingDecorator visitor = new NaNAsMissingDecorator();
            visitor.applyTo((Visitable)pmml);
        }
        return pmml;
    }

    public MiningModel encodeMiningModel(Map<String, ?> options, Schema schema) {
        Boolean compact = (Boolean)options.get("compact");
        Integer ntreeLimit = (Integer)options.get("ntree_limit");
        MiningModel miningModel = this.gbtree.encodeMiningModel(this.obj, this.base_score, ntreeLimit, schema).setAlgorithmName("XGBoost (" + this.gbtree.getAlgorithmName() + ")");
        if (Boolean.TRUE.equals(compact)) {
            TreeModelCompactor visitor = new TreeModelCompactor();
            visitor.applyTo((Visitable)miningModel);
        }
        return miningModel;
    }

    public int num_feature() {
        return this.num_feature;
    }

    public int num_class() {
        return this.num_class;
    }

    public ObjFunction obj() {
        return this.obj;
    }

    private GBTree parseGradientBooster(String name_gbm) {
        switch (name_gbm) {
            case "gbtree": {
                return new GBTree();
            }
            case "dart": {
                return new Dart();
            }
        }
        throw new IllegalArgumentException(name_gbm);
    }

    private ObjFunction parseObjective(String name_obj) {
        switch (name_obj) {
            case "reg:linear": 
            case "reg:squarederror": 
            case "reg:squaredlogerror": {
                return new LinearRegression();
            }
            case "reg:logistic": {
                return new LogisticRegression();
            }
            case "reg:gamma": 
            case "reg:tweedie": {
                return new GeneralizedLinearRegression();
            }
            case "count:poisson": {
                return new PoissonRegression();
            }
            case "binary:hinge": {
                return new HingeClassification();
            }
            case "binary:logistic": {
                return new BinomialLogisticRegression();
            }
            case "rank:map": 
            case "rank:ndcg": 
            case "rank:pairwise": {
                return new LambdaMART();
            }
            case "multi:softmax": 
            case "multi:softprob": {
                return new MultinomialLogisticRegression(this.num_class);
            }
        }
        throw new IllegalArgumentException(name_obj);
    }

    private static <DIS extends InputStream> boolean consumeHeader(DIS is, String header) throws IOException {
        byte[] headerBytes = header.getBytes(StandardCharsets.UTF_8);
        byte[] buffer = new byte[headerBytes.length];
        is.mark(buffer.length);
        ((DataInput)((Object)is)).readFully(buffer);
        boolean equals = Arrays.equals(headerBytes, buffer);
        if (!equals) {
            is.reset();
        }
        return equals;
    }
}

