/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.types;

import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.Graphs;
import gnu.trove.THashSet;
import java.util.Collections;
import java.util.Set;
import org._3pq.jgrapht.UndirectedGraph;
import org._3pq.jgrapht.alg.ConnectivityInspector;

public class UndirectedModel
extends FactorGraph {
    private Set edges = new THashSet();

    public UndirectedModel() {
    }

    public UndirectedModel(Variable[] vars) {
        super(vars);
    }

    public UndirectedModel(int capacity) {
        super(capacity);
    }

    public Set getEdgeSet() {
        return Collections.unmodifiableSet(this.edges);
    }

    @Override
    public void addFactor(Factor factor) {
        super.addFactor(factor);
        if (factor.varSet().size() == 2) {
            this.edges.add(factor.varSet());
        }
    }

    public static UndirectedModel createBoltzmannMachine(double[][] weights, double[] biases) {
        if (weights.length != biases.length) {
            throw new IllegalArgumentException("Number of weights " + weights.length + " not equal to number of biases " + biases.length);
        }
        int numV = weights.length;
        Variable[] vars = new Variable[numV];
        for (int i = 0; i < numV; ++i) {
            vars[i] = new Variable(2);
        }
        UndirectedModel mdl = new UndirectedModel(vars);
        for (int i = 0; i < numV; ++i) {
            TableFactor nodePtl = new TableFactor(vars[i], new double[]{1.0, Math.exp(biases[i])});
            mdl.addFactor(nodePtl);
            for (int j = i + 1; j < numV; ++j) {
                if (weights[i][j] == 0.0) continue;
                double[] ptl = new double[]{1.0, 1.0, 1.0, Math.exp(weights[i][j])};
                mdl.addFactor(vars[i], vars[j], ptl);
            }
        }
        return mdl;
    }

    public boolean isConnected(Variable v1, Variable v2) {
        UndirectedGraph g = Graphs.mdlToGraph(this);
        ConnectivityInspector ins = new ConnectivityInspector(g);
        return g.containsVertex((Object)v1) && g.containsVertex((Object)v2) && ins.pathExists((Object)v1, (Object)v2);
    }
}

