/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.stuff;

import dr.evomodel.stuff.GenPolyaUrnProcessPrior;
import dr.inference.distribution.MultivariateDistributionLikelihood;
import dr.inference.distribution.MultivariateNormalDistributionModel;
import dr.inference.distribution.WishartGammalDistributionModel;
import dr.inference.model.CompoundParameter;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.distributions.WishartDistribution;
import dr.math.distributions.WishartStatistics;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

public class BaseDistPrecisionGibbsOperator
extends SimpleMCMCOperator
implements GibbsOperator {
    private static final String BASE_DIST_PREC_OPERATOR = "baseDistPrecisionGibbsOperator";
    public static final String WEIGHT = "weight";
    public static final String BASE_DIST_NUM = "baseDistNum";
    private double pathWeight = 1.0;
    private GenPolyaUrnProcessPrior gpuProcess;
    private final Parameter mean;
    private final MatrixParameterInterface precision;
    private final int baseDistNum;
    private final int dim;
    private double priorDf;
    private SymmetricMatrix priorInverseScaleMatrix;
    private WishartGammalDistributionModel priorModel = null;
    private static final boolean DEBUG = false;
    private double numberObservations;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newDoubleRule("weight"), AttributeRule.newDoubleRule("baseDistNum", false), new ElementRule(MultivariateDistributionLikelihood.class, false), new ElementRule(GenPolyaUrnProcessPrior.class, false)};

        @Override
        public String getParserName() {
            return BaseDistPrecisionGibbsOperator.BASE_DIST_PREC_OPERATOR;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            double d = xMLObject.getDoubleAttribute(BaseDistPrecisionGibbsOperator.WEIGHT);
            int n = xMLObject.getIntegerAttribute(BaseDistPrecisionGibbsOperator.BASE_DIST_NUM);
            MultivariateDistributionLikelihood multivariateDistributionLikelihood = (MultivariateDistributionLikelihood)xMLObject.getChild(MultivariateDistributionLikelihood.class);
            GenPolyaUrnProcessPrior genPolyaUrnProcessPrior = (GenPolyaUrnProcessPrior)xMLObject.getChild(GenPolyaUrnProcessPrior.class);
            return new BaseDistPrecisionGibbsOperator(genPolyaUrnProcessPrior, (WishartStatistics)((Object)multivariateDistributionLikelihood.getDistribution()), n, d);
        }

        @Override
        public String getParserDescription() {
            return "This element returns a Gibbs sampler for the precision matrix of a GPU process multivariate normal base distribution.";
        }

        @Override
        public Class getReturnType() {
            return MCMCOperator.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public BaseDistPrecisionGibbsOperator(GenPolyaUrnProcessPrior genPolyaUrnProcessPrior, WishartStatistics wishartStatistics, int n, double d) {
        this.gpuProcess = genPolyaUrnProcessPrior;
        MultivariateNormalDistributionModel multivariateNormalDistributionModel = (MultivariateNormalDistributionModel)genPolyaUrnProcessPrior.getParametricBaseDist().get(n);
        this.mean = multivariateNormalDistributionModel.getMeanParameter();
        this.precision = multivariateNormalDistributionModel.getPrecisionMatrixParameter();
        this.dim = this.mean.getDimension();
        this.baseDistNum = n;
        this.setupWishartStatistics(wishartStatistics);
        this.setWeight(d);
    }

    private void setupWishartStatistics(WishartStatistics wishartStatistics) {
        this.priorDf = wishartStatistics.getDF();
        this.priorInverseScaleMatrix = null;
        double[][] dArray = wishartStatistics.getScaleMatrix();
        if (dArray != null) {
            this.priorInverseScaleMatrix = new SymmetricMatrix(dArray).inverse();
        }
    }

    private void incrementOuterProduct(double[][] dArray, GenPolyaUrnProcessPrior genPolyaUrnProcessPrior) {
        double[] dArray2 = genPolyaUrnProcessPrior.getParametricBaseDist().get(this.baseDistNum).getMean();
        CompoundParameter compoundParameter = genPolyaUrnProcessPrior.getUniquelyRealizedParameters();
        int[] nArray = genPolyaUrnProcessPrior.getIsCatActive();
        this.numberObservations = 0.0;
        for (int i = 0; i < genPolyaUrnProcessPrior.maxCategoryCount; ++i) {
            int n;
            if (nArray[i] != 1) continue;
            double[] dArray3 = compoundParameter.getParameter(i).getParameterValues();
            for (n = 0; n < this.dim; ++n) {
                int n2 = n;
                dArray3[n2] = dArray3[n2] - dArray2[n];
            }
            for (n = 0; n < this.dim; ++n) {
                for (int j = n; j < this.dim; ++j) {
                    double[] dArray4 = dArray[n];
                    int n3 = j;
                    double d = dArray4[n3] + dArray3[n] * dArray3[j];
                    dArray4[n3] = d;
                    dArray[j][n] = d;
                }
            }
            this.numberObservations += 1.0;
        }
    }

    private double[][] getOperationScaleMatrixAndSetObservationCount() {
        double[][] dArray = new double[this.dim][this.dim];
        Matrix matrix = null;
        this.numberObservations = 0.0;
        this.incrementOuterProduct(dArray, this.gpuProcess);
        try {
            SymmetricMatrix symmetricMatrix = new SymmetricMatrix(dArray);
            if (this.pathWeight != 1.0) {
                symmetricMatrix = (SymmetricMatrix)symmetricMatrix.product(this.pathWeight);
            }
            if (this.priorInverseScaleMatrix != null) {
                symmetricMatrix = this.priorInverseScaleMatrix.add(symmetricMatrix);
            }
            matrix = symmetricMatrix.inverse();
        }
        catch (IllegalDimension illegalDimension) {
            illegalDimension.printStackTrace();
        }
        assert (matrix != null);
        return matrix.toComponents();
    }

    @Override
    public double doOperation() {
        double[][] dArray = this.getOperationScaleMatrixAndSetObservationCount();
        double d = this.numberObservations;
        double d2 = this.priorDf + d * this.pathWeight;
        double[][] dArray2 = WishartDistribution.nextWishart(d2, dArray);
        for (int i = 0; i < this.dim; ++i) {
            Parameter parameter = this.precision.getParameter(i);
            for (int j = 0; j < this.dim; ++j) {
                parameter.setParameterValueQuietly(j, dArray2[j][i]);
            }
        }
        this.precision.fireParameterChangedEvent();
        return 0.0;
    }

    @Override
    public void setPathParameter(double d) {
        if (d < 0.0 || d > 1.0) {
            throw new IllegalArgumentException("Illegal path weight of " + d);
        }
        this.pathWeight = d;
    }

    @Override
    public String getOperatorName() {
        return BASE_DIST_PREC_OPERATOR;
    }
}

