/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.inference;

import eu.amidst.core.distribution.ConditionalDistribution;
import eu.amidst.core.distribution.Distribution;
import eu.amidst.core.distribution.Multinomial;
import eu.amidst.core.distribution.UnivariateDistribution;
import eu.amidst.core.exponentialfamily.EF_Distribution;
import eu.amidst.core.exponentialfamily.EF_UnivariateDistribution;
import eu.amidst.core.exponentialfamily.SufficientStatistics;
import eu.amidst.core.inference.InferenceAlgorithm;
import eu.amidst.core.inference.messagepassing.VMP;
import eu.amidst.core.io.BayesianNetworkLoader;
import eu.amidst.core.models.BayesianNetwork;
import eu.amidst.core.utils.LocalRandomGenerator;
import eu.amidst.core.utils.Serialization;
import eu.amidst.core.utils.Utils;
import eu.amidst.core.variables.Assignment;
import eu.amidst.core.variables.HashMapAssignment;
import eu.amidst.core.variables.Variable;
import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public class ImportanceSampling
implements InferenceAlgorithm,
Serializable {
    private static final long serialVersionUID = 8587756877237341367L;
    private BayesianNetwork model;
    private BayesianNetwork samplingModel;
    private boolean sameSamplingModel;
    private List<Variable> causalOrder;
    private int seed = 0;
    private int sampleSize = 10000;
    private boolean keepDataOnMemory = true;
    private List<WeightedAssignment> weightedSampleList;
    private Stream<WeightedAssignment> weightedSampleStream;
    private Assignment evidence;
    private boolean parallelMode = true;

    @Override
    public void setParallelMode(boolean parallelMode_) {
        this.parallelMode = parallelMode_;
    }

    @Override
    public void setSeed(int seed) {
        this.seed = seed;
    }

    @Override
    public void setModel(BayesianNetwork model_) {
        this.samplingModel = this.model = Serialization.deepCopy(model_);
        this.causalOrder = Utils.getTopologicalOrder(this.model.getDAG());
        this.sameSamplingModel = true;
        this.evidence = null;
        this.weightedSampleList = null;
        this.weightedSampleStream = null;
    }

    @Override
    public void setEvidence(Assignment evidence_) {
        this.evidence = evidence_;
        this.weightedSampleList = null;
        this.weightedSampleStream = null;
    }

    public void setSamplingModel(BayesianNetwork samplingModel_) {
        this.samplingModel = new BayesianNetwork(samplingModel_.getDAG(), Serialization.deepCopy(samplingModel_.getConditionalDistributions()));
        this.causalOrder = Utils.getTopologicalOrder(this.samplingModel.getDAG());
        this.sameSamplingModel = this.samplingModel.equalBNs(this.model, 1.0E-10);
    }

    public void setSampleSize(int sampleSize) {
        this.sampleSize = sampleSize;
    }

    public void setKeepDataOnMemory(boolean keepDataOnMemory) {
        this.keepDataOnMemory = keepDataOnMemory;
    }

    @Override
    public BayesianNetwork getOriginalModel() {
        return this.model;
    }

    public BayesianNetwork getSamplingModel() {
        return this.samplingModel;
    }

    @Override
    public double getLogProbabilityOfEvidence() {
        if (this.keepDataOnMemory) {
            this.weightedSampleStream = (Stream)this.weightedSampleList.stream().sequential();
        } else {
            this.computeWeightedSampleStream(false);
        }
        if (this.parallelMode) {
            this.weightedSampleStream.parallel();
        }
        return Math.log(this.weightedSampleStream.mapToDouble(ws -> Math.exp(((WeightedAssignment)ws).weight)).filter(Double::isFinite).average().getAsDouble());
    }

    public Stream<Assignment> getSamples() {
        if (this.keepDataOnMemory) {
            this.weightedSampleStream = (Stream)this.weightedSampleList.stream().sequential();
        }
        if (this.parallelMode) {
            this.weightedSampleStream.parallel();
        }
        return this.weightedSampleStream.map(wsl -> ((WeightedAssignment)wsl).assignment);
    }

    private WeightedAssignment getWeightedAssignmentSameModel(Random random) {
        HashMapAssignment sample = new HashMapAssignment(this.model.getNumberOfVars());
        double logWeight = 0.0;
        for (Variable samplingVar : this.causalOrder) {
            double simulatedValue;
            Object samplingDistribution = this.model.getConditionalDistribution(samplingVar);
            UnivariateDistribution univariateSamplingDistribution = ((ConditionalDistribution)samplingDistribution).getUnivariateDistribution(sample);
            if (this.evidence != null && !Double.isNaN(this.evidence.getValue(samplingVar))) {
                simulatedValue = this.evidence.getValue(samplingVar);
                logWeight += univariateSamplingDistribution.getLogProbability(simulatedValue);
            } else {
                simulatedValue = univariateSamplingDistribution.sample(random);
            }
            sample.setValue(samplingVar, simulatedValue);
        }
        return new WeightedAssignment(sample, logWeight);
    }

    private WeightedAssignment getWeightedAssignment(Random random) {
        if (this.sameSamplingModel) {
            return this.getWeightedAssignmentSameModel(random);
        }
        HashMapAssignment samplingAssignment = new HashMapAssignment(1);
        HashMapAssignment modelAssignment = new HashMapAssignment(1);
        double numerator = 0.0;
        double denominator = 0.0;
        for (Variable samplingVar : this.causalOrder) {
            double simulatedValue;
            Variable modelVar = this.model.getVariables().getVariableById(samplingVar.getVarID());
            if (this.evidence != null && !Double.isNaN(this.evidence.getValue(samplingVar))) {
                simulatedValue = this.evidence.getValue(samplingVar);
                UnivariateDistribution univariateModelDistribution = ((ConditionalDistribution)this.model.getConditionalDistribution(modelVar)).getUnivariateDistribution(modelAssignment);
                numerator += univariateModelDistribution.getLogProbability(simulatedValue);
            } else {
                Object samplingDistribution = this.samplingModel.getConditionalDistribution(samplingVar);
                UnivariateDistribution univariateSamplingDistribution = ((ConditionalDistribution)samplingDistribution).getUnivariateDistribution(samplingAssignment);
                simulatedValue = univariateSamplingDistribution.sample(random);
                denominator += univariateSamplingDistribution.getLogProbability(simulatedValue);
                UnivariateDistribution univariateModelDistribution = ((ConditionalDistribution)this.model.getConditionalDistribution(modelVar)).getUnivariateDistribution(modelAssignment);
                numerator += univariateModelDistribution.getLogProbability(simulatedValue);
            }
            modelAssignment.setValue(modelVar, simulatedValue);
            samplingAssignment.setValue(samplingVar, simulatedValue);
        }
        double logWeight = numerator - denominator;
        return new WeightedAssignment(samplingAssignment, logWeight);
    }

    @Override
    public double getExpectedValue(Variable var, Function<Double, Double> function) {
        if (this.keepDataOnMemory) {
            this.weightedSampleStream = (Stream)this.weightedSampleList.stream().sequential();
        } else {
            this.computeWeightedSampleStream(false);
        }
        if (this.parallelMode) {
            this.weightedSampleStream.parallel();
        }
        List sum = this.weightedSampleStream.map(ws -> Arrays.asList(Math.exp(((WeightedAssignment)ws).weight), Math.exp(((WeightedAssignment)ws).weight) * (Double)function.apply(((WeightedAssignment)ws).assignment.getValue(var)))).filter(array -> Double.isFinite((Double)array.get(0)) && Double.isFinite((Double)array.get(1))).reduce(Arrays.asList(new Double(0.0), new Double(0.0)), (e1, e2) -> Arrays.asList((Double)e1.get(0) + (Double)e2.get(0), (Double)e1.get(1) + (Double)e2.get(1)));
        return (Double)sum.get(1) / (Double)sum.get(0);
    }

    @Override
    public <E extends UnivariateDistribution> E getPosterior(Variable var) {
        Variable samplingVar = this.samplingModel.getVariables().getVariableByName(var.getName());
        Object ef_univariateDistribution = ((UnivariateDistribution)samplingVar.newUnivariateDistribution()).toEFUnivariateDistribution();
        AtomicInteger dataInstanceCount = new AtomicInteger(0);
        if (this.keepDataOnMemory) {
            this.weightedSampleStream = (Stream)this.weightedSampleList.stream().sequential();
        } else {
            this.computeWeightedSampleStream(false);
        }
        if (this.parallelMode) {
            this.weightedSampleStream.parallel();
        }
        if (!this.keepDataOnMemory) {
            this.weightedSampleList = this.weightedSampleStream.collect(Collectors.toList());
        }
        double maxLogWeight = this.weightedSampleList.stream().mapToDouble(weightetAssignment -> ((WeightedAssignment)weightetAssignment).weight).filter(Double::isFinite).max().getAsDouble();
        SufficientStatistics sumSS = this.weightedSampleStream.peek(w -> dataInstanceCount.getAndIncrement()).map(e -> {
            SufficientStatistics SS = ef_univariateDistribution.getSufficientStatistics(((WeightedAssignment)e).assignment);
            SS.multiplyBy(Math.exp(((WeightedAssignment)e).weight - maxLogWeight));
            return SS;
        }).filter(ss -> Double.isFinite(ss.sum())).reduce(SufficientStatistics::sumVectorNonStateless).get();
        sumSS.multiplyBy(Math.exp(maxLogWeight));
        sumSS.divideBy(dataInstanceCount.get());
        sumSS.divideBy(Math.exp(this.getLogProbabilityOfEvidence()));
        ((EF_Distribution)ef_univariateDistribution).setMomentParameters(sumSS);
        Object posteriorDistribution = ((EF_UnivariateDistribution)ef_univariateDistribution).toUnivariateDistribution();
        if (var.isMultinomial()) {
            double[] probabilities = ((Multinomial)posteriorDistribution).getProbabilities();
            double probMax = Arrays.stream(probabilities).max().getAsDouble();
            Arrays.stream(probabilities).map(prob -> prob / probMax);
            ((Multinomial)posteriorDistribution).setProbabilities(Utils.normalize(probabilities));
        }
        return (E)((UnivariateDistribution)posteriorDistribution);
    }

    private void computeWeightedSampleStream(boolean saveDataOnMemory_) {
        LocalRandomGenerator randomGenerator = new LocalRandomGenerator(this.seed);
        this.weightedSampleStream = this.parallelMode ? IntStream.range(0, this.sampleSize).parallel().mapToObj(i -> this.getWeightedAssignment(randomGenerator.current())) : IntStream.range(0, this.sampleSize).sequential().mapToObj(i -> this.getWeightedAssignment(randomGenerator.current()));
        if (saveDataOnMemory_) {
            this.weightedSampleList = this.weightedSampleStream.collect(Collectors.toList());
        }
    }

    @Override
    public void runInference() {
        if (this.keepDataOnMemory) {
            this.computeWeightedSampleStream(true);
        }
    }

    public static void main(String[] args) throws IOException, ClassNotFoundException {
        BayesianNetwork bn = BayesianNetworkLoader.loadFromFile("./networks/dataWeka/asia.bn");
        System.out.println(bn.toString());
        VMP vmp = new VMP();
        vmp.setModel(bn);
        vmp.runInference();
        ImportanceSampling importanceSampling = new ImportanceSampling();
        importanceSampling.setModel(bn);
        importanceSampling.setParallelMode(true);
        importanceSampling.setSampleSize(100);
        importanceSampling.setSeed(57457);
        importanceSampling.setKeepDataOnMemory(true);
        importanceSampling.runInference();
        List<Variable> causalOrder = importanceSampling.causalOrder;
        for (Variable var : causalOrder) {
            System.out.println("Posterior (IS) of " + var.getName() + ":" + ((Distribution)importanceSampling.getPosterior(var)).toString());
            System.out.println("Posterior (VMP) of " + var.getName() + ":" + ((Distribution)vmp.getPosterior(var)).toString());
        }
        Variable variable1 = causalOrder.get(1);
        Variable variable2 = causalOrder.get(2);
        int var1value = 0;
        int var2value = 0;
        System.out.println();
        System.out.println("Evidence: Variable " + variable1.getName() + " = " + var1value + " and Variable " + variable2.getName() + " = " + var2value);
        System.out.println();
        HashMapAssignment assignment = new HashMapAssignment(2);
        assignment.setValue(variable1, var1value);
        assignment.setValue(variable2, var2value);
        importanceSampling.setEvidence(assignment);
        importanceSampling.runInference();
        for (Variable var : causalOrder) {
            System.out.println("Posterior of " + var.getName() + " (IS with Evidence) :" + ((Distribution)importanceSampling.getPosterior(var)).toString());
        }
        System.out.printf("Prob. of Evidence: " + Math.exp(importanceSampling.getLogProbabilityOfEvidence()), new Object[0]);
    }

    private class WeightedAssignment {
        private HashMapAssignment assignment;
        private double weight;

        public WeightedAssignment(HashMapAssignment assignment_, double weight_) {
            this.assignment = assignment_;
            this.weight = weight_;
        }

        public String toString() {
            StringBuilder str = new StringBuilder();
            str.append("[ ");
            for (Map.Entry<Variable, Double> entry : this.assignment.entrySet()) {
                str.append(entry.getKey().getName() + " = " + entry.getValue());
                str.append(", ");
            }
            str.append("Weight = " + this.weight + " ]");
            return str.toString();
        }
    }
}

