/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised.constraints;

import cc.mallet.fst.SumLattice;
import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import gnu.trove.TIntArrayList;
import gnu.trove.TIntObjectHashMap;
import java.util.ArrayList;
import java.util.BitSet;

public abstract class OneLabelGEConstraints
implements GEConstraint {
    protected TIntObjectHashMap<OneLabelGEConstraint> constraints;
    protected StateLabelMap map;
    protected TIntArrayList cache;

    public OneLabelGEConstraints() {
        this.constraints = new TIntObjectHashMap();
        this.cache = new TIntArrayList();
    }

    protected OneLabelGEConstraints(TIntObjectHashMap<OneLabelGEConstraint> constraints, StateLabelMap map) {
        this.constraints = constraints;
        this.map = map;
        this.cache = new TIntArrayList();
    }

    public abstract void addConstraint(int var1, double[] var2, double var3);

    @Override
    public boolean isOneStateConstraint() {
        return true;
    }

    @Override
    public void setStateLabelMap(StateLabelMap map) {
        this.map = map;
    }

    @Override
    public void preProcess(FeatureVector fv) {
        this.cache.resetQuick();
        for (int loc = 0; loc < fv.numLocations(); ++loc) {
            int fi = fv.indexAtLocation(loc);
            if (!this.constraints.containsKey(fi)) continue;
            this.cache.add(fi);
        }
        if (this.constraints.containsKey(fv.getAlphabet().size())) {
            this.cache.add(fv.getAlphabet().size());
        }
    }

    @Override
    public BitSet preProcess(InstanceList data) {
        int ii = 0;
        BitSet bitSet = new BitSet(data.size());
        for (Instance instance : data) {
            FeatureVectorSequence fvs = (FeatureVectorSequence)instance.getData();
            for (int ip = 0; ip < fvs.size(); ++ip) {
                FeatureVector fv = fvs.get(ip);
                for (int loc = 0; loc < fv.numLocations(); ++loc) {
                    int fi = fv.indexAtLocation(loc);
                    if (!this.constraints.containsKey(fi)) continue;
                    ((OneLabelGEConstraint)this.constraints.get((int)fi)).count += 1.0;
                    bitSet.set(ii);
                }
                if (!this.constraints.containsKey(fv.getAlphabet().size())) continue;
                bitSet.set(ii);
                ((OneLabelGEConstraint)this.constraints.get((int)fv.getAlphabet().size())).count += 1.0;
            }
            ++ii;
        }
        return bitSet;
    }

    @Override
    public double getCompositeConstraintFeatureValue(FeatureVector fv, int ip, int si1, int si2) {
        double value = 0.0;
        int li2 = this.map.getLabelIndex(si2);
        for (int i = 0; i < this.cache.size(); ++i) {
            value += ((OneLabelGEConstraint)this.constraints.get(this.cache.getQuick(i))).getValue(li2);
        }
        return value;
    }

    @Override
    public abstract double getValue();

    @Override
    public void zeroExpectations() {
        for (int fi : this.constraints.keys()) {
            ((OneLabelGEConstraint)this.constraints.get((int)fi)).expectation = new double[this.map.getNumLabels()];
        }
    }

    @Override
    public void computeExpectations(ArrayList<SumLattice> lattices) {
        TIntArrayList cache = new TIntArrayList();
        for (int i = 0; i < lattices.size(); ++i) {
            if (lattices.get(i) == null) continue;
            SumLattice lattice = lattices.get(i);
            FeatureVectorSequence fvs = (FeatureVectorSequence)lattice.getInput();
            double[][] gammas = lattice.getGammas();
            for (int ip = 0; ip < fvs.size(); ++ip) {
                cache.resetQuick();
                FeatureVector fv = fvs.getFeatureVector(ip);
                for (int loc = 0; loc < fv.numLocations(); ++loc) {
                    int fi = fv.indexAtLocation(loc);
                    if (!this.constraints.containsKey(fi)) continue;
                    cache.add(fi);
                }
                if (this.constraints.containsKey(fv.getAlphabet().size())) {
                    cache.add(fv.getAlphabet().size());
                }
                for (int s = 0; s < this.map.getNumStates(); ++s) {
                    int li = this.map.getLabelIndex(s);
                    if (li == -2) continue;
                    double gammaProb = Math.exp(gammas[ip + 1][s]);
                    for (int j = 0; j < cache.size(); ++j) {
                        int n = li;
                        ((OneLabelGEConstraint)this.constraints.get((int)cache.getQuick((int)j))).expectation[n] = ((OneLabelGEConstraint)this.constraints.get((int)cache.getQuick((int)j))).expectation[n] + gammaProb;
                    }
                }
            }
        }
    }

    protected abstract class OneLabelGEConstraint {
        protected double[] target;
        protected double[] expectation;
        protected double count;
        protected double weight;

        public OneLabelGEConstraint(double[] target, double weight) {
            this.target = target;
            this.weight = weight;
            this.expectation = null;
            this.count = 0.0;
        }

        public double getCount() {
            return this.count;
        }

        public double[] getTarget() {
            return this.target;
        }

        public double[] getExpectation() {
            return this.expectation;
        }

        public double getWeight() {
            return this.weight;
        }

        public abstract double getValue(int var1);
    }
}

