/*
 * Decompiled with CFR 0.152.
 */
package ru.itmo.ctlab.virgo.sgmwcs.solver;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.stream.Collectors;
import ru.itmo.ctlab.virgo.sgmwcs.Signals;
import ru.itmo.ctlab.virgo.sgmwcs.graph.Edge;
import ru.itmo.ctlab.virgo.sgmwcs.graph.Graph;
import ru.itmo.ctlab.virgo.sgmwcs.graph.Node;
import ru.itmo.ctlab.virgo.sgmwcs.graph.Unit;

class Dijkstra {
    private Graph graph;
    private Signals signals;
    private Map<Node, Double> d;
    private Map<Unit, Set<Integer>> p;
    private Map<Set<Integer>, Double> cache;
    private Set<Node> dests;
    private Set<Integer> currentSignals;

    private double currentWeight() {
        Set negSets = this.signals.filter(this.currentSignals, s -> this.signals.weight((int)s) < 0.0).collect(Collectors.toSet());
        return this.cache.computeIfAbsent(negSets, s -> -this.signals.weightSum(negSets));
    }

    private double weight(Node n) {
        return this.d.getOrDefault(n, (Double)Double.MAX_VALUE);
    }

    Dijkstra(Graph graph, Signals signals) {
        this.graph = graph;
        this.signals = signals;
        this.dests = new HashSet<Node>();
    }

    public void solve(Node u) {
        Node cur;
        this.d = new HashMap<Node, Double>();
        this.p = new HashMap<Unit, Set<Integer>>();
        PriorityQueue<Node> q = new PriorityQueue<Node>(Comparator.comparingDouble(this::weight));
        this.cache = new HashMap<Set<Integer>, Double>();
        this.currentSignals = new HashSet<Integer>();
        q.add(u);
        this.d.put(u, 0.0);
        this.p.put(u, new HashSet<Integer>(this.signals.positiveUnitSets((Unit)u)));
        ArrayList<Integer> addedE = new ArrayList<Integer>();
        ArrayList<Integer> addedN = new ArrayList<Integer>();
        HashSet<Node> visitedDests = new HashSet<Node>();
        while (!((cur = q.poll()) == null || this.dests.contains(cur) && visitedDests.add(cur) && visitedDests.containsAll(this.dests))) {
            this.currentSignals = this.p.getOrDefault(cur, new HashSet());
            double cw = this.currentWeight();
            for (Node node : this.graph.neighborListOf(cur)) {
                List<Integer> negN = this.signals.unitSets((Unit)node);
                double sumN = 0.0;
                Iterator<Comparable<Integer>> iterator2 = negN.iterator();
                while (iterator2.hasNext()) {
                    int i = iterator2.next();
                    if (!this.currentSignals.add(i)) continue;
                    addedN.add(i);
                    sumN -= Math.min(this.signals.weight(i), 0.0);
                }
                cw += sumN;
                for (Edge edge : this.graph.getAllEdges(node, cur)) {
                    List<Integer> negE2 = this.signals.unitSets((Unit)edge);
                    double sumE = 0.0;
                    for (int i : negE2) {
                        if (!this.currentSignals.add(i)) continue;
                        addedE.add(i);
                        sumE -= Math.min(this.signals.weight(i), 0.0);
                    }
                    if ((cw += sumE) < this.weight(node)) {
                        q.remove(node);
                        this.d.put(node, cw);
                        this.p.put(node, new HashSet<Integer>(this.currentSignals));
                        q.add(node);
                    }
                    this.currentSignals.removeAll(addedE);
                    addedE.clear();
                    cw -= sumE;
                }
                this.currentSignals.removeAll(addedN);
                addedN.clear();
                cw -= sumN;
            }
        }
    }

    boolean solveNP(Node u) {
        List<Node> nbors = this.graph.neighborListOf(u);
        if (nbors.size() != 2) {
            return false;
        }
        Node v_1 = nbors.get(0);
        Node v_2 = nbors.get(1);
        this.dests.add(v_2);
        this.solve(v_1);
        HashSet<Integer> neg = new HashSet<Integer>(this.signals.negativeUnitSets((Unit)u));
        neg.addAll(this.signals.negativeUnitSets(this.graph.edgesOf(u)));
        if (this.p.get(v_2).containsAll(neg)) {
            return false;
        }
        HashSet<Integer> pos = new HashSet<Integer>(this.signals.positiveUnitSets((Unit)u));
        pos.addAll(this.signals.positiveUnitSets(this.graph.edgesOf(u)));
        pos.removeAll(this.signals.positiveUnitSets(v_1, v_2));
        return this.p.get(v_2).containsAll(pos);
    }

    Set<Edge> solveNE(Node u, List<Node> neighbors) {
        this.dests = new HashSet<Node>(neighbors);
        this.solve(u);
        HashSet<Edge> res = new HashSet<Edge>();
        neighbors.forEach(n -> {
            List<Edge> edges = this.graph.getAllEdges((Node)n, u);
            this.p.get(n).removeAll(this.signals.unitSets((Unit)u, (Unit)n));
            for (Edge e : edges) {
                if (this.p.get(n).containsAll(this.signals.negativeUnitSets((Unit)e))) continue;
                res.add(e);
            }
        });
        return res;
    }

    Map<Node, Double> distances() {
        return this.d;
    }
}

