/*
 * Decompiled with CFR 0.152.
 */
package org.act.cat;

import com.google.common.primitives.Doubles;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.act.cat.CatInput;
import org.act.cat.CatInputStandard;
import org.act.cat.ItemScores;
import org.act.cat.PassageOrItemEligibilityOverall;
import org.act.sim.SimulationFunctions;
import org.act.testdef.Item;
import org.act.util.PrimitiveArraySet;
import org.act.util.PrimitiveArrays;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.distribution.LogNormalDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;

public class CatHelper {
    private static final String A_PARAM = Item.ColumnName.A_PARAM.getColName();
    private static final String B_PARAM = Item.ColumnName.B_PARAM.getColName();
    private static final String C_PARAM = Item.ColumnName.C_PARAM.getColName();
    private static final String A_PARAM_SE = Item.ColumnName.A_PARAM_SE.getColName();
    private static final String B_PARAM_SE = Item.ColumnName.B_PARAM_SE.getColName();
    private static final String C_PARAM_SE = Item.ColumnName.C_PARAM_SE.getColName();
    private static final String D_CONST = Item.ColumnName.D_CONSTANT.getColName();

    private CatHelper() {
    }

    public static RealMatrix getItemParams(PrimitiveArraySet itemPoolDataset) {
        double[] aPar = itemPoolDataset.getDoubleArrayCopy(A_PARAM);
        double[] bPar = itemPoolDataset.getDoubleArrayCopy(B_PARAM);
        double[] cPar = itemPoolDataset.getDoubleArrayCopy(C_PARAM);
        double[] dConst = itemPoolDataset.getDoubleArrayCopy(D_CONST);
        double[][] itemParData = new double[][]{aPar, bPar, cPar, dConst};
        RealMatrix itemPar = MatrixUtils.createRealMatrix((double[][])itemParData);
        return itemPar.transpose();
    }

    public static double[][][] getItemParamsSamples(PrimitiveArraySet itemPoolDataset, int sampleSize) {
        double[] aPar = itemPoolDataset.getDoubleArrayCopy(A_PARAM);
        double[] bPar = itemPoolDataset.getDoubleArrayCopy(B_PARAM);
        double[] cPar = itemPoolDataset.getDoubleArrayCopy(C_PARAM);
        double[] aParSE = itemPoolDataset.getDoubleArrayCopy(A_PARAM_SE);
        double[] bParSE = itemPoolDataset.getDoubleArrayCopy(B_PARAM_SE);
        double[] cParSE = itemPoolDataset.getDoubleArrayCopy(C_PARAM_SE);
        int itemNum = bPar.length;
        double[][][] itemParSamples = new double[itemNum][3][sampleSize];
        for (int i = 0; i < itemNum; ++i) {
            double logaMean = Math.log(aPar[i] / Math.sqrt(1.0 + Math.pow(aParSE[i], 2.0) / Math.pow(aPar[i], 2.0)));
            double logaSD = Math.sqrt(Math.log(1.0 + Math.pow(aParSE[i], 2.0) / Math.pow(aPar[i], 2.0)));
            LogNormalDistribution logNDistA = new LogNormalDistribution(logaMean, logaSD);
            itemParSamples[i][0] = logNDistA.sample(sampleSize);
            NormalDistribution normDistB = new NormalDistribution(bPar[i], bParSE[i]);
            itemParSamples[i][1] = normDistB.sample(sampleSize);
            NormalDistribution normDistC = new NormalDistribution(cPar[i], cParSE[i]);
            double[] tempSamples = normDistC.sample(sampleSize);
            ArrayList<Double> filteredLogitSampleList = new ArrayList<Double>();
            for (double value : tempSamples) {
                if (!(value >= 0.0) || !(value <= 1.0)) continue;
                filteredLogitSampleList.add(Math.log(value / (1.0 - value)));
            }
            double[] logitMeanSd = CatHelper.calMeanSD(filteredLogitSampleList);
            NormalDistribution normDistLogitC = new NormalDistribution(logitMeanSd[0], logitMeanSd[1]);
            double[] logitCSamples = normDistLogitC.sample(sampleSize);
            for (int n = 0; n < sampleSize; ++n) {
                itemParSamples[i][2][n] = 1.0 / (1.0 + Math.exp(-logitCSamples[n]));
            }
        }
        return itemParSamples;
    }

