/*
 * Decompiled with CFR 0.152.
 */
package org.ohdsi.metaAnalysis;

import dr.inference.distribution.MultivariateDistributionLikelihood;
import dr.inference.hmc.CompoundGradient;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.JointGradient;
import dr.inference.loggers.Loggable;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.GradientProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.MatrixParameter;
import dr.inference.model.Parameter;
import dr.inference.operators.AdaptationMode;
import dr.inference.operators.OperatorSchedule;
import dr.inference.operators.RandomWalkOperator;
import dr.inference.operators.SimpleOperatorSchedule;
import dr.inference.operators.hmc.HamiltonianMonteCarloOperator;
import dr.inference.operators.hmc.MassPreconditioner;
import dr.inference.operators.hmc.MassPreconditioningOptions;
import dr.math.MathUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.ohdsi.likelihood.MultivariableCoxPartialLikelihood;
import org.ohdsi.mcmc.Analysis;
import org.ohdsi.mcmc.Runner;
import org.ohdsi.metaAnalysis.DataModel;
import org.ohdsi.metaAnalysis.MultivariatePrior;

public class MultivariableHierarchicalMetaAnalysis
implements Analysis {
    private final Likelihood likelihood;
    private final Likelihood prior;
    private final Likelihood joint;
    private final List<Parameter> parameters;
    private final OperatorSchedule schedule;

    public MultivariableHierarchicalMetaAnalysis(List<DataModel> dataModels, MultivariatePrior multivariatePrior, HierarchicalMetaAnalysisConfiguration cg) {
        MathUtils.setSeed((long)cg.seed);
        ArrayList<Likelihood> allDataLikelihoods = new ArrayList<Likelihood>();
        ArrayList<Object> allOperators = new ArrayList<Object>();
        ArrayList<Parameter> allParameters = new ArrayList<Parameter>();
        int analysisDim = dataModels.get(0).getCompoundParameter().getDimension();
        ArrayList<GradientWrtParameterProvider> likelihoodDerivativeList = cg.useHMC ? new ArrayList<GradientWrtParameterProvider>() : null;
        ArrayList<GradientWrtParameterProvider.ParameterWrapper> priorDerivativeList = cg.useHMC ? new ArrayList<GradientWrtParameterProvider.ParameterWrapper>() : null;
        MultivariateDistributionLikelihood mdl = (MultivariateDistributionLikelihood)multivariatePrior.getLikelihood(0);
        GradientProvider provider = cg.useHMC ? (GradientProvider)mdl.getDistribution() : null;
        for (DataModel singleAnalysis : dataModels) {
            if (analysisDim != singleAnalysis.getCompoundParameter().getDimension()) {
                throw new IllegalArgumentException("Mismatched regression dimensions");
            }
            allDataLikelihoods.add(singleAnalysis.getLikelihood());
            Parameter beta = singleAnalysis.getCompoundParameter();
            if (cg.useHMC) {
                likelihoodDerivativeList.add((GradientWrtParameterProvider)singleAnalysis.getLikelihood());
                priorDerivativeList.add(new GradientWrtParameterProvider.ParameterWrapper(provider, beta, (Likelihood)mdl));
                continue;
            }
            allParameters.add(beta);
            allOperators.add(new RandomWalkOperator(beta, null, 0.1, RandomWalkOperator.BoundaryCondition.reflecting, cg.operatorWeight * (double)beta.getDimension(), cg.mode));
        }
        if (cg.useHMC) {
            CompoundGradient priorGradient = new CompoundGradient(priorDerivativeList);
            CompoundGradient likelihoodGradient = new CompoundGradient(likelihoodDerivativeList);
            allParameters.add(priorGradient.getParameter());
            List<GradientWrtParameterProvider> jointGradientList = Arrays.asList(priorGradient, likelihoodGradient);
            JointGradient jointGradient = new JointGradient(jointGradientList);
            jointGradient.getGradientLogDensity();
            double stepSize = 1.8;
            int nSteps = 10;
            double randomStepFraction = 0.0;
            MassPreconditioningOptions.Default preconditioningOptions = new MassPreconditioningOptions.Default(10, 0, 0, 0, false, (Parameter)new Parameter.Default(0.01), (Parameter)new Parameter.Default(100.0));
            MassPreconditioner.Type preconditionerType = MassPreconditioner.Type.DIAGONAL;
            MassPreconditioner preconditioner = preconditionerType.factory(jointGradient, null, preconditioningOptions);
            HamiltonianMonteCarloOperator.Options runtimeOptions = new HamiltonianMonteCarloOperator.Options(1.8, 10, 0.0, preconditioningOptions, 0, 0.0, 10, 0.1, 0.8, HamiltonianMonteCarloOperator.InstabilityHandler.factory("reject"));
            allOperators.add(new HamiltonianMonteCarloOperator(AdaptationMode.ADAPTATION_ON, 0.2 * (double)priorDerivativeList.size(), jointGradient, jointGradient.getParameter(), null, null, runtimeOptions, preconditioner));
        }
        allParameters.addAll(multivariatePrior.getParameters());
        allOperators.addAll(multivariatePrior.getOperators(cg.operatorWeight, cg.mode));
        this.prior = multivariatePrior.getPrior();
        this.likelihood = new CompoundLikelihood(cg.threads, allDataLikelihoods);
        this.joint = new CompoundLikelihood(Arrays.asList(this.likelihood, this.prior));
        this.joint.setId("joint");
        this.parameters = allParameters;
        this.schedule = new SimpleOperatorSchedule(1000, 0.0);
        this.schedule.addOperators(allOperators);
    }

    @Override
    public List<Loggable> getLoggerColumns() {
        ArrayList<Loggable> columns = new ArrayList<Loggable>();
        columns.add((Loggable)this.likelihood);
        columns.add((Loggable)this.prior);
        columns.addAll(this.parameters);
        return columns;
    }

    @Override
    public Likelihood getJoint() {
        return this.joint;
    }

    @Override
    public OperatorSchedule getSchedule() {
        return this.schedule;
    }

    public static void main(String[] args) {
        int chainLength = 110000;
        int burnIn = 10000;
        int subSampleFrequency = 10;
        HierarchicalMetaAnalysisConfiguration cg = new HierarchicalMetaAnalysisConfiguration();
        List<DataModel> likelihoods = Arrays.asList(new MultivariableCoxPartialLikelihood((Parameter)new Parameter.Default(new double[]{-0.4608773, -0.1012988}), MultivariableCoxPartialLikelihood.exampleBladder()), new MultivariableCoxPartialLikelihood((Parameter)new Parameter.Default(new double[]{-0.5608773, -0.2012988}), MultivariableCoxPartialLikelihood.exampleBladder()), new MultivariableCoxPartialLikelihood((Parameter)new Parameter.Default(new double[]{-0.6608773, -0.3012988}), MultivariableCoxPartialLikelihood.exampleBladder()), new MultivariableCoxPartialLikelihood((Parameter)new Parameter.Default(new double[]{-0.7608773, -0.4012988}), MultivariableCoxPartialLikelihood.exampleBladder()));
        MultivariableHierarchicalMetaAnalysis analysis = new MultivariableHierarchicalMetaAnalysis(likelihoods, new MultivariatePrior.MultivariateNormal(likelihoods, cg), cg);
        Runner runner = new Runner(analysis, chainLength, burnIn, subSampleFrequency, cg.seed);
        runner.run();
        runner.processSamples();
        System.exit(0);
    }

    public static double[][] diagonalScaleMatrix(int dim, double diagonal) {
        double[][] scale = new double[dim][];
        int i = 0;
        while (i < dim) {
            scale[i] = new double[dim];
            scale[i][i] = diagonal;
            ++i;
        }
        return scale;
    }

    public static MatrixParameter diagonalMatrixParameter(String name, int dim, double diagonal) {
        Parameter[] columns = new Parameter[dim];
        int i = 0;
        while (i < dim) {
            columns[i] = new Parameter.Default(dim, 0.0);
            columns[i].setParameterValue(i, diagonal);
            ++i;
        }
        return new MatrixParameter(name, columns);
    }

    public static class HierarchicalMetaAnalysisConfiguration {
        public double tauShape = 1.0;
        public double tauScale = 1.0;
        public double tauDf = 3.0;
        public double muMean = 0.0;
        public double muSd = 2.0;
        public double startingMu = 0.0;
        public double startingTau = 1.0;
        AdaptationMode mode = AdaptationMode.ADAPTATION_ON;
        public double operatorWeight = 10.0;
        public long seed = 666L;
        public int threads = 1;
        public boolean useHMC = false;
    }
}

