forked from idrl/idrlnet
256 lines
9.3 KiB
Python
256 lines
9.3 KiB
Python
"""Define Computational graph"""
|
|
|
|
import sympy as sp
|
|
from typing import List, Dict, Union
|
|
from copy import copy
|
|
from collections import defaultdict
|
|
import networkx as nx
|
|
import matplotlib.pyplot as plt
|
|
import math
|
|
from idrlnet.variable import Variables
|
|
from idrlnet.node import Node
|
|
from idrlnet.header import logger, DIFF_SYMBOL
|
|
from idrlnet.pde import PdeNode
|
|
from idrlnet.net import NetNode
|
|
|
|
__all__ = ["ComputableNodeList", "Vertex", "VertexTaskPipeline"]
|
|
x, y = sp.symbols("x y")
|
|
ComputableNodeList = [List[Union[PdeNode, NetNode]]]
|
|
|
|
|
|
class Vertex(Node):
|
|
counter = 0
|
|
|
|
def __init__(self, pre=None, next=None, node=None, ntype="c"):
|
|
node = Node() if node is None else node
|
|
self.__dict__ = node.__dict__.copy()
|
|
self.index = type(self).counter
|
|
type(self).counter += 1
|
|
self.pre = pre if pre is not None else set()
|
|
self.next = next if pre is not None else set()
|
|
self.ntype = ntype
|
|
assert self.ntype in ("d", "c", "r")
|
|
|
|
def __eq__(self, other):
|
|
return self.index == other.index
|
|
|
|
def __hash__(self):
|
|
return self.index
|
|
|
|
def __str__(self):
|
|
info = (
|
|
f"index: {self.index}\n"
|
|
+ f"pre: {[node.index for node in self.pre]}\n"
|
|
+ f"next: {[node.index for node in self.next]}\n"
|
|
)
|
|
return super().__str__() + info
|
|
|
|
|
|
class VertexTaskPipeline:
|
|
MAX_STACK_ALLOWED = 100000
|
|
|
|
@property
|
|
def evaluation_order_list(self):
|
|
return self._evaluation_order_list
|
|
|
|
@evaluation_order_list.setter
|
|
def evaluation_order_list(self, evaluation_order_list):
|
|
self._evaluation_order_list = evaluation_order_list
|
|
|
|
def __init__(
|
|
self, nodes: ComputableNodeList, invar: Variables, req_names: List[str]
|
|
):
|
|
self.nodes = nodes
|
|
self.req_names = req_names
|
|
self.computable = set(invar.keys())
|
|
|
|
graph_nodes = set(Vertex(node=node) for node in nodes)
|
|
req_name_dict: Dict[str, List[Vertex]] = defaultdict(list)
|
|
|
|
self.G = nx.DiGraph()
|
|
self.egde_data = defaultdict(set)
|
|
required_stack = []
|
|
for req_name in req_names:
|
|
final_graph_node = Vertex()
|
|
if DIFF_SYMBOL in req_name:
|
|
final_graph_node.derivatives = (req_name,)
|
|
final_graph_node.inputs = tuple()
|
|
else:
|
|
final_graph_node.inputs = [req_name]
|
|
final_graph_node.derivatives = tuple()
|
|
final_graph_node.outputs = tuple()
|
|
final_graph_node.name = f"<{req_name}>"
|
|
final_graph_node.ntype = "r"
|
|
graph_nodes.add(final_graph_node)
|
|
req_name_dict[req_name].append(final_graph_node)
|
|
required_stack.append(final_graph_node)
|
|
final_graph_node.evaluate = lambda x: x
|
|
|
|
logger.info("Constructing computation graph...")
|
|
while len(req_name_dict) > 0:
|
|
to_be_removed = set()
|
|
to_be_added = defaultdict(list)
|
|
if len(required_stack) >= self.MAX_STACK_ALLOWED:
|
|
raise ValueError
|
|
for req_name, current_gn in req_name_dict.items():
|
|
req_name = tuple(req_name.split(DIFF_SYMBOL))
|
|
match_score = -1
|
|
match_gn = None
|
|
for gn in graph_nodes:
|
|
if gn in current_gn:
|
|
continue
|
|
for output in gn.outputs:
|
|
output = tuple(output.split(DIFF_SYMBOL))
|
|
if (
|
|
len(output) <= len(req_name)
|
|
and req_name[: len(output)] == output
|
|
and len(output) > match_score
|
|
):
|
|
match_score = len(output)
|
|
match_gn = gn
|
|
for p_in in invar.keys():
|
|
p_in = tuple(p_in.split(DIFF_SYMBOL))
|
|
if (
|
|
len(p_in) <= len(req_name)
|
|
and req_name[: len(p_in)] == p_in
|
|
and len(p_in) > match_score
|
|
):
|
|
match_score = len(p_in)
|
|
match_gn = None
|
|
for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]:
|
|
self.G.add_edge(DIFF_SYMBOL.join(p_in), sub_gn.name)
|
|
if match_score <= 0:
|
|
raise Exception("Can't be computed: " + DIFF_SYMBOL.join(req_name))
|
|
elif match_gn is not None:
|
|
for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]:
|
|
logger.info(
|
|
f"{sub_gn.name}.{DIFF_SYMBOL.join(req_name)} <---- {match_gn.name}"
|
|
)
|
|
match_gn.next.add(sub_gn)
|
|
self.egde_data[(match_gn.name, sub_gn.name)].add(
|
|
DIFF_SYMBOL.join(req_name)
|
|
)
|
|
required_stack.append(match_gn)
|
|
for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]:
|
|
sub_gn.pre.add(match_gn)
|
|
for p in match_gn.inputs:
|
|
to_be_added[p].append(match_gn)
|
|
for p in match_gn.derivatives:
|
|
to_be_added[p].append(match_gn)
|
|
for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]:
|
|
self.G.add_edge(match_gn.name, sub_gn.name)
|
|
to_be_removed.add(DIFF_SYMBOL.join(req_name))
|
|
if len(to_be_removed) == 0 and len(req_name_dict) > 0:
|
|
raise Exception("Can't be computed")
|
|
for p in to_be_removed:
|
|
req_name_dict.pop(p)
|
|
self.computable.add(p)
|
|
for k, v in to_be_added.items():
|
|
if k in req_name_dict:
|
|
req_name_dict[k].extend(v)
|
|
else:
|
|
req_name_dict[k] = v
|
|
evaluation_order = []
|
|
while len(required_stack) > 0:
|
|
gn = required_stack.pop()
|
|
if gn not in evaluation_order:
|
|
evaluation_order.append(gn)
|
|
self.computable = self.computable.union(set(gn.outputs))
|
|
self.evaluation_order_list = evaluation_order
|
|
self._graph_node_table = {node.name: node for node in graph_nodes}
|
|
for key in invar:
|
|
node = Vertex()
|
|
node.name = key
|
|
node.outputs = (key,)
|
|
node.inputs = tuple()
|
|
node.ntype = "d"
|
|
self._graph_node_table[key] = node
|
|
logger.info("Computation graph constructed.")
|
|
|
|
def operation_order(self, invar: Variables):
|
|
for node in self.evaluation_order_list:
|
|
if not set(node.derivatives).issubset(invar.keys()):
|
|
invar.differentiate_(
|
|
independent_var=invar, required_derivatives=node.derivatives
|
|
)
|
|
invar.update(
|
|
node.evaluate(
|
|
{**invar.subset(node.inputs), **invar.subset(node.derivatives)}
|
|
)
|
|
)
|
|
|
|
def forward_pipeline(
|
|
self, invar: Variables, req_names: List[str] = None
|
|
) -> Variables:
|
|
if req_names is None or set(req_names).issubset(set(self.computable)):
|
|
outvar = copy(invar)
|
|
self.operation_order(outvar)
|
|
return outvar.subset(self.req_names if req_names is None else req_names)
|
|
else:
|
|
logger.info("The existing graph fails. Construct a temporary graph...")
|
|
return VertexTaskPipeline(self.nodes, invar, req_names).forward_pipeline(
|
|
invar
|
|
)
|
|
|
|
def to_json(self):
|
|
pass
|
|
|
|
def display(self, filename: str = None):
|
|
_, ax = plt.subplots(1, 1, figsize=(8, 8))
|
|
ax.axis("off")
|
|
pos = nx.spring_layout(self.G, k=10 / (math.sqrt(self.G.order()) + 0.1))
|
|
nx.draw_networkx_nodes(
|
|
self.G,
|
|
pos,
|
|
nodelist=list(
|
|
node
|
|
for node in self.G.nodes
|
|
if self._graph_node_table[node].ntype == "c"
|
|
),
|
|
cmap=plt.get_cmap("jet"),
|
|
node_size=1300,
|
|
node_color="pink",
|
|
alpha=0.5,
|
|
)
|
|
nx.draw_networkx_nodes(
|
|
self.G,
|
|
pos,
|
|
nodelist=list(
|
|
node
|
|
for node in self.G.nodes
|
|
if self._graph_node_table[node].ntype == "r"
|
|
),
|
|
cmap=plt.get_cmap("jet"),
|
|
node_size=1300,
|
|
node_color="green",
|
|
alpha=0.3,
|
|
)
|
|
nx.draw_networkx_nodes(
|
|
self.G,
|
|
pos,
|
|
nodelist=list(
|
|
node
|
|
for node in self.G.nodes
|
|
if self._graph_node_table[node].ntype == "d"
|
|
),
|
|
cmap=plt.get_cmap("jet"),
|
|
node_size=1300,
|
|
node_color="blue",
|
|
alpha=0.3,
|
|
)
|
|
nx.draw_networkx_edges(
|
|
self.G, pos, edge_color="r", arrows=True, arrowsize=30, arrowstyle="-|>"
|
|
)
|
|
nx.draw_networkx_labels(self.G, pos)
|
|
nx.draw_networkx_edge_labels(
|
|
self.G,
|
|
pos,
|
|
edge_labels={k: ", ".join(v) for k, v in self.egde_data.items()},
|
|
font_size=10,
|
|
)
|
|
if filename is None:
|
|
plt.show()
|
|
else:
|
|
plt.savefig(filename)
|
|
plt.close()
|