/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.hmc;

import dr.evomodel.treedatalikelihood.continuous.BranchSpecificGradient;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.DiagonalMatrix;
import dr.inference.model.Likelihood;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;
import java.util.List;

public abstract class AbstractDiffusionGradient
implements GradientWrtParameterProvider,
Reportable {
    private final Likelihood likelihood;
    private final double lowerBound;
    private final double upperBound;
    protected int offset;

    AbstractDiffusionGradient(Likelihood likelihood, double d, double d2) {
        this.likelihood = likelihood;
        this.lowerBound = d2;
        this.upperBound = d;
        this.offset = 0;
    }

    public abstract double[] getGradientLogDensity(double[] var1);

    public abstract Parameter getRawParameter();

    public void setOffset(int n) {
        this.offset = n;
    }

    public abstract ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter getDerivationParameter();

    @Override
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    protected Parameter getNumericalParameter() {
        return this.getParameter();
    }

    @Override
    public String getReport() {
        return GradientWrtParameterProvider.getReportAndCheckForError(this, this.lowerBound, this.upperBound, TOLERANCE);
    }

    String getReportString(double[] dArray, double[] dArray2) {
        return this.getClass().getCanonicalName() + "\nanalytic: " + new Vector(dArray) + "\nnumeric: " + new Vector(dArray2) + "\n";
    }

    String getReportString(double[] dArray, double[] dArray2, double[] dArray3) {
        return this.getClass().getCanonicalName() + "\nanalytic: " + new Vector(dArray) + "\nnumeric (no Cholesky): " + new Vector(dArray2) + "\nnumeric (with Cholesky): " + new Vector(dArray3) + "\n";
    }

    MultivariateFunction getNumeric() {
        return new MultivariateFunction(){

            @Override
            public double evaluate(double[] dArray) {
                for (int i = 0; i < dArray.length; ++i) {
                    AbstractDiffusionGradient.this.getNumericalParameter().setParameterValue(i, dArray[i]);
                }
                AbstractDiffusionGradient.this.likelihood.makeDirty();
                System.err.println("likelihood in numeric:" + AbstractDiffusionGradient.this.likelihood.getLogLikelihood());
                return AbstractDiffusionGradient.this.likelihood.getLogLikelihood();
            }

            @Override
            public int getNumArguments() {
                return AbstractDiffusionGradient.this.getDimension();
            }

            @Override
            public double getLowerBound(int n) {
                return AbstractDiffusionGradient.this.lowerBound;
            }

            @Override
            public double getUpperBound(int n) {
                return AbstractDiffusionGradient.this.upperBound;
            }
        };
    }

    String checkNumeric(double[] dArray) {
        System.err.println("Numeric at: \n" + new Vector(this.getNumericalParameter().getParameterValues()));
        double[] dArray2 = this.getNumericalParameter().getParameterValues();
        double[] dArray3 = NumericalDerivative.gradient(this.getNumeric(), dArray2);
        for (int i = 0; i < dArray2.length; ++i) {
            this.getNumericalParameter().setParameterValue(i, dArray2[i]);
        }
        return this.getReportString(dArray, dArray3);
    }

    public static class ParameterDiffusionGradient
    extends AbstractDiffusionGradient
    implements Reportable {
        protected final int dim;
        private final BranchSpecificGradient branchSpecificGradient;
        private final Parameter parameter;
        private final Parameter rawParameter;

        ParameterDiffusionGradient(BranchSpecificGradient branchSpecificGradient, Likelihood likelihood, Parameter parameter, Parameter parameter2, double d, double d2) {
            super(likelihood, d, d2);
            this.parameter = parameter;
            this.rawParameter = parameter2;
            this.branchSpecificGradient = branchSpecificGradient;
            this.dim = parameter.getDimension();
        }

        @Override
        public Parameter getParameter() {
            return this.parameter;
        }

        @Override
        public int getDimension() {
            return this.dim;
        }

        @Override
        public Parameter getRawParameter() {
            return this.rawParameter;
        }

        @Override
        public double[] getGradientLogDensity() {
            double[] dArray = this.branchSpecificGradient.getGradientLogDensity();
            return this.getGradientLogDensity(dArray);
        }

        @Override
        public double[] getGradientLogDensity(double[] dArray) {
            return this.extractGradient(dArray);
        }

        private double[] extractGradient(double[] dArray) {
            double[] dArray2 = new double[this.dim];
            for (int i = 0; i < this.dim; ++i) {
                dArray2[i] = dArray[this.offset + i];
            }
            return dArray2;
        }

        @Override
        public ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter getDerivationParameter() {
            List<ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter> list = this.branchSpecificGradient.getDerivationParameter();
            assert (list.size() == 1);
            return list.get(0);
        }

        @Override
        public String getReport() {
            return "Gradient." + this.rawParameter.getParameterName() + "\n" + super.getReport();
        }

        public static ParameterDiffusionGradient createDriftGradient(BranchSpecificGradient branchSpecificGradient, Likelihood likelihood, Parameter parameter) {
            return new ParameterDiffusionGradient(branchSpecificGradient, likelihood, parameter, parameter, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY);
        }

        public static ParameterDiffusionGradient createDiagonalAttenuationGradient(BranchSpecificGradient branchSpecificGradient, Likelihood likelihood, MatrixParameterInterface matrixParameterInterface) {
            assert (matrixParameterInterface instanceof DiagonalMatrix) : "DiagonalAttenuationGradient can only be applied to a DiagonalMatrix.";
            return new ParameterDiffusionGradient(branchSpecificGradient, likelihood, ((DiagonalMatrix)matrixParameterInterface).getDiagonalParameter(), (DiagonalMatrix)matrixParameterInterface, Double.POSITIVE_INFINITY, 0.0);
        }
    }
}

