174 lines
5.5 KiB
Python
174 lines
5.5 KiB
Python
from collections import deque
|
|
from typing import List, Set
|
|
|
|
|
|
class DiGraph:
|
|
"""Really simple unweighted directed graph data structure to track dependencies.
|
|
|
|
The API is pretty much the same as networkx so if you add something just
|
|
copy their API.
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Dict of node -> dict of arbitrary attributes
|
|
self._node = {}
|
|
# Nested dict of node -> successor node -> nothing.
|
|
# (didn't implement edge data)
|
|
self._succ = {}
|
|
# Nested dict of node -> predecessor node -> nothing.
|
|
self._pred = {}
|
|
|
|
# Keep track of the order in which nodes are added to
|
|
# the graph.
|
|
self._node_order = {}
|
|
self._insertion_idx = 0
|
|
|
|
def add_node(self, n, **kwargs):
|
|
"""Add a node to the graph.
|
|
|
|
Args:
|
|
n: the node. Can we any object that is a valid dict key.
|
|
**kwargs: any attributes you want to attach to the node.
|
|
"""
|
|
if n not in self._node:
|
|
self._node[n] = kwargs
|
|
self._succ[n] = {}
|
|
self._pred[n] = {}
|
|
self._node_order[n] = self._insertion_idx
|
|
self._insertion_idx += 1
|
|
else:
|
|
self._node[n].update(kwargs)
|
|
|
|
def add_edge(self, u, v):
|
|
"""Add an edge to graph between nodes ``u`` and ``v``
|
|
|
|
``u`` and ``v`` will be created if they do not already exist.
|
|
"""
|
|
# add nodes
|
|
self.add_node(u)
|
|
self.add_node(v)
|
|
|
|
# add the edge
|
|
self._succ[u][v] = True
|
|
self._pred[v][u] = True
|
|
|
|
def successors(self, n):
|
|
"""Returns an iterator over successor nodes of n."""
|
|
try:
|
|
return iter(self._succ[n])
|
|
except KeyError as e:
|
|
raise ValueError(f"The node {n} is not in the digraph.") from e
|
|
|
|
def predecessors(self, n):
|
|
"""Returns an iterator over predecessors nodes of n."""
|
|
try:
|
|
return iter(self._pred[n])
|
|
except KeyError as e:
|
|
raise ValueError(f"The node {n} is not in the digraph.") from e
|
|
|
|
@property
|
|
def edges(self):
|
|
"""Returns an iterator over all edges (u, v) in the graph"""
|
|
for n, successors in self._succ.items():
|
|
for succ in successors:
|
|
yield n, succ
|
|
|
|
@property
|
|
def nodes(self):
|
|
"""Returns a dictionary of all nodes to their attributes."""
|
|
return self._node
|
|
|
|
def __iter__(self):
|
|
"""Iterate over the nodes."""
|
|
return iter(self._node)
|
|
|
|
def __contains__(self, n):
|
|
"""Returns True if ``n`` is a node in the graph, False otherwise."""
|
|
try:
|
|
return n in self._node
|
|
except TypeError:
|
|
return False
|
|
|
|
def forward_transitive_closure(self, src: str) -> Set[str]:
|
|
"""Returns a set of nodes that are reachable from src"""
|
|
|
|
result = set(src)
|
|
working_set = deque(src)
|
|
while len(working_set) > 0:
|
|
cur = working_set.popleft()
|
|
for n in self.successors(cur):
|
|
if n not in result:
|
|
result.add(n)
|
|
working_set.append(n)
|
|
return result
|
|
|
|
def backward_transitive_closure(self, src: str) -> Set[str]:
|
|
"""Returns a set of nodes that are reachable from src in reverse direction"""
|
|
|
|
result = set(src)
|
|
working_set = deque(src)
|
|
while len(working_set) > 0:
|
|
cur = working_set.popleft()
|
|
for n in self.predecessors(cur):
|
|
if n not in result:
|
|
result.add(n)
|
|
working_set.append(n)
|
|
return result
|
|
|
|
def all_paths(self, src: str, dst: str):
|
|
"""Returns a subgraph rooted at src that shows all the paths to dst."""
|
|
|
|
result_graph = DiGraph()
|
|
# First compute forward transitive closure of src (all things reachable from src).
|
|
forward_reachable_from_src = self.forward_transitive_closure(src)
|
|
|
|
if dst not in forward_reachable_from_src:
|
|
return result_graph
|
|
|
|
# Second walk the reverse dependencies of dst, adding each node to
|
|
# the output graph iff it is also present in forward_reachable_from_src.
|
|
# we don't use backward_transitive_closures for optimization purposes
|
|
working_set = deque(dst)
|
|
while len(working_set) > 0:
|
|
cur = working_set.popleft()
|
|
for n in self.predecessors(cur):
|
|
if n in forward_reachable_from_src:
|
|
result_graph.add_edge(n, cur)
|
|
# only explore further if its reachable from src
|
|
working_set.append(n)
|
|
|
|
return result_graph.to_dot()
|
|
|
|
def first_path(self, dst: str) -> List[str]:
|
|
"""Returns a list of nodes that show the first path that resulted in dst being added to the graph."""
|
|
path = []
|
|
|
|
while dst:
|
|
path.append(dst)
|
|
candidates = self._pred[dst].keys()
|
|
dst, min_idx = "", None
|
|
for candidate in candidates:
|
|
idx = self._node_order.get(candidate, None)
|
|
if idx is None:
|
|
break
|
|
if min_idx is None or idx < min_idx:
|
|
min_idx = idx
|
|
dst = candidate
|
|
|
|
return list(reversed(path))
|
|
|
|
def to_dot(self) -> str:
|
|
"""Returns the dot representation of the graph.
|
|
|
|
Returns:
|
|
A dot representation of the graph.
|
|
"""
|
|
edges = "\n".join(f'"{f}" -> "{t}";' for f, t in self.edges)
|
|
return f"""\
|
|
digraph G {{
|
|
rankdir = LR;
|
|
node [shape=box];
|
|
{edges}
|
|
}}
|
|
"""
|