    public static RealMatrix getItemParamsForScoring(List<String> itemIds, List<String> itemsToAdminThisStage, RealMatrix itemParams) {
        int[] rowIndices = PrimitiveArrays.select(itemIds.toArray(new String[0]), itemsToAdminThisStage.toArray(new String[0]));
        int[] colIndices = new int[]{0, 1, 2, 3};
        return itemParams.getSubMatrix(rowIndices, colIndices);
    }

    public static CatInput createNextCatInput(CatInput catInput, ItemScores itemScores, List<String> itemsToAdminThisStage, int stage, Map<String, Integer> itemToPassageIndexMap, PassageOrItemEligibilityOverall passageOrItemEligibilityOverall, List<String> administeredItems, List<String> itemsToAdminister, List<String> shadowTest, double previousTheta, double previousThetaSE) {
        int[] itemScoresAllInt = ArrayUtils.addAll((int[])catInput.getItemScores().getItemScores(), (int[])itemScores.getItemScores());
        double[] respProbAll = ArrayUtils.addAll((double[])catInput.getItemScores().getRespProbs(), (double[])itemScores.getRespProbs());
        ItemScores allItemScores = new ItemScores(itemScoresAllInt, respProbAll);
        ArrayList<String> itemsAdminAll = new ArrayList<String>(catInput.getItemsAdmin());
        itemsAdminAll.addAll(itemsToAdminThisStage);
        List<Integer> administeredPassagesIndexSequence = SimulationFunctions.getPassageIndexOrderForAdministeredItems(itemToPassageIndexMap, itemsAdminAll);
        return new CatInputStandard.Builder().catConfig(catInput.getCatConfig()).testConfig(catInput.getTestConfig()).itemScores(allItemScores).itemsAdmin(itemsAdminAll).completedCount(itemsAdminAll.size()).adaptiveStage(stage).administeredPassageIndexSequence(administeredPassagesIndexSequence).passageOrItemEligibilityOverall(passageOrItemEligibilityOverall).itemsToAdminister(itemsToAdminister).shadowTest(shadowTest).previousTheta(previousTheta).previousThetaSe(previousThetaSE).build();
    }

    public static boolean[] getPreviousShadowBoolean(String[] itemIds, String[] previousShadowTest) {
        boolean[] previousShadowBoolean = new boolean[itemIds.length];
        if (previousShadowTest.length > 0) {
            int[] indicesPreviousShadow = PrimitiveArrays.select(itemIds, previousShadowTest);
            for (int idx = 0; idx < indicesPreviousShadow.length; ++idx) {
                int previousShadowItem = indicesPreviousShadow[idx];
                previousShadowBoolean[previousShadowItem] = true;
            }
        }
        return previousShadowBoolean;
    }

    public static List<Integer> getPreSelPassageSeqRowIndex(String[] itemIds, String[] passageIdsFromItemTable, String[] passageIdsFromPassageTable, String[] previousShadowTest) {
        int[] previousShadowTestIndices = PrimitiveArrays.select(itemIds, previousShadowTest);
        ArrayList<Integer> prePassageSeqRowIndex = new ArrayList<Integer>();
        for (int i = 0; i < previousShadowTestIndices.length; ++i) {
            String previousPassageString = passageIdsFromItemTable[previousShadowTestIndices[i]];
            if (previousPassageString.equalsIgnoreCase("none")) continue;
            int previousPassageIdx = PrimitiveArrays.select(passageIdsFromPassageTable, previousPassageString)[0];
            if (prePassageSeqRowIndex.isEmpty()) {
                prePassageSeqRowIndex.add(previousPassageIdx);
                continue;
            }
            if (((Integer)prePassageSeqRowIndex.get(prePassageSeqRowIndex.size() - 1)).equals(previousPassageIdx)) continue;
            prePassageSeqRowIndex.add(previousPassageIdx);
        }
        return prePassageSeqRowIndex;
    }

    public static double[] calMeanSD(List<Double> data) {
        double mean = 0.0;
        double sd = 0.0;
        for (int n = 0; n < data.size(); ++n) {
            mean += data.get(n).doubleValue();
        }
        mean /= (double)data.size();
        double sqrDiff = 0.0;
        for (int n = 0; n < data.size(); ++n) {
            sqrDiff += Math.pow(data.get(n) - mean, 2.0);
        }
        sd = Math.sqrt(sqrDiff / (double)(data.size() - 1));
        return new double[]{mean, sd};
    }

    public static double[] calMeanSD(double[] data) {
        return CatHelper.calMeanSD(Doubles.asList((double[])data));
    }
}

