style: change to black style

This commit is contained in:
zweien 2021-07-13 10:39:09 +08:00
parent 1d5984d9f0
commit f94494c43e
26 changed files with 1343 additions and 735 deletions

View File

@ -13,16 +13,16 @@
import os import os
import sys import sys
sys.path.insert(0, os.path.abspath('..')) sys.path.insert(0, os.path.abspath(".."))
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = 'idrlnet' project = "idrlnet"
copyright = '2021, IDRL' copyright = "2021, IDRL"
author = 'IDRL' author = "IDRL"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
release = '0.0.1-rc1' release = "0.0.1-rc1"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
@ -34,37 +34,37 @@ extensions = [
"sphinx.ext.mathjax", "sphinx.ext.mathjax",
"sphinx.ext.napoleon", "sphinx.ext.napoleon",
"sphinx.ext.viewcode", "sphinx.ext.viewcode",
'myst_parser', "myst_parser",
'sphinx.ext.autosectionlabel', "sphinx.ext.autosectionlabel",
] ]
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ["_templates"]
source_suffix = { source_suffix = {
'.rst': 'restructuredtext', ".rst": "restructuredtext",
'.txt': 'markdown', ".txt": "markdown",
'.md': 'markdown', ".md": "markdown",
} }
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path. # This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# -- Options for HTML output ------------------------------------------------- # -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
# #
html_theme = 'sphinx_rtd_theme' html_theme = "sphinx_rtd_theme"
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ["_static"]
# for MarkdownParser # for MarkdownParser
from sphinx_markdown_parser.parser import MarkdownParser # noqa from sphinx_markdown_parser.parser import MarkdownParser # noqa
# def setup(app): # def setup(app):

View File

@ -3,9 +3,9 @@ import sympy as sp
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
x = sp.Symbol('x') x = sp.Symbol("x")
s = sp.Symbol('s') s = sp.Symbol("s")
f = sp.Function('f')(x) f = sp.Function("f")(x)
geo = sc.Line1D(0, 5) geo = sc.Line1D(0, 5)
@ -19,43 +19,49 @@ def interior():
@sc.datanode @sc.datanode
def init(): def init():
points = geo.sample_boundary(1, sieve=sp.Eq(x, 0)) points = geo.sample_boundary(1, sieve=sp.Eq(x, 0))
points['lambda_f'] = 1000 * np.ones_like(points['x']) points["lambda_f"] = 1000 * np.ones_like(points["x"])
constraints = {'f': 1} constraints = {"f": 1}
return points, constraints return points, constraints
@sc.datanode(name='InteriorInfer') @sc.datanode(name="InteriorInfer")
def infer(): def infer():
points = {'x': np.linspace(0, 5, 1000).reshape(-1, 1)} points = {"x": np.linspace(0, 5, 1000).reshape(-1, 1)}
return points, {} return points, {}
netnode = sc.get_net_node(inputs=('x',), outputs=('f',), name='net') netnode = sc.get_net_node(inputs=("x",), outputs=("f",), name="net")
exp_lhs = sc.ExpressionNode(expression=f.diff(x) + f, name='lhs') exp_lhs = sc.ExpressionNode(expression=f.diff(x) + f, name="lhs")
fs = sp.Symbol('fs') fs = sp.Symbol("fs")
exp_rhs = sc.Int1DNode(expression=sp.exp(s - x) * fs, var=s, lb=0, ub=x, expression_name='rhs', exp_rhs = sc.Int1DNode(
funs={'fs': {'eval': netnode, expression=sp.exp(s - x) * fs,
'input_map': {'x': 's'}, var=s,
'output_map': {'f': 'fs'}}}, lb=0,
degree=10) ub=x,
diff = sc.Difference(T='lhs', S='rhs', dim=1, time=False) expression_name="rhs",
funs={"fs": {"eval": netnode, "input_map": {"x": "s"}, "output_map": {"f": "fs"}}},
degree=10,
)
diff = sc.Difference(T="lhs", S="rhs", dim=1, time=False)
solver = sc.Solver(sample_domains=(interior(), init(), infer()), solver = sc.Solver(
netnodes=[netnode], sample_domains=(interior(), init(), infer()),
pdes=[exp_lhs, exp_rhs, diff], netnodes=[netnode],
loading=True, pdes=[exp_lhs, exp_rhs, diff],
max_iter=3000) loading=True,
max_iter=3000,
)
solver.solve() solver.solve()
points = solver.infer_step({'InteriorInfer': ['x', 'f']}) points = solver.infer_step({"InteriorInfer": ["x", "f"]})
num_x = points['InteriorInfer']['x'].detach().cpu().numpy().ravel() num_x = points["InteriorInfer"]["x"].detach().cpu().numpy().ravel()
num_f = points['InteriorInfer']['f'].detach().cpu().numpy().ravel() num_f = points["InteriorInfer"]["f"].detach().cpu().numpy().ravel()
fig = plt.figure(figsize=(8,4)) fig = plt.figure(figsize=(8, 4))
plt.plot(num_x, num_f) plt.plot(num_x, num_f)
plt.plot(num_x, np.exp(-num_x) * np.cosh(num_x)) plt.plot(num_x, np.exp(-num_x) * np.cosh(num_x))
plt.xlabel('x') plt.xlabel("x")
plt.ylabel('y') plt.ylabel("y")
plt.legend(['Prediction', 'Exact']) plt.legend(["Prediction", "Exact"])
plt.savefig('ide.png', dpi=1000, bbox_inches='tight') plt.savefig("ide.png", dpi=1000, bbox_inches="tight")
plt.show() plt.show()

View File

@ -4,63 +4,82 @@ import matplotlib.pyplot as plt
import matplotlib.tri as tri import matplotlib.tri as tri
import idrlnet.shortcut as sc import idrlnet.shortcut as sc
x = Symbol('x') x = Symbol("x")
t_symbol = Symbol('t') t_symbol = Symbol("t")
time_range = {t_symbol: (0, 1)} time_range = {t_symbol: (0, 1)}
geo = sc.Line1D(-1., 1.) geo = sc.Line1D(-1.0, 1.0)
@sc.datanode(name='burgers_equation') @sc.datanode(name="burgers_equation")
def interior_domain(): def interior_domain():
points = geo.sample_interior(10000, bounds={x: (-1., 1.)}, param_ranges=time_range) points = geo.sample_interior(
constraints = {'burgers_u': 0} 10000, bounds={x: (-1.0, 1.0)}, param_ranges=time_range
)
constraints = {"burgers_u": 0}
return points, constraints return points, constraints
@sc.datanode(name='t_boundary') @sc.datanode(name="t_boundary")
def init_domain(): def init_domain():
points = geo.sample_interior(100, param_ranges={t_symbol: 0.0}) points = geo.sample_interior(100, param_ranges={t_symbol: 0.0})
constraints = sc.Variables({'u': -sin(math.pi * x)}) constraints = sc.Variables({"u": -sin(math.pi * x)})
return points, constraints return points, constraints
@sc.datanode(name="x_boundary") @sc.datanode(name="x_boundary")
def boundary_domain(): def boundary_domain():
points = geo.sample_boundary(100, param_ranges=time_range) points = geo.sample_boundary(100, param_ranges=time_range)
constraints = sc.Variables({'u': 0}) constraints = sc.Variables({"u": 0})
return points, constraints return points, constraints
net = sc.get_net_node(inputs=('x', 't',), outputs=('u',), name='net1', arch=sc.Arch.mlp) net = sc.get_net_node(
pde = sc.BurgersNode(u='u', v=0.01 / math.pi) inputs=(
s = sc.Solver(sample_domains=(interior_domain(), init_domain(), boundary_domain()), "x",
netnodes=[net], pdes=[pde], max_iter=4000) "t",
),
outputs=("u",),
name="net1",
arch=sc.Arch.mlp,
)
pde = sc.BurgersNode(u="u", v=0.01 / math.pi)
s = sc.Solver(
sample_domains=(interior_domain(), init_domain(), boundary_domain()),
netnodes=[net],
pdes=[pde],
max_iter=4000,
)
s.solve() s.solve()
coord = s.infer_step({'burgers_equation': ['x', 't', 'u'], 't_boundary': ['x', 't'], coord = s.infer_step(
'x_boundary': ['x', 't']}) {
num_x = coord['burgers_equation']['x'].cpu().detach().numpy().ravel() "burgers_equation": ["x", "t", "u"],
num_t = coord['burgers_equation']['t'].cpu().detach().numpy().ravel() "t_boundary": ["x", "t"],
num_u = coord['burgers_equation']['u'].cpu().detach().numpy().ravel() "x_boundary": ["x", "t"],
}
)
num_x = coord["burgers_equation"]["x"].cpu().detach().numpy().ravel()
num_t = coord["burgers_equation"]["t"].cpu().detach().numpy().ravel()
num_u = coord["burgers_equation"]["u"].cpu().detach().numpy().ravel()
init_x = coord['t_boundary']['x'].cpu().detach().numpy().ravel() init_x = coord["t_boundary"]["x"].cpu().detach().numpy().ravel()
init_t = coord['t_boundary']['t'].cpu().detach().numpy().ravel() init_t = coord["t_boundary"]["t"].cpu().detach().numpy().ravel()
boundary_x = coord['x_boundary']['x'].cpu().detach().numpy().ravel() boundary_x = coord["x_boundary"]["x"].cpu().detach().numpy().ravel()
boundary_t = coord['x_boundary']['t'].cpu().detach().numpy().ravel() boundary_t = coord["x_boundary"]["t"].cpu().detach().numpy().ravel()
triang_total = tri.Triangulation(num_t.flatten(), num_x.flatten()) triang_total = tri.Triangulation(num_t.flatten(), num_x.flatten())
u_pre = num_u.flatten() u_pre = num_u.flatten()
fig = plt.figure(figsize=(15, 5)) fig = plt.figure(figsize=(15, 5))
ax1 = fig.add_subplot(221) ax1 = fig.add_subplot(221)
tcf = ax1.tricontourf(triang_total, u_pre, 100, cmap='jet') tcf = ax1.tricontourf(triang_total, u_pre, 100, cmap="jet")
tc_bar = plt.colorbar(tcf) tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=10) tc_bar.ax.tick_params(labelsize=10)
ax1.set_xlabel('$t$') ax1.set_xlabel("$t$")
ax1.set_ylabel('$x$') ax1.set_ylabel("$x$")
ax1.set_title('$u(x,t)$') ax1.set_title("$u(x,t)$")
ax1.scatter(init_t, init_x, c='black', marker='x', s=8) ax1.scatter(init_t, init_x, c="black", marker="x", s=8)
ax1.scatter(boundary_t, boundary_x, c='black', marker='x', s=8) ax1.scatter(boundary_t, boundary_x, c="black", marker="x", s=8)
plt.xlim(0, 1) plt.xlim(0, 1)
plt.ylim(-1, 1) plt.ylim(-1, 1)
plt.savefig('Burgers.png', dpi=500, bbox_inches='tight', pad_inches=0.02) plt.savefig("Burgers.png", dpi=500, bbox_inches="tight", pad_inches=0.02)

View File

@ -3,59 +3,68 @@ import sympy as sp
import numpy as np import numpy as np
import idrlnet.shortcut as sc import idrlnet.shortcut as sc
x = sp.symbols('x') x = sp.symbols("x")
Line = sc.Line1D(0, 1) Line = sc.Line1D(0, 1)
y = sp.Function('y')(x) y = sp.Function("y")(x)
@sc.datanode(name='interior') @sc.datanode(name="interior")
class Interior(sc.SampleDomain): class Interior(sc.SampleDomain):
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
return Line.sample_interior(1000), {'dddd_y': 0} return Line.sample_interior(1000), {"dddd_y": 0}
@sc.datanode(name='left_boundary1') @sc.datanode(name="left_boundary1")
class LeftBoundary1(sc.SampleDomain): class LeftBoundary1(sc.SampleDomain):
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
return Line.sample_boundary(100, sieve=(sp.Eq(x, 0))), {'y': 0} return Line.sample_boundary(100, sieve=(sp.Eq(x, 0))), {"y": 0}
@sc.datanode(name='left_boundary2') @sc.datanode(name="left_boundary2")
class LeftBoundary2(sc.SampleDomain): class LeftBoundary2(sc.SampleDomain):
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
return Line.sample_boundary(100, sieve=(sp.Eq(x, 0))), {'d_y': 0} return Line.sample_boundary(100, sieve=(sp.Eq(x, 0))), {"d_y": 0}
@sc.datanode(name='right_boundary1') @sc.datanode(name="right_boundary1")
class RightBoundary1(sc.SampleDomain): class RightBoundary1(sc.SampleDomain):
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
return Line.sample_boundary(100, sieve=(sp.Eq(x, 1))), {'dd_y': 0} return Line.sample_boundary(100, sieve=(sp.Eq(x, 1))), {"dd_y": 0}
@sc.datanode(name='right_boundary2') @sc.datanode(name="right_boundary2")
class RightBoundary2(sc.SampleDomain): class RightBoundary2(sc.SampleDomain):
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
return Line.sample_boundary(100, sieve=(sp.Eq(x, 1))), {'ddd_y': 0} return Line.sample_boundary(100, sieve=(sp.Eq(x, 1))), {"ddd_y": 0}
@sc.datanode(name='infer') @sc.datanode(name="infer")
class Infer(sc.SampleDomain): class Infer(sc.SampleDomain):
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
return {'x': np.linspace(0, 1, 1000).reshape(-1, 1)}, {} return {"x": np.linspace(0, 1, 1000).reshape(-1, 1)}, {}
net = sc.get_net_node(inputs=('x',), outputs=('y',), name='net', arch=sc.Arch.mlp) net = sc.get_net_node(inputs=("x",), outputs=("y",), name="net", arch=sc.Arch.mlp)
pde1 = sc.ExpressionNode(name='dddd_y', expression=y.diff(x).diff(x).diff(x).diff(x) + 1) pde1 = sc.ExpressionNode(
pde2 = sc.ExpressionNode(name='d_y', expression=y.diff(x)) name="dddd_y", expression=y.diff(x).diff(x).diff(x).diff(x) + 1
pde3 = sc.ExpressionNode(name='dd_y', expression=y.diff(x).diff(x)) )
pde4 = sc.ExpressionNode(name='ddd_y', expression=y.diff(x).diff(x).diff(x)) pde2 = sc.ExpressionNode(name="d_y", expression=y.diff(x))
pde3 = sc.ExpressionNode(name="dd_y", expression=y.diff(x).diff(x))
pde4 = sc.ExpressionNode(name="ddd_y", expression=y.diff(x).diff(x).diff(x))
solver = sc.Solver( solver = sc.Solver(
sample_domains=(Interior(), LeftBoundary1(), LeftBoundary2(), RightBoundary1(), RightBoundary2()), sample_domains=(
Interior(),
LeftBoundary1(),
LeftBoundary2(),
RightBoundary1(),
RightBoundary2(),
),
netnodes=[net], netnodes=[net],
pdes=[pde1, pde2, pde3, pde4], pdes=[pde1, pde2, pde3, pde4],
max_iter=2000) max_iter=2000,
)
solver.solve() solver.solve()
@ -65,14 +74,14 @@ def exact(x):
solver.sample_domains = (Infer(),) solver.sample_domains = (Infer(),)
points = solver.infer_step({'infer': ['x', 'y']}) points = solver.infer_step({"infer": ["x", "y"]})
xs = points['infer']['x'].detach().cpu().numpy().ravel() xs = points["infer"]["x"].detach().cpu().numpy().ravel()
y_pred = points['infer']['y'].detach().cpu().numpy().ravel() y_pred = points["infer"]["y"].detach().cpu().numpy().ravel()
plt.plot(xs, y_pred, label='Pred') plt.plot(xs, y_pred, label="Pred")
y_exact = exact(xs) y_exact = exact(xs)
plt.plot(xs, y_exact, label='Exact', linestyle='--') plt.plot(xs, y_exact, label="Exact", linestyle="--")
plt.legend() plt.legend()
plt.xlabel('x') plt.xlabel("x")
plt.ylabel('w') plt.ylabel("w")
plt.savefig('Euler_beam.png', dpi=300, bbox_inches='tight') plt.savefig("Euler_beam.png", dpi=300, bbox_inches="tight")
plt.show() plt.show()

View File

@ -10,104 +10,121 @@ import matplotlib.pyplot as plt
L = float(pi) L = float(pi)
geo = sc.Line1D(0, L) geo = sc.Line1D(0, L)
t_symbol = Symbol('t') t_symbol = Symbol("t")
x = Symbol('x') x = Symbol("x")
time_range = {t_symbol: (0, 2 * L)} time_range = {t_symbol: (0, 2 * L)}
c = 1.54 c = 1.54
external_filename = 'external_sample.csv' external_filename = "external_sample.csv"
def generate_observed_data(): def generate_observed_data():
if os.path.exists(external_filename): if os.path.exists(external_filename):
return return
points = geo.sample_interior(density=20, points = geo.sample_interior(
bounds={x: (0, L)}, density=20, bounds={x: (0, L)}, param_ranges=time_range, low_discrepancy=True
param_ranges=time_range, )
low_discrepancy=True) points["u"] = np.sin(points["x"]) * (
points['u'] = np.sin(points['x']) * (np.sin(c * points['t']) + np.cos(c * points['t'])) np.sin(c * points["t"]) + np.cos(c * points["t"])
points['u'][np.random.choice(len(points['u']), 10, replace=False)] = 3. )
points["u"][np.random.choice(len(points["u"]), 10, replace=False)] = 3.0
points = {k: v.ravel() for k, v in points.items()} points = {k: v.ravel() for k, v in points.items()}
points = pd.DataFrame.from_dict(points) points = pd.DataFrame.from_dict(points)
points.to_csv('external_sample.csv', index=False) points.to_csv("external_sample.csv", index=False)
generate_observed_data() generate_observed_data()
# @sc.datanode(name='wave_domain') # @sc.datanode(name='wave_domain')
@sc.datanode(name='wave_domain', loss_fn='L1') @sc.datanode(name="wave_domain", loss_fn="L1")
class WaveExternal(sc.SampleDomain): class WaveExternal(sc.SampleDomain):
def __init__(self): def __init__(self):
points = pd.read_csv('external_sample.csv') points = pd.read_csv("external_sample.csv")
self.points = {col: points[col].to_numpy().reshape(-1, 1) for col in points.columns} self.points = {
self.constraints = {'u': self.points.pop('u')} col: points[col].to_numpy().reshape(-1, 1) for col in points.columns
}
self.constraints = {"u": self.points.pop("u")}
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
return self.points, self.constraints return self.points, self.constraints
@sc.datanode(name='wave_external') @sc.datanode(name="wave_external")
class WaveEq(sc.SampleDomain): class WaveEq(sc.SampleDomain):
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
points = geo.sample_interior(density=1000, bounds={x: (0, L)}, param_ranges=time_range) points = geo.sample_interior(
constraints = {'wave_equation': 0.} density=1000, bounds={x: (0, L)}, param_ranges=time_range
)
constraints = {"wave_equation": 0.0}
return points, constraints return points, constraints
@sc.datanode(name='center_infer') @sc.datanode(name="center_infer")
class CenterInfer(sc.SampleDomain): class CenterInfer(sc.SampleDomain):
def __init__(self): def __init__(self):
self.points = sc.Variables() self.points = sc.Variables()
self.points['t'] = np.linspace(0, 2 * L, 200).reshape(-1, 1) self.points["t"] = np.linspace(0, 2 * L, 200).reshape(-1, 1)
self.points['x'] = np.ones_like(self.points['t']) * L / 2 self.points["x"] = np.ones_like(self.points["t"]) * L / 2
self.points['area'] = np.ones_like(self.points['t']) self.points["area"] = np.ones_like(self.points["t"])
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
return self.points, {} return self.points, {}
net = sc.get_net_node(inputs=('x', 't',), outputs=('u',), name='net1', arch=sc.Arch.mlp) net = sc.get_net_node(
var_c = sc.get_net_node(inputs=('x',), outputs=('c',), arch=sc.Arch.single_var) inputs=(
pde = sc.WaveNode(c='c', dim=1, time=True, u='u') "x",
s = sc.Solver(sample_domains=(WaveExternal(), WaveEq()), "t",
netnodes=[net, var_c], ),
pdes=[pde], outputs=("u",),
# network_dir='square_network_dir', name="net1",
network_dir='network_dir', arch=sc.Arch.mlp,
max_iter=5000) )
var_c = sc.get_net_node(inputs=("x",), outputs=("c",), arch=sc.Arch.single_var)
pde = sc.WaveNode(c="c", dim=1, time=True, u="u")
s = sc.Solver(
sample_domains=(WaveExternal(), WaveEq()),
netnodes=[net, var_c],
pdes=[pde],
# network_dir='square_network_dir',
network_dir="network_dir",
max_iter=5000,
)
s.solve() s.solve()
_, ax = plt.subplots(1, 1, figsize=(8, 4)) _, ax = plt.subplots(1, 1, figsize=(8, 4))
coord = s.infer_step(domain_attr={'wave_domain': ['x', 't', 'u']}) coord = s.infer_step(domain_attr={"wave_domain": ["x", "t", "u"]})
num_t = coord['wave_domain']['t'].cpu().detach().numpy().ravel() num_t = coord["wave_domain"]["t"].cpu().detach().numpy().ravel()
num_u = coord['wave_domain']['u'].cpu().detach().numpy().ravel() num_u = coord["wave_domain"]["u"].cpu().detach().numpy().ravel()
ax.scatter(num_t, num_u, c='r', marker='o', label='predicted points') ax.scatter(num_t, num_u, c="r", marker="o", label="predicted points")
print("true paratmeter c: {:.4f}".format(c)) print("true paratmeter c: {:.4f}".format(c))
predict_c = var_c.evaluate(torch.Tensor([[1.0]])).item() predict_c = var_c.evaluate(torch.Tensor([[1.0]])).item()
print("predicted parameter c: {:.4f}".format(predict_c)) print("predicted parameter c: {:.4f}".format(predict_c))
num_t = WaveExternal().sample_fn.points['t'].ravel() num_t = WaveExternal().sample_fn.points["t"].ravel()
num_u = WaveExternal().sample_fn.constraints['u'].ravel() num_u = WaveExternal().sample_fn.constraints["u"].ravel()
ax.scatter(num_t, num_u, c='b', marker='x', label='observed points') ax.scatter(num_t, num_u, c="b", marker="x", label="observed points")
s.sample_domains = (CenterInfer(),) s.sample_domains = (CenterInfer(),)
points = s.infer_step({'center_infer': ['t', 'x', 'u']}) points = s.infer_step({"center_infer": ["t", "x", "u"]})
num_t = points['center_infer']['t'].cpu().detach().numpy().ravel() num_t = points["center_infer"]["t"].cpu().detach().numpy().ravel()
num_u = points['center_infer']['u'].cpu().detach().numpy().ravel() num_u = points["center_infer"]["u"].cpu().detach().numpy().ravel()
num_x = points['center_infer']['x'].cpu().detach().numpy().ravel() num_x = points["center_infer"]["x"].cpu().detach().numpy().ravel()
ax.plot(num_t, np.sin(num_x) * (np.sin(c * num_t) + np.cos(c * num_t)), c='k', label='exact') ax.plot(
ax.plot(num_t, num_u, '--', c='g', linewidth=4, label='predict') num_t, np.sin(num_x) * (np.sin(c * num_t) + np.cos(c * num_t)), c="k", label="exact"
)
ax.plot(num_t, num_u, "--", c="g", linewidth=4, label="predict")
ax.legend() ax.legend()
ax.set_xlabel('t') ax.set_xlabel("t")
ax.set_ylabel('u') ax.set_ylabel("u")
# ax.set_title(f'Square loss ($x=0.5L$, c={predict_c:.4f}))') # ax.set_title(f'Square loss ($x=0.5L$, c={predict_c:.4f}))')
ax.set_title(f'L1 loss ($x=0.5L$, c={predict_c:.4f})') ax.set_title(f"L1 loss ($x=0.5L$, c={predict_c:.4f})")
ax.grid(True) ax.grid(True)
ax.set_xlim([-0.5, 6.5]) ax.set_xlim([-0.5, 6.5])
ax.set_ylim([-3.5, 4.5]) ax.set_ylim([-3.5, 4.5])
# plt.savefig('square.png', dpi=1000, bbox_inches='tight', pad_inches=0.02) # plt.savefig('square.png', dpi=1000, bbox_inches='tight', pad_inches=0.02)
plt.savefig('L1.png', dpi=1000, bbox_inches='tight', pad_inches=0.02) plt.savefig("L1.png", dpi=1000, bbox_inches="tight", pad_inches=0.02)
plt.show() plt.show()
plt.close() plt.close()

View File

@ -9,26 +9,30 @@ import math
import idrlnet.shortcut as sc import idrlnet.shortcut as sc
x = sp.Symbol('x') x = sp.Symbol("x")
u = sp.Function('u')(x) u = sp.Function("u")(x)
geo = sc.Line1D(-1, 0.5) geo = sc.Line1D(-1, 0.5)
@sc.datanode(sigma=1000.) @sc.datanode(sigma=1000.0)
class Boundary(sc.SampleDomain): class Boundary(sc.SampleDomain):
def __init__(self): def __init__(self):
self.points = geo.sample_boundary(1, ) self.points = geo.sample_boundary(
self.constraints = {'u': np.cosh(self.points['x'])} 1,
)
self.constraints = {"u": np.cosh(self.points["x"])}
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
return self.points, self.constraints return self.points, self.constraints
@sc.datanode(loss_fn='L1') @sc.datanode(loss_fn="L1")
class Interior(sc.SampleDomain): class Interior(sc.SampleDomain):
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
points = geo.sample_interior(10000) points = geo.sample_interior(10000)
constraints = {'integral_dx': 0, } constraints = {
"integral_dx": 0,
}
return points, constraints return points, constraints
@ -36,8 +40,8 @@ class Interior(sc.SampleDomain):
class InteriorInfer(sc.SampleDomain): class InteriorInfer(sc.SampleDomain):
def __init__(self): def __init__(self):
self.points = sc.Variables() self.points = sc.Variables()
self.points['x'] = np.linspace(-1, 0.5, 1001, endpoint=True).reshape(-1, 1) self.points["x"] = np.linspace(-1, 0.5, 1001, endpoint=True).reshape(-1, 1)
self.points['area'] = np.ones_like(self.points['x']) self.points["area"] = np.ones_like(self.points["x"])
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
return self.points, {} return self.points, {}
@ -46,8 +50,8 @@ class InteriorInfer(sc.SampleDomain):
# plot Intermediate results # plot Intermediate results
class PlotReceiver(sc.Receiver): class PlotReceiver(sc.Receiver):
def __init__(self): def __init__(self):
if not os.path.exists('plot'): if not os.path.exists("plot"):
os.mkdir('plot') os.mkdir("plot")
xx = np.linspace(-1, 0.5, 1001, endpoint=True) xx = np.linspace(-1, 0.5, 1001, endpoint=True)
self.xx = xx self.xx = xx
angle = np.linspace(0, math.pi * 2, 100) angle = np.linspace(0, math.pi * 2, 100)
@ -58,28 +62,30 @@ class PlotReceiver(sc.Receiver):
zz_mesh = yy * np.sin(angle_mesh) zz_mesh = yy * np.sin(angle_mesh)
fig = plt.figure(figsize=(8, 8)) fig = plt.figure(figsize=(8, 8))
ax = fig.gca(projection='3d') ax = fig.gca(projection="3d")
ax.set_zlim3d(-1.25 - 1, 0.75 + 1) ax.set_zlim3d(-1.25 - 1, 0.75 + 1)
ax.set_ylim3d(-2, 2) ax.set_ylim3d(-2, 2)
ax.set_xlim3d(-2, 2) ax.set_xlim3d(-2, 2)
my_col = cm.cool((yy * np.ones_like(angle_mesh) - 1.0) / 0.6) my_col = cm.cool((yy * np.ones_like(angle_mesh) - 1.0) / 0.6)
ax.plot_surface(yy_mesh, zz_mesh, xx_mesh, facecolors=my_col) ax.plot_surface(yy_mesh, zz_mesh, xx_mesh, facecolors=my_col)
ax.view_init(elev=15., azim=0) ax.view_init(elev=15.0, azim=0)
ax.dist = 5 ax.dist = 5
plt.axis('off') plt.axis("off")
plt.tight_layout(pad=0., w_pad=0., h_pad=.0) plt.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0)
plt.savefig(f'plot/p_exact.png') plt.savefig(f"plot/p_exact.png")
plt.show() plt.show()
plt.close() plt.close()
self.predict_history = [] self.predict_history = []
def receive_notify(self, obj: sc.Solver, message: Dict): def receive_notify(self, obj: sc.Solver, message: Dict):
if sc.Signal.SOLVE_START in message or (sc.Signal.TRAIN_PIPE_END in message and obj.global_step % 200 == 0): if sc.Signal.SOLVE_START in message or (
sc.Signal.TRAIN_PIPE_END in message and obj.global_step % 200 == 0
):
print("plotting") print("plotting")
points = s.infer_step({'InteriorInfer': ['x', 'u']}) points = s.infer_step({"InteriorInfer": ["x", "u"]})
num_x = points['InteriorInfer']['x'].detach().cpu().numpy().ravel() num_x = points["InteriorInfer"]["x"].detach().cpu().numpy().ravel()
num_u = points['InteriorInfer']['u'].detach().cpu().numpy().ravel() num_u = points["InteriorInfer"]["u"].detach().cpu().numpy().ravel()
angle = np.linspace(0, math.pi * 2, 100) angle = np.linspace(0, math.pi * 2, 100)
xx_mesh, angle_mesh = np.meshgrid(num_x, angle) xx_mesh, angle_mesh = np.meshgrid(num_x, angle)
@ -87,28 +93,28 @@ class PlotReceiver(sc.Receiver):
zz_mesh = num_u * np.sin(angle_mesh) zz_mesh = num_u * np.sin(angle_mesh)
fig = plt.figure(figsize=(8, 8)) fig = plt.figure(figsize=(8, 8))
ax = fig.gca(projection='3d') ax = fig.gca(projection="3d")
ax.set_zlim3d(-1.25 - 1, 0.75 + 1) ax.set_zlim3d(-1.25 - 1, 0.75 + 1)
ax.set_ylim3d(-2, 2) ax.set_ylim3d(-2, 2)
ax.set_xlim3d(-2, 2) ax.set_xlim3d(-2, 2)
my_col = cm.cool((num_u * np.ones_like(angle_mesh) - 1.0) / 0.6) my_col = cm.cool((num_u * np.ones_like(angle_mesh) - 1.0) / 0.6)
ax.plot_surface(yy_mesh, zz_mesh, xx_mesh, facecolors=my_col) ax.plot_surface(yy_mesh, zz_mesh, xx_mesh, facecolors=my_col)
ax.view_init(elev=15., azim=0) ax.view_init(elev=15.0, azim=0)
ax.dist = 5 ax.dist = 5
plt.axis('off') plt.axis("off")
plt.tight_layout(pad=0., w_pad=0., h_pad=.0) plt.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0)
plt.savefig(f'plot/p_{obj.global_step}.png') plt.savefig(f"plot/p_{obj.global_step}.png")
plt.show() plt.show()
plt.close() plt.close()
self.predict_history.append((num_u, obj.global_step)) self.predict_history.append((num_u, obj.global_step))
if sc.Signal.SOLVE_END in message: if sc.Signal.SOLVE_END in message:
try: try:
with open('result.pickle', 'rb') as f: with open("result.pickle", "rb") as f:
self.predict_history = pickle.load(f) self.predict_history = pickle.load(f)
except: except:
with open('result.pickle', 'wb') as f: with open("result.pickle", "wb") as f:
pickle.dump(self.predict_history, f) pickle.dump(self.predict_history, f)
for yy, step in self.predict_history: for yy, step in self.predict_history:
if step == 0: if step == 0:
@ -116,28 +122,35 @@ class PlotReceiver(sc.Receiver):
if step == 200: if step == 200:
plt.plot(yy, self.xx, label=f"iter={step}") plt.plot(yy, self.xx, label=f"iter={step}")
if step == 800: if step == 800:
plt.plot(yy[::100], self.xx[::100], '-o', label=f"iter={step}") plt.plot(yy[::100], self.xx[::100], "-o", label=f"iter={step}")
plt.plot(np.cosh(self.xx)[::100], self.xx[::100], '-x', label='exact') plt.plot(np.cosh(self.xx)[::100], self.xx[::100], "-x", label="exact")
plt.plot([0, np.cosh(-1)], [-1, -1], '--', color='gray') plt.plot([0, np.cosh(-1)], [-1, -1], "--", color="gray")
plt.plot([0, np.cosh(0.5)], [0.5, 0.5], '--', color='gray') plt.plot([0, np.cosh(0.5)], [0.5, 0.5], "--", color="gray")
plt.legend() plt.legend()
plt.xlim([0, 1.7]) plt.xlim([0, 1.7])
plt.xlabel('y') plt.xlabel("y")
plt.ylabel('x') plt.ylabel("x")
plt.savefig('iterations.png') plt.savefig("iterations.png")
plt.show() plt.show()
plt.close() plt.close()
dx_exp = sc.ExpressionNode(expression=sp.Abs(u) * sp.sqrt((u.diff(x)) ** 2 + 1), name='dx') dx_exp = sc.ExpressionNode(
net = sc.get_net_node(inputs=('x',), outputs=('u',), name='net', arch=sc.Arch.mlp) expression=sp.Abs(u) * sp.sqrt((u.diff(x)) ** 2 + 1), name="dx"
)
net = sc.get_net_node(inputs=("x",), outputs=("u",), name="net", arch=sc.Arch.mlp)
integral = sc.ICNode('dx', dim=1, time=False) integral = sc.ICNode("dx", dim=1, time=False)
s = sc.Solver(sample_domains=(Boundary(), Interior(), InteriorInfer()), s = sc.Solver(
netnodes=[net], sample_domains=(Boundary(), Interior(), InteriorInfer()),
init_network_dirs=['pretrain_network_dir'], netnodes=[net],
pdes=[dx_exp, integral, ], init_network_dirs=["pretrain_network_dir"],
max_iter=1500) pdes=[
dx_exp,
integral,
],
max_iter=1500,
)
s.register_receiver(PlotReceiver()) s.register_receiver(PlotReceiver())
s.solve() s.solve()

View File

@ -3,30 +3,34 @@ import numpy as np
import sympy as sp import sympy as sp
import idrlnet.shortcut as sc import idrlnet.shortcut as sc
x = sp.Symbol('x') x = sp.Symbol("x")
geo = sc.Line1D(-1, 0.5) geo = sc.Line1D(-1, 0.5)
@sc.datanode(loss_fn='L1') @sc.datanode(loss_fn="L1")
class Interior(sc.SampleDomain): class Interior(sc.SampleDomain):
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
points = geo.sample_interior(100) points = geo.sample_interior(100)
constraints = {'u': (np.cosh(0.5) - np.cosh(-1)) / 1.5 * (x + 1.0) + np.cosh(-1)} constraints = {
"u": (np.cosh(0.5) - np.cosh(-1)) / 1.5 * (x + 1.0) + np.cosh(-1)
}
return points, constraints return points, constraints
net = sc.get_net_node(inputs=('x',), outputs=('u',), name='net', arch=sc.Arch.mlp) net = sc.get_net_node(inputs=("x",), outputs=("u",), name="net", arch=sc.Arch.mlp)
s = sc.Solver(sample_domains=(Interior(),), s = sc.Solver(
netnodes=[net], sample_domains=(Interior(),),
pdes=[], netnodes=[net],
network_dir='pretrain_network_dir', pdes=[],
max_iter=1000) network_dir="pretrain_network_dir",
max_iter=1000,
)
s.solve() s.solve()
points = s.infer_step({'Interior': ['x', 'u']}) points = s.infer_step({"Interior": ["x", "u"]})
num_x = points['Interior']['x'].detach().cpu().numpy().ravel() num_x = points["Interior"]["x"].detach().cpu().numpy().ravel()
num_u = points['Interior']['u'].detach().cpu().numpy().ravel() num_u = points["Interior"]["u"].detach().cpu().numpy().ravel()
xx = np.linspace(-1, 0.5, 1000, endpoint=True) xx = np.linspace(-1, 0.5, 1000, endpoint=True)
yy = np.cosh(xx) yy = np.cosh(xx)

View File

@ -4,18 +4,20 @@ import matplotlib.pyplot as plt
import matplotlib.tri as tri import matplotlib.tri as tri
import numpy as np import numpy as np
x, y = sp.symbols('x y') x, y = sp.symbols("x y")
temp = sp.Symbol('temp') temp = sp.Symbol("temp")
temp_range = {temp: (-0.2, 0.2)} temp_range = {temp: (-0.2, 0.2)}
rec = sc.Rectangle((-1., -1.), (1., 1.)) rec = sc.Rectangle((-1.0, -1.0), (1.0, 1.0))
@sc.datanode @sc.datanode
class Right(sc.SampleDomain): class Right(sc.SampleDomain):
# Due to `name` is not specified, Right will be the name of datanode automatically # Due to `name` is not specified, Right will be the name of datanode automatically
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
points = rec.sample_boundary(1000, sieve=(sp.Eq(x, 1.)), param_ranges=temp_range) points = rec.sample_boundary(
constraints = sc.Variables({'T': 0.}) 1000, sieve=(sp.Eq(x, 1.0)), param_ranges=temp_range
)
constraints = sc.Variables({"T": 0.0})
return points, constraints return points, constraints
@ -23,16 +25,20 @@ class Right(sc.SampleDomain):
class Left(sc.SampleDomain): class Left(sc.SampleDomain):
# Due to `name` is not specified, Left will be the name of datanode automatically # Due to `name` is not specified, Left will be the name of datanode automatically
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
points = rec.sample_boundary(1000, sieve=(sp.Eq(x, -1.)), param_ranges=temp_range) points = rec.sample_boundary(
constraints = sc.Variables({'T': temp}) 1000, sieve=(sp.Eq(x, -1.0)), param_ranges=temp_range
)
constraints = sc.Variables({"T": temp})
return points, constraints return points, constraints
@sc.datanode(name="up_down") @sc.datanode(name="up_down")
class UpDownBoundaryDomain(sc.SampleDomain): class UpDownBoundaryDomain(sc.SampleDomain):
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
points = rec.sample_boundary(1000, sieve=((x > -1.) & (x < 1.)), param_ranges=temp_range) points = rec.sample_boundary(
constraints = sc.Variables({'normal_gradient_T': 0.}) 1000, sieve=((x > -1.0) & (x < 1.0)), param_ranges=temp_range
)
constraints = sc.Variables({"normal_gradient_T": 0.0})
return points, constraints return points, constraints
@ -43,47 +49,53 @@ class HeatDomain(sc.SampleDomain):
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
points = rec.sample_interior(self.points, param_ranges=temp_range) points = rec.sample_interior(self.points, param_ranges=temp_range)
constraints = sc.Variables({'diffusion_T': 1.}) constraints = sc.Variables({"diffusion_T": 1.0})
return points, constraints return points, constraints
net = sc.get_net_node(inputs=('x', 'y', 'temp'), outputs=('T',), name='net1', arch=sc.Arch.mlp) net = sc.get_net_node(
pde = sc.DiffusionNode(T='T', D=1., Q=0., dim=2, time=False) inputs=("x", "y", "temp"), outputs=("T",), name="net1", arch=sc.Arch.mlp
grad = sc.NormalGradient('T', dim=2, time=False) )
s = sc.Solver(sample_domains=(HeatDomain(), Left(), Right(), UpDownBoundaryDomain()), pde = sc.DiffusionNode(T="T", D=1.0, Q=0.0, dim=2, time=False)
netnodes=[net], grad = sc.NormalGradient("T", dim=2, time=False)
pdes=[pde, grad], s = sc.Solver(
max_iter=3000) sample_domains=(HeatDomain(), Left(), Right(), UpDownBoundaryDomain()),
netnodes=[net],
pdes=[pde, grad],
max_iter=3000,
)
s.solve() s.solve()
def infer_temp(temp_num, file_suffix=None): def infer_temp(temp_num, file_suffix=None):
temp_range[temp] = temp_num temp_range[temp] = temp_num
s.set_domain_parameter('heat_domain', {'points': 10000}) s.set_domain_parameter("heat_domain", {"points": 10000})
coord = s.infer_step({'heat_domain': ['x', 'y', 'T']}) coord = s.infer_step({"heat_domain": ["x", "y", "T"]})
num_x = coord['heat_domain']['x'].cpu().detach().numpy().ravel() num_x = coord["heat_domain"]["x"].cpu().detach().numpy().ravel()
num_y = coord['heat_domain']['y'].cpu().detach().numpy().ravel() num_y = coord["heat_domain"]["y"].cpu().detach().numpy().ravel()
num_Tp = coord['heat_domain']['T'].cpu().detach().numpy().ravel() num_Tp = coord["heat_domain"]["T"].cpu().detach().numpy().ravel()
# Ground truth # Ground truth
num_T = -(num_x + 1 + temp_num) * (num_x - 1.) / 2 num_T = -(num_x + 1 + temp_num) * (num_x - 1.0) / 2
fig, ax = plt.subplots(1, 3, figsize=(10, 3)) fig, ax = plt.subplots(1, 3, figsize=(10, 3))
triang_total = tri.Triangulation(num_x, num_y) triang_total = tri.Triangulation(num_x, num_y)
ax[0].tricontourf(triang_total, num_Tp, 100, cmap='hot', vmin=-0.2, vmax=1.21 / 2) ax[0].tricontourf(triang_total, num_Tp, 100, cmap="hot", vmin=-0.2, vmax=1.21 / 2)
ax[0].axis('off') ax[0].axis("off")
ax[0].set_title(f'prediction($T_l={temp_num:.2f}$)') ax[0].set_title(f"prediction($T_l={temp_num:.2f}$)")
ax[1].tricontourf(triang_total, num_T, 100, cmap='hot', vmin=-0.2, vmax=1.21 / 2) ax[1].tricontourf(triang_total, num_T, 100, cmap="hot", vmin=-0.2, vmax=1.21 / 2)
ax[1].axis('off') ax[1].axis("off")
ax[1].set_title(f'ground truth($T_l={temp_num:.2f}$)') ax[1].set_title(f"ground truth($T_l={temp_num:.2f}$)")
ax[2].tricontourf(triang_total, np.abs(num_T - num_Tp), 100, cmap='hot', vmin=0, vmax=1.21 / 2) ax[2].tricontourf(
ax[2].axis('off') triang_total, np.abs(num_T - num_Tp), 100, cmap="hot", vmin=0, vmax=1.21 / 2
ax[2].set_title('absolute error') )
ax[2].axis("off")
ax[2].set_title("absolute error")
if file_suffix is None: if file_suffix is None:
plt.savefig(f'poisson_{temp_num:.2f}.png', dpi=300, bbox_inches='tight') plt.savefig(f"poisson_{temp_num:.2f}.png", dpi=300, bbox_inches="tight")
plt.show() plt.show()
else: else:
plt.savefig(f'poisson_{file_suffix}.png', dpi=300, bbox_inches='tight') plt.savefig(f"poisson_{file_suffix}.png", dpi=300, bbox_inches="tight")
plt.show() plt.show()

View File

@ -4,24 +4,24 @@ import matplotlib.pyplot as plt
import matplotlib.tri as tri import matplotlib.tri as tri
import numpy as np import numpy as np
x, y = sp.symbols('x y') x, y = sp.symbols("x y")
rec = sc.Rectangle((-1., -1.), (1., 1.)) rec = sc.Rectangle((-1.0, -1.0), (1.0, 1.0))
@sc.datanode @sc.datanode
class LeftRight(sc.SampleDomain): class LeftRight(sc.SampleDomain):
# Due to `name` is not specified, LeftRight will be the name of datanode automatically # Due to `name` is not specified, LeftRight will be the name of datanode automatically
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
points = rec.sample_boundary(1000, sieve=((y > -1.) & (y < 1.))) points = rec.sample_boundary(1000, sieve=((y > -1.0) & (y < 1.0)))
constraints = {'T': 0.} constraints = {"T": 0.0}
return points, constraints return points, constraints
@sc.datanode(name="up_down") @sc.datanode(name="up_down")
class UpDownBoundaryDomain(sc.SampleDomain): class UpDownBoundaryDomain(sc.SampleDomain):
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
points = rec.sample_boundary(1000, sieve=((x > -1.) & (x < 1.))) points = rec.sample_boundary(1000, sieve=((x > -1.0) & (x < 1.0)))
constraints = {'normal_gradient_T': 0.} constraints = {"normal_gradient_T": 0.0}
return points, constraints return points, constraints
@ -32,39 +32,51 @@ class HeatDomain(sc.SampleDomain):
def sampling(self, *args, **kwargs): def sampling(self, *args, **kwargs):
points = rec.sample_interior(self.points) points = rec.sample_interior(self.points)
constraints = {'diffusion_T': 1.} constraints = {"diffusion_T": 1.0}
return points, constraints return points, constraints
net = sc.get_net_node(inputs=('x', 'y',), outputs=('T',), name='net1', arch=sc.Arch.mlp) net = sc.get_net_node(
pde = sc.DiffusionNode(T='T', D=1., Q=0., dim=2, time=False) inputs=(
grad = sc.NormalGradient('T', dim=2, time=False) "x",
s = sc.Solver(sample_domains=(HeatDomain(), LeftRight(), UpDownBoundaryDomain()), "y",
netnodes=[net], ),
pdes=[pde, grad], outputs=("T",),
max_iter=1000) name="net1",
arch=sc.Arch.mlp,
)
pde = sc.DiffusionNode(T="T", D=1.0, Q=0.0, dim=2, time=False)
grad = sc.NormalGradient("T", dim=2, time=False)
s = sc.Solver(
sample_domains=(HeatDomain(), LeftRight(), UpDownBoundaryDomain()),
netnodes=[net],
pdes=[pde, grad],
max_iter=1000,
)
s.solve() s.solve()
# Inference # Inference
s.set_domain_parameter('heat_domain', {'points': 10000}) s.set_domain_parameter("heat_domain", {"points": 10000})
coord = s.infer_step({'heat_domain': ['x', 'y', 'T']}) coord = s.infer_step({"heat_domain": ["x", "y", "T"]})
num_x = coord['heat_domain']['x'].cpu().detach().numpy().ravel() num_x = coord["heat_domain"]["x"].cpu().detach().numpy().ravel()
num_y = coord['heat_domain']['y'].cpu().detach().numpy().ravel() num_y = coord["heat_domain"]["y"].cpu().detach().numpy().ravel()
num_Tp = coord['heat_domain']['T'].cpu().detach().numpy().ravel() num_Tp = coord["heat_domain"]["T"].cpu().detach().numpy().ravel()
# Ground truth # Ground truth
num_T = -num_x * num_x / 2 + 0.5 num_T = -num_x * num_x / 2 + 0.5
fig, ax = plt.subplots(1, 3, figsize=(10, 3)) fig, ax = plt.subplots(1, 3, figsize=(10, 3))
triang_total = tri.Triangulation(num_x, num_y) triang_total = tri.Triangulation(num_x, num_y)
ax[0].tricontourf(triang_total, num_Tp, 100, cmap='hot', vmin=0, vmax=0.5) ax[0].tricontourf(triang_total, num_Tp, 100, cmap="hot", vmin=0, vmax=0.5)
ax[0].axis('off') ax[0].axis("off")
ax[0].set_title('prediction') ax[0].set_title("prediction")
ax[1].tricontourf(triang_total, num_T, 100, cmap='hot', vmin=0, vmax=0.5) ax[1].tricontourf(triang_total, num_T, 100, cmap="hot", vmin=0, vmax=0.5)
ax[1].axis('off') ax[1].axis("off")
ax[1].set_title('ground truth') ax[1].set_title("ground truth")
ax[2].tricontourf(triang_total, np.abs(num_T - num_Tp), 100, cmap='hot', vmin=0, vmax=0.5) ax[2].tricontourf(
ax[2].axis('off') triang_total, np.abs(num_T - num_Tp), 100, cmap="hot", vmin=0, vmax=0.5
ax[2].set_title('absolute error') )
ax[2].axis("off")
ax[2].set_title("absolute error")
plt.savefig('simple_poisson.png', dpi=300, bbox_inches='tight') plt.savefig("simple_poisson.png", dpi=300, bbox_inches="tight")

View File

@ -1,15 +1,16 @@
import torch import torch
# todo more careful check # todo more careful check
GPU_ENABLED = True GPU_ENABLED = True
if torch.cuda.is_available(): if torch.cuda.is_available():
try: try:
_ = torch.Tensor([0., 0.]).cuda() _ = torch.Tensor([0.0, 0.0]).cuda()
torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.set_default_tensor_type("torch.cuda.FloatTensor")
print('gpu available') print("gpu available")
GPU_ENABLED = True GPU_ENABLED = True
except: except:
print('gpu not available') print("gpu not available")
GPU_ENABLED = False GPU_ENABLED = False
else: else:
print('gpu not available') print("gpu not available")
GPU_ENABLED = False GPU_ENABLED = False

View File

@ -15,14 +15,28 @@ def indicator(xn: torch.Tensor, *axis_bounds):
i = 0 i = 0
lb, ub, lb_eq = axis_bounds[0] lb, ub, lb_eq = axis_bounds[0]
if lb_eq: if lb_eq:
indic = torch.logical_and(xn[:, i:i + 1] >= axis_bounds[0][0], axis_bounds[0][1] >= xn[:, i:i + 1]) indic = torch.logical_and(
xn[:, i : i + 1] >= axis_bounds[0][0], axis_bounds[0][1] >= xn[:, i : i + 1]
)
else: else:
indic = torch.logical_and(xn[:, i:i + 1] > axis_bounds[0][0], axis_bounds[0][1] >= xn[:, i:i + 1]) indic = torch.logical_and(
xn[:, i : i + 1] > axis_bounds[0][0], axis_bounds[0][1] >= xn[:, i : i + 1]
)
for i, (lb, ub, lb_eq) in enumerate(axis_bounds[1:]): for i, (lb, ub, lb_eq) in enumerate(axis_bounds[1:]):
if lb_eq: if lb_eq:
indic = torch.logical_and(indic, torch.logical_and(xn[:, i + 1:i + 2] >= lb, ub >= xn[:, i + 1:i + 2])) indic = torch.logical_and(
indic,
torch.logical_and(
xn[:, i + 1 : i + 2] >= lb, ub >= xn[:, i + 1 : i + 2]
),
)
else: else:
indic = torch.logical_and(indic, torch.logical_and(xn[:, i + 1:i + 2] > lb, ub >= xn[:, i + 1:i + 2])) indic = torch.logical_and(
indic,
torch.logical_and(
xn[:, i + 1 : i + 2] > lb, ub >= xn[:, i + 1 : i + 2]
),
)
return indic return indic
@ -34,8 +48,8 @@ class NetEval(torch.nn.Module):
self.n_columns = len(self.columns) - 1 self.n_columns = len(self.columns) - 1
self.n_rows = len(self.rows) - 1 self.n_rows = len(self.rows) - 1
self.nets = [] self.nets = []
if 'net_generator' in kwargs.keys(): if "net_generator" in kwargs.keys():
net_gen = kwargs.pop('net_generator') net_gen = kwargs.pop("net_generator")
else: else:
net_gen = lambda: mlp.MLP([n_inputs, 20, 20, 20, 20, n_outputs]) net_gen = lambda: mlp.MLP([n_inputs, 20, 20, 20, 20, n_outputs])
for i in range(self.n_columns): for i in range(self.n_columns):
@ -50,8 +64,18 @@ class NetEval(torch.nn.Module):
y = 0 y = 0
for i in range(self.n_columns): for i in range(self.n_columns):
for j in range(self.n_rows): for j in range(self.n_rows):
y += indicator(xn, (self.columns[i], self.columns[i + 1], True if i == 0 else False), y += (
(self.rows[j], self.rows[j + 1], True if j == 0 else False)) * self.nets[i][j](x) indicator(
xn,
(
self.columns[i],
self.columns[i + 1],
True if i == 0 else False,
),
(self.rows[j], self.rows[j + 1], True if j == 0 else False),
)
* self.nets[i][j](x)
)
return y return y
@ -59,7 +83,10 @@ class Interface:
def __init__(self, points1, points2, nr, outputs, i1, j1, i2, j2, overlap=0.2): def __init__(self, points1, points2, nr, outputs, i1, j1, i2, j2, overlap=0.2):
x_min, x_max = min(points1[0], points2[0]), max(points1[0], points2[0]) x_min, x_max = min(points1[0], points2[0]), max(points1[0], points2[0])
y_min, y_max = min(points1[1], points2[1]), max(points1[1], points2[1]) y_min, y_max = min(points1[1], points2[1]), max(points1[1], points2[1])
self.geo = Rectangle((x_min - overlap / 2, y_min - overlap / 2), (x_max + overlap / 2, y_max + overlap / 2)) self.geo = Rectangle(
(x_min - overlap / 2, y_min - overlap / 2),
(x_max + overlap / 2, y_max + overlap / 2),
)
self.nr = nr self.nr = nr
self.outputs = outputs self.outputs = outputs
self.i1 = i1 self.i1 = i1
@ -69,16 +96,26 @@ class Interface:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
points = self.geo.sample_boundary(self.nr) points = self.geo.sample_boundary(self.nr)
return points, {f'difference_{output}_{self.i1}_{self.j1}_{output}_{self.i2}_{self.j2}': 0 return points, {
for output in self.outputs} f"difference_{output}_{self.i1}_{self.j1}_{output}_{self.i2}_{self.j2}": 0
for output in self.outputs
}
class NetGridNode(NetNode): class NetGridNode(NetNode):
def __init__(self, inputs: Union[Tuple, List[str]], outputs: Union[Tuple, List[str]], def __init__(
x_segments: List[float] = None, y_segments: List[float] = None, self,
z_segments: List[float] = None, t_segments: List[float] = None, columns: List[float] = None, inputs: Union[Tuple, List[str]],
rows: List[float] = None, *args, outputs: Union[Tuple, List[str]],
**kwargs): x_segments: List[float] = None,
y_segments: List[float] = None,
z_segments: List[float] = None,
t_segments: List[float] = None,
columns: List[float] = None,
rows: List[float] = None,
*args,
**kwargs,
):
if columns is None: if columns is None:
columns = [] columns = []
if rows is None: if rows is None:
@ -87,8 +124,16 @@ class NetGridNode(NetNode):
fixed = False fixed = False
self.columns = columns self.columns = columns
self.rows = rows self.rows = rows
self.main_net = NetEval(n_inputs=len(inputs), n_outputs=len(outputs), columns=columns, rows=rows, **kwargs) self.main_net = NetEval(
super(NetGridNode, self).__init__(inputs, outputs, self.main_net, fixed, require_no_grad, *args, **kwargs) n_inputs=len(inputs),
n_outputs=len(outputs),
columns=columns,
rows=rows,
**kwargs,
)
super(NetGridNode, self).__init__(
inputs, outputs, self.main_net, fixed, require_no_grad, *args, **kwargs
)
def get_grid(self, overlap, nr_points_per_interface_area=100): def get_grid(self, overlap, nr_points_per_interface_area=100):
n_columns = self.main_net.n_columns n_columns = self.main_net.n_columns
@ -98,54 +143,119 @@ class NetGridNode(NetNode):
constraints = [] constraints = []
for i in range(n_columns): for i in range(n_columns):
for j in range(n_rows): for j in range(n_rows):
nn = NetNode(inputs=self.inputs, nn = NetNode(
outputs=tuple(f'{output}_{i}_{j}' for output in self.outputs), inputs=self.inputs,
net=self.main_net.nets[i][j], outputs=tuple(f"{output}_{i}_{j}" for output in self.outputs),
name=f'{self.name}[{i}][{j}]') net=self.main_net.nets[i][j],
name=f"{self.name}[{i}][{j}]",
)
nn.is_reference = True nn.is_reference = True
netnodes.append(nn) netnodes.append(nn)
if i > 0: if i > 0:
for output in self.outputs: for output in self.outputs:
diff_Node = Difference(f'{output}_{i - 1}_{j}', f'{output}_{i}_{j}', dim=2, time=False) diff_Node = Difference(
f"{output}_{i - 1}_{j}",
f"{output}_{i}_{j}",
dim=2,
time=False,
)
eqs.append(diff_Node) eqs.append(diff_Node)
interface = Interface((self.columns[i], self.rows[j]), (self.columns[i], self.rows[j + 1]), interface = Interface(
nr_points_per_interface_area, self.outputs, i - 1, j, i, j, overlap=overlap) (self.columns[i], self.rows[j]),
(self.columns[i], self.rows[j + 1]),
nr_points_per_interface_area,
self.outputs,
i - 1,
j,
i,
j,
overlap=overlap,
)
constraints.append(get_data_node(interface, name=f'interface[{i - 1}][{j}]_[{i}][{j}]')) constraints.append(
get_data_node(
interface, name=f"interface[{i - 1}][{j}]_[{i}][{j}]"
)
)
if j > 0: if j > 0:
for output in self.outputs: for output in self.outputs:
diff_Node = Difference(f'{output}_{i}_{j - 1}', f'{output}_{i}_{j}', dim=2, time=False) diff_Node = Difference(
f"{output}_{i}_{j - 1}",
f"{output}_{i}_{j}",
dim=2,
time=False,
)
eqs.append(diff_Node) eqs.append(diff_Node)
interface = Interface((self.columns[i], self.rows[j]), (self.columns[i + 1], self.rows[j]), interface = Interface(
nr_points_per_interface_area, self.outputs, i, j - 1, i, j, overlap=overlap) (self.columns[i], self.rows[j]),
(self.columns[i + 1], self.rows[j]),
nr_points_per_interface_area,
self.outputs,
i,
j - 1,
i,
j,
overlap=overlap,
)
constraints.append(get_data_node(interface, name=f'interface[{i}][{j - 1}]_[{i}][{j}]')) constraints.append(
get_data_node(
interface, name=f"interface[{i}][{j - 1}]_[{i}][{j}]"
)
)
return netnodes, eqs, constraints return netnodes, eqs, constraints
def get_net_reg_grid_2d(inputs: Union[Tuple, List[str]], outputs: Union[Tuple, List[str]], name: str, def get_net_reg_grid_2d(
columns: List[float], rows: List[float], **kwargs): inputs: Union[Tuple, List[str]],
if 'overlap' in kwargs.keys(): outputs: Union[Tuple, List[str]],
overlap = kwargs.pop('overlap') name: str,
columns: List[float],
rows: List[float],
**kwargs,
):
if "overlap" in kwargs.keys():
overlap = kwargs.pop("overlap")
else: else:
overlap = 0.2 overlap = 0.2
net = NetGridNode(inputs=inputs, outputs=outputs, columns=columns, rows=rows, name=name, **kwargs) net = NetGridNode(
nets, eqs, interfaces = net.get_grid(nr_points_per_interface_area=1000, overlap=overlap) inputs=inputs, outputs=outputs, columns=columns, rows=rows, name=name, **kwargs
)
nets, eqs, interfaces = net.get_grid(
nr_points_per_interface_area=1000, overlap=overlap
)
nets.append(net) nets.append(net)
return nets, eqs, interfaces return nets, eqs, interfaces
def get_net_reg_grid(inputs: Union[Tuple, List[str]], outputs: Union[Tuple, List[str]], name: str, def get_net_reg_grid(
x_segments: List[float] = None, y_segments: List[float] = None, z_segments: List[float] = None, inputs: Union[Tuple, List[str]],
t_segments: List[float] = None, **kwargs): outputs: Union[Tuple, List[str]],
if 'overlap' in kwargs.keys(): name: str,
overlap = kwargs.pop('overlap') x_segments: List[float] = None,
y_segments: List[float] = None,
z_segments: List[float] = None,
t_segments: List[float] = None,
**kwargs,
):
if "overlap" in kwargs.keys():
overlap = kwargs.pop("overlap")
else: else:
overlap = 0.2 overlap = 0.2
net = NetGridNode(inputs=inputs, outputs=outputs, x_segments=x_segments, y_segments=y_segments, net = NetGridNode(
z_segments=z_segments, t_segments=t_segments, name=name, **kwargs) inputs=inputs,
nets, eqs, interfaces = net.get_grid(nr_points_per_interface_area=1000, overlap=overlap) outputs=outputs,
x_segments=x_segments,
y_segments=y_segments,
z_segments=z_segments,
t_segments=t_segments,
name=name,
**kwargs,
)
nets, eqs, interfaces = net.get_grid(
nr_points_per_interface_area=1000, overlap=overlap
)
nets.append(net) nets.append(net)
return nets, eqs, interfaces return nets, eqs, interfaces

View File

@ -5,35 +5,40 @@ import math
import torch import torch
from idrlnet.header import logger from idrlnet.header import logger
__all__ = ['Activation', 'Initializer', 'get_activation_layer', 'get_linear_layer'] __all__ = ["Activation", "Initializer", "get_activation_layer", "get_linear_layer"]
class Activation(enum.Enum): class Activation(enum.Enum):
relu = 'relu' relu = "relu"
silu = 'silu' silu = "silu"
selu = 'selu' selu = "selu"
sigmoid = 'sigmoid' sigmoid = "sigmoid"
tanh = 'tanh' tanh = "tanh"
swish = 'swish' swish = "swish"
poly = 'poly' poly = "poly"
sin = 'sin' sin = "sin"
leaky_relu = 'leaky_relu' leaky_relu = "leaky_relu"
class Initializer(enum.Enum): class Initializer(enum.Enum):
Xavier_uniform = 'Xavier_uniform' Xavier_uniform = "Xavier_uniform"
constant = 'constant' constant = "constant"
kaiming_uniform = 'kaiming_uniform' kaiming_uniform = "kaiming_uniform"
default = 'default' default = "default"
def get_linear_layer(input_dim: int, output_dim: int, weight_norm=False, def get_linear_layer(
initializer: Initializer = Initializer.Xavier_uniform, *args, input_dim: int,
**kwargs): output_dim: int,
weight_norm=False,
initializer: Initializer = Initializer.Xavier_uniform,
*args,
**kwargs,
):
layer = torch.nn.Linear(input_dim, output_dim) layer = torch.nn.Linear(input_dim, output_dim)
init_method = InitializerFactory.get_initializer(initializer=initializer, **kwargs) init_method = InitializerFactory.get_initializer(initializer=initializer, **kwargs)
init_method(layer.weight) init_method(layer.weight)
torch.nn.init.constant_(layer.bias, 0.) torch.nn.init.constant_(layer.bias, 0.0)
if weight_norm: if weight_norm:
layer = torch.nn.utils.weight_norm(layer) layer = torch.nn.utils.weight_norm(layer)
return layer return layer
@ -81,8 +86,10 @@ class ActivationFactory:
elif activation == Activation.silu: elif activation == Activation.silu:
return Silu() return Silu()
else: else:
logger.error(f'Activation {activation} is not supported!') logger.error(f"Activation {activation} is not supported!")
raise NotImplementedError('Activation ' + activation.name + ' is not supported') raise NotImplementedError(
"Activation " + activation.name + " is not supported"
)
class Silu: class Silu:
@ -105,8 +112,12 @@ def leaky_relu(x, leak=0.1):
def triangle_wave(x): def triangle_wave(x):
y = 0.0 y = 0.0
for i in range(3): for i in range(3):
y += (-1.0) ** (i) * torch.sin(2.0 * math.pi * (2.0 * i + 1.0) * x) / (2.0 * i + 1.0) ** (2) y += (
y = 0.5 * (8 / (math.pi ** 2) * y) + .5 (-1.0) ** (i)
* torch.sin(2.0 * math.pi * (2.0 * i + 1.0) * x)
/ (2.0 * i + 1.0) ** (2)
)
y = 0.5 * (8 / (math.pi ** 2) * y) + 0.5
return y return y
@ -139,11 +150,15 @@ class InitializerFactory:
if initializer == Initializer.Xavier_uniform: if initializer == Initializer.Xavier_uniform:
return torch.nn.init.xavier_uniform_ return torch.nn.init.xavier_uniform_
elif initializer == Initializer.constant: elif initializer == Initializer.constant:
return lambda x: torch.nn.init.constant_(x, kwargs['constant']) return lambda x: torch.nn.init.constant_(x, kwargs["constant"])
elif initializer == Initializer.kaiming_uniform: elif initializer == Initializer.kaiming_uniform:
return lambda x: torch.nn.init.kaiming_uniform_(x, mode='fan_in', nonlinearity='relu') return lambda x: torch.nn.init.kaiming_uniform_(
x, mode="fan_in", nonlinearity="relu"
)
elif initializer == Initializer.default: elif initializer == Initializer.default:
return lambda x: x return lambda x: x
else: else:
logger.error('initialization ' + initializer.name + ' is not supported') logger.error("initialization " + initializer.name + " is not supported")
raise NotImplementedError('initialization ' + initializer.name + ' is not supported') raise NotImplementedError(
"initialization " + initializer.name + " is not supported"
)

View File

@ -3,7 +3,12 @@
import torch import torch
import math import math
from collections import OrderedDict from collections import OrderedDict
from idrlnet.architecture.layer import get_linear_layer, get_activation_layer, Initializer, Activation from idrlnet.architecture.layer import (
get_linear_layer,
get_activation_layer,
Initializer,
Activation,
)
from typing import List, Union, Tuple from typing import List, Union, Tuple
from idrlnet.header import logger from idrlnet.header import logger
from idrlnet.net import NetNode from idrlnet.net import NetNode
@ -28,25 +33,36 @@ class MLP(torch.nn.Module):
:param kwargs: :param kwargs:
""" """
def __init__(self, n_seq: List[int], activation: Union[Activation, List[Activation]] = Activation.swish, def __init__(
initialization: Initializer = Initializer.kaiming_uniform, self,
weight_norm: bool = True, name: str = 'mlp', *args, **kwargs): n_seq: List[int],
activation: Union[Activation, List[Activation]] = Activation.swish,
initialization: Initializer = Initializer.kaiming_uniform,
weight_norm: bool = True,
name: str = "mlp",
*args,
**kwargs,
):
super().__init__() super().__init__()
self.layers = OrderedDict() self.layers = OrderedDict()
current_activation = '' current_activation = ""
assert isinstance(n_seq, Activation) or isinstance(n_seq, list) assert isinstance(n_seq, Activation) or isinstance(n_seq, list)
for i in range(len(n_seq) - 1): for i in range(len(n_seq) - 1):
if isinstance(activation, list): if isinstance(activation, list):
current_activation = activation[i] current_activation = activation[i]
elif i < len(n_seq) - 2: elif i < len(n_seq) - 2:
current_activation = activation current_activation = activation
self.layers['{}_{}'.format(name, i)] = get_linear_layer(n_seq[i], n_seq[i + 1], weight_norm, initialization, self.layers["{}_{}".format(name, i)] = get_linear_layer(
*args, **kwargs) n_seq[i], n_seq[i + 1], weight_norm, initialization, *args, **kwargs
if (isinstance(activation, Activation) and i < len(n_seq) - 2) or isinstance(activation, list): )
if current_activation == 'none': if (
isinstance(activation, Activation) and i < len(n_seq) - 2
) or isinstance(activation, list):
if current_activation == "none":
continue continue
self.layers['{}_{}_activation'.format(name, i)] = get_activation_layer(current_activation, *args, self.layers["{}_{}_activation".format(name, i)] = get_activation_layer(
**kwargs) current_activation, *args, **kwargs
)
self.layers = torch.nn.ModuleDict(self.layers) self.layers = torch.nn.ModuleDict(self.layers)
def forward(self, x): def forward(self, x):
@ -61,8 +77,15 @@ class MLP(torch.nn.Module):
class Siren(torch.nn.Module): class Siren(torch.nn.Module):
def __init__(self, n_seq: List[int], first_omega: float = 30.0, def __init__(
omega: float = 30.0, name: str = 'siren', *args, **kwargs): self,
n_seq: List[int],
first_omega: float = 30.0,
omega: float = 30.0,
name: str = "siren",
*args,
**kwargs,
):
super().__init__() super().__init__()
self.layers = OrderedDict() self.layers = OrderedDict()
self.first_omega = first_omega self.first_omega = first_omega
@ -70,24 +93,37 @@ class Siren(torch.nn.Module):
assert isinstance(n_seq, str) or isinstance(n_seq, list) assert isinstance(n_seq, str) or isinstance(n_seq, list)
for i in range(len(n_seq) - 1): for i in range(len(n_seq) - 1):
if i == 0: if i == 0:
self.layers['{}_{}'.format(name, i)] = self.get_siren_layer(n_seq[i], n_seq[i + 1], True, first_omega) self.layers["{}_{}".format(name, i)] = self.get_siren_layer(
n_seq[i], n_seq[i + 1], True, first_omega
)
else: else:
self.layers['{}_{}'.format(name, i)] = self.get_siren_layer(n_seq[i], n_seq[i + 1], False, omega) self.layers["{}_{}".format(name, i)] = self.get_siren_layer(
n_seq[i], n_seq[i + 1], False, omega
)
if i < (len(n_seq) - 2): if i < (len(n_seq) - 2):
self.layers['{}_{}_activation'.format(name, i)] = get_activation_layer(Activation.sin, *args, **kwargs) self.layers["{}_{}_activation".format(name, i)] = get_activation_layer(
Activation.sin, *args, **kwargs
)
self.layers = torch.nn.ModuleDict(self.layers) self.layers = torch.nn.ModuleDict(self.layers)
@staticmethod @staticmethod
def get_siren_layer(input_dim: int, output_dim: int, is_first: bool, omega_0: float): def get_siren_layer(
input_dim: int, output_dim: int, is_first: bool, omega_0: float
):
layer = torch.nn.Linear(input_dim, output_dim) layer = torch.nn.Linear(input_dim, output_dim)
dim = input_dim dim = input_dim
if is_first: if is_first:
torch.nn.init.uniform_(layer.weight.data, -1.0 / dim, 1.0 / dim) torch.nn.init.uniform_(layer.weight.data, -1.0 / dim, 1.0 / dim)
else: else:
torch.nn.init.uniform_(layer.weight.data, -1.0 * math.sqrt(6.0 / dim) / omega_0, torch.nn.init.uniform_(
math.sqrt(6.0 / dim) / omega_0) layer.weight.data,
torch.nn.init.uniform_(layer.bias.data, -1 * math.sqrt(1 / dim), math.sqrt(1 / dim)) -1.0 * math.sqrt(6.0 / dim) / omega_0,
math.sqrt(6.0 / dim) / omega_0,
)
torch.nn.init.uniform_(
layer.bias.data, -1 * math.sqrt(1 / dim), math.sqrt(1 / dim)
)
return layer return layer
def forward(self, x): def forward(self, x):
@ -113,7 +149,7 @@ class SingleVar(torch.nn.Module):
self.value = torch.nn.Parameter(torch.Tensor([initialization])) self.value = torch.nn.Parameter(torch.Tensor([initialization]))
def forward(self, x) -> torch.Tensor: def forward(self, x) -> torch.Tensor:
return x[:, :1] * 0. + self.value return x[:, :1] * 0.0 + self.value
def get_value(self) -> torch.Tensor: def get_value(self) -> torch.Tensor:
return self.value return self.value
@ -135,7 +171,7 @@ class BoundedSingleVar(torch.nn.Module):
self.ub, self.lb = upper_bound, lower_bound self.ub, self.lb = upper_bound, lower_bound
def forward(self, x) -> torch.Tensor: def forward(self, x) -> torch.Tensor:
return x[:, :1] * 0. + self.layer(self.value) * (self.ub - self.lb) + self.lb return x[:, :1] * 0.0 + self.layer(self.value) * (self.ub - self.lb) + self.lb
def get_value(self) -> torch.Tensor: def get_value(self) -> torch.Tensor:
return self.layer(self.value) * (self.ub - self.lb) + self.lb return self.layer(self.value) * (self.ub - self.lb) + self.lb
@ -144,18 +180,22 @@ class BoundedSingleVar(torch.nn.Module):
class Arch(enum.Enum): class Arch(enum.Enum):
"""Enumerate pre-defined neural networks.""" """Enumerate pre-defined neural networks."""
mlp = 'mlp' mlp = "mlp"
toy = 'toy' toy = "toy"
mlp_xl = 'mlp_xl' mlp_xl = "mlp_xl"
single_var = 'single_var' single_var = "single_var"
bounded_single_var = 'bounded_single_var' bounded_single_var = "bounded_single_var"
siren = 'siren' siren = "siren"
def get_net_node(inputs: Union[Tuple[str, ...], List[str]], outputs: Union[Tuple[str, ...], List[str]], def get_net_node(
arch: Arch = None, name=None, inputs: Union[Tuple[str, ...], List[str]],
*args, outputs: Union[Tuple[str, ...], List[str]],
**kwargs) -> NetNode: arch: Arch = None,
name=None,
*args,
**kwargs,
) -> NetNode:
"""Get a net node wrapping networks with pre-defined configurations """Get a net node wrapping networks with pre-defined configurations
:param inputs: Input symbols for the generated node. :param inputs: Input symbols for the generated node.
@ -175,36 +215,65 @@ def get_net_node(inputs: Union[Tuple[str, ...], List[str]], outputs: Union[Tuple
:return: :return:
""" """
arch = Arch.mlp if arch is None else arch arch = Arch.mlp if arch is None else arch
if 'evaluate' in kwargs.keys(): if "evaluate" in kwargs.keys():
evaluate = kwargs.pop('evaluate') evaluate = kwargs.pop("evaluate")
else: else:
if arch == Arch.mlp: if arch == Arch.mlp:
seq = kwargs['seq'] if 'seq' in kwargs.keys() else [len(inputs), 20, 20, 20, 20, len(outputs)] seq = (
evaluate = MLP(n_seq=seq, activation=Activation.swish, initialization=Initializer.kaiming_uniform, kwargs["seq"]
weight_norm=True) if "seq" in kwargs.keys()
else [len(inputs), 20, 20, 20, 20, len(outputs)]
)
evaluate = MLP(
n_seq=seq,
activation=Activation.swish,
initialization=Initializer.kaiming_uniform,
weight_norm=True,
)
elif arch == Arch.toy: elif arch == Arch.toy:
evaluate = SimpleExpr("nothing") evaluate = SimpleExpr("nothing")
elif arch == Arch.mlp_xl or arch == 'fc': elif arch == Arch.mlp_xl or arch == "fc":
seq = kwargs['seq'] if 'seq' in kwargs.keys() else [len(inputs), 512, 512, 512, 512, 512, 512, len(outputs)] seq = (
evaluate = MLP(n_seq=seq, activation=Activation.silu, initialization=Initializer.kaiming_uniform, kwargs["seq"]
weight_norm=True) if "seq" in kwargs.keys()
else [len(inputs), 512, 512, 512, 512, 512, 512, len(outputs)]
)
evaluate = MLP(
n_seq=seq,
activation=Activation.silu,
initialization=Initializer.kaiming_uniform,
weight_norm=True,
)
elif arch == Arch.single_var: elif arch == Arch.single_var:
evaluate = SingleVar(initialization=kwargs.get('initialization', 1.)) evaluate = SingleVar(initialization=kwargs.get("initialization", 1.0))
elif arch == Arch.bounded_single_var: elif arch == Arch.bounded_single_var:
evaluate = BoundedSingleVar(lower_bound=kwargs['lower_bound'], upper_bound=kwargs['upper_bound']) evaluate = BoundedSingleVar(
lower_bound=kwargs["lower_bound"], upper_bound=kwargs["upper_bound"]
)
elif arch == Arch.siren: elif arch == Arch.siren:
seq = kwargs['seq'] if 'seq' in kwargs.keys() else [len(inputs), 512, 512, 512, 512, 512, 512, len(outputs)] seq = (
kwargs["seq"]
if "seq" in kwargs.keys()
else [len(inputs), 512, 512, 512, 512, 512, 512, len(outputs)]
)
evaluate = Siren(n_seq=seq) evaluate = Siren(n_seq=seq)
else: else:
logger.error(f'{arch} is not supported!') logger.error(f"{arch} is not supported!")
raise NotImplementedError(f'{arch} is not supported!') raise NotImplementedError(f"{arch} is not supported!")
nn = NetNode(inputs=inputs, outputs=outputs, net=evaluate, name=name, *args, **kwargs) nn = NetNode(
inputs=inputs, outputs=outputs, net=evaluate, name=name, *args, **kwargs
)
return nn return nn
def get_shared_net_node(shared_node: NetNode, inputs: Union[Tuple[str, ...], List[str]], def get_shared_net_node(
outputs: Union[Tuple[str, ...], List[str]], name=None, *args, shared_node: NetNode,
**kwargs) -> NetNode: inputs: Union[Tuple[str, ...], List[str]],
outputs: Union[Tuple[str, ...], List[str]],
name=None,
*args,
**kwargs,
) -> NetNode:
"""Construct a netnode, the net of which is shared by a given netnode. One can specify different inputs and outputs """Construct a netnode, the net of which is shared by a given netnode. One can specify different inputs and outputs
just like an independent netnode. However, the net parameters may have multiple references. Thus the step just like an independent netnode. However, the net parameters may have multiple references. Thus the step
operations during optimization should only be applied once. operations during optimization should only be applied once.
@ -221,22 +290,29 @@ def get_shared_net_node(shared_node: NetNode, inputs: Union[Tuple[str, ...], Lis
:param kwargs: :param kwargs:
:return: :return:
""" """
nn = NetNode(inputs, outputs, shared_node.net, is_reference=True, name=name, *args, **kwargs) nn = NetNode(
inputs, outputs, shared_node.net, is_reference=True, name=name, *args, **kwargs
)
return nn return nn
def get_inter_name(length: int, prefix: str): def get_inter_name(length: int, prefix: str):
return [prefix + f'_{i}' for i in range(length)] return [prefix + f"_{i}" for i in range(length)]
class SimpleExpr(torch.nn.Module): class SimpleExpr(torch.nn.Module):
"""This class is for testing. One can override SimpleExper.forward to represent complex formulas.""" """This class is for testing. One can override SimpleExper.forward to represent complex formulas."""
def __init__(self, expr, name='expr'): def __init__(self, expr, name="expr"):
super().__init__() super().__init__()
self.evaluate = expr self.evaluate = expr
self.name = name self.name = name
self._placeholder = torch.nn.Parameter(torch.Tensor([0.0])) self._placeholder = torch.nn.Parameter(torch.Tensor([0.0]))
def forward(self, x): def forward(self, x):
return self._placeholder + x[:, :1] * x[:, :1] / 2 + x[:, 1:] * x[:, 1:] / 2 - self._placeholder return (
self._placeholder
+ x[:, :1] * x[:, :1] / 2
+ x[:, 1:] * x[:, 1:] / 2
- self._placeholder
)

View File

@ -7,13 +7,13 @@ from torch.utils.tensorboard import SummaryWriter
from idrlnet.receivers import Receiver, Signal from idrlnet.receivers import Receiver, Signal
from idrlnet.variable import Variables from idrlnet.variable import Variables
__all__ = ['GradientReceiver', 'SummaryReceiver', 'HandleResultReceiver'] __all__ = ["GradientReceiver", "SummaryReceiver", "HandleResultReceiver"]
class GradientReceiver(Receiver): class GradientReceiver(Receiver):
"""Register the receiver to monitor gradient norm on the Tensorboard.""" """Register the receiver to monitor gradient norm on the Tensorboard."""
def receive_notify(self, solver: 'Solver', message): # noqa def receive_notify(self, solver: "Solver", message): # noqa
if not (Signal.TRAIN_PIPE_END in message): if not (Signal.TRAIN_PIPE_END in message):
return return
for netnode in solver.netnodes: for netnode in solver.netnodes:
@ -23,9 +23,11 @@ class GradientReceiver(Receiver):
for p in model.parameters(): for p in model.parameters():
param_norm = p.grad.data.norm(2) param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2 total_norm += param_norm.item() ** 2
total_norm = total_norm ** (1. / 2) total_norm = total_norm ** (1.0 / 2)
assert isinstance(solver.receivers[0], SummaryWriter) assert isinstance(solver.receivers[0], SummaryWriter)
solver.summary_receiver.add_scalar('gradient/total_norm', total_norm, solver.global_step) solver.summary_receiver.add_scalar(
"gradient/total_norm", total_norm, solver.global_step
)
class SummaryReceiver(SummaryWriter, Receiver): class SummaryReceiver(SummaryWriter, Receiver):
@ -34,15 +36,19 @@ class SummaryReceiver(SummaryWriter, Receiver):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
SummaryWriter.__init__(self, *args, **kwargs) SummaryWriter.__init__(self, *args, **kwargs)
def receive_notify(self, solver: 'Solver', message: Dict): # noqa def receive_notify(self, solver: "Solver", message: Dict): # noqa
if Signal.AFTER_COMPUTE_LOSS in message.keys(): if Signal.AFTER_COMPUTE_LOSS in message.keys():
loss_component = message[Signal.AFTER_COMPUTE_LOSS] loss_component = message[Signal.AFTER_COMPUTE_LOSS]
self.add_scalars('loss_overview', loss_component, solver.global_step) self.add_scalars("loss_overview", loss_component, solver.global_step)
for key, value in loss_component.items(): for key, value in loss_component.items():
self.add_scalar(f'loss_component/{key}', value, solver.global_step) self.add_scalar(f"loss_component/{key}", value, solver.global_step)
if Signal.TRAIN_PIPE_END in message.keys(): if Signal.TRAIN_PIPE_END in message.keys():
for i, optimizer in enumerate(solver.optimizers): for i, optimizer in enumerate(solver.optimizers):
self.add_scalar(f'optimizer/lr_{i}', optimizer.param_groups[0]['lr'], solver.global_step) self.add_scalar(
f"optimizer/lr_{i}",
optimizer.param_groups[0]["lr"],
solver.global_step,
)
class HandleResultReceiver(Receiver): class HandleResultReceiver(Receiver):
@ -51,11 +57,13 @@ class HandleResultReceiver(Receiver):
def __init__(self, result_dir): def __init__(self, result_dir):
self.result_dir = result_dir self.result_dir = result_dir
def receive_notify(self, solver: 'Solver', message: Dict): # noqa def receive_notify(self, solver: "Solver", message: Dict): # noqa
if Signal.SOLVE_END in message.keys(): if Signal.SOLVE_END in message.keys():
samples = solver.sample_variables_from_domains() samples = solver.sample_variables_from_domains()
in_var, _, lambda_out = solver.generate_in_out_dict(samples) in_var, _, lambda_out = solver.generate_in_out_dict(samples)
pred_out_sample = solver.forward_through_all_graph(in_var, solver.outvar_dict_index) pred_out_sample = solver.forward_through_all_graph(
in_var, solver.outvar_dict_index
)
diff_out_sample = {key: Variables() for key in pred_out_sample} diff_out_sample = {key: Variables() for key in pred_out_sample}
results_path = pathlib.Path(self.result_dir) results_path = pathlib.Path(self.result_dir)
results_path.mkdir(exist_ok=True, parents=True) results_path.mkdir(exist_ok=True, parents=True)
@ -65,7 +73,15 @@ class HandleResultReceiver(Receiver):
pred_out_sample[key][_key] = samples[key][_key] pred_out_sample[key][_key] = samples[key][_key]
diff_out_sample[key][_key] = samples[key][_key] diff_out_sample[key][_key] = samples[key][_key]
else: else:
diff_out_sample[key][_key] = pred_out_sample[key][_key] - samples[key][_key] diff_out_sample[key][_key] = (
samples[key].save(os.path.join(results_path, f'{key}_true'), ['vtu', 'np', 'csv']) pred_out_sample[key][_key] - samples[key][_key]
pred_out_sample[key].save(os.path.join(results_path, f'{key}_pred'), ['vtu', 'np', 'csv']) )
diff_out_sample[key].save(os.path.join(results_path, f'{key}_diff'), ['vtu', 'np', 'csv']) samples[key].save(
os.path.join(results_path, f"{key}_true"), ["vtu", "np", "csv"]
)
pred_out_sample[key].save(
os.path.join(results_path, f"{key}_pred"), ["vtu", "np", "csv"]
)
diff_out_sample[key].save(
os.path.join(results_path, f"{key}_diff"), ["vtu", "np", "csv"]
)

View File

@ -36,6 +36,7 @@ class DataNode(Node):
:param args: :param args:
:param kwargs: :param kwargs:
""" """
counter = 0 counter = 0
@property @property
@ -87,18 +88,27 @@ class DataNode(Node):
try: try:
output_vars[key] = lambdify_np(value, input_vars)(**input_vars) output_vars[key] = lambdify_np(value, input_vars)(**input_vars)
except: except:
logger.error('unsupported constraints type.') logger.error("unsupported constraints type.")
raise ValueError('unsupported constraints type.') raise ValueError("unsupported constraints type.")
try: try:
return Variables({**input_vars, **output_vars}).to_torch_tensor_() return Variables({**input_vars, **output_vars}).to_torch_tensor_()
except: except:
return Variables({**input_vars, **output_vars}) return Variables({**input_vars, **output_vars})
def __init__(self, inputs: Union[Tuple[str, ...], List[str]], outputs: Union[Tuple[str, ...], List[str]], def __init__(
sample_fn: Callable, loss_fn: str = 'square', lambda_outputs: Union[Tuple[str, ...], List[str]] = None, self,
name=None, sigma=1.0, var_sigma=False, inputs: Union[Tuple[str, ...], List[str]],
*args, **kwargs): outputs: Union[Tuple[str, ...], List[str]],
sample_fn: Callable,
loss_fn: str = "square",
lambda_outputs: Union[Tuple[str, ...], List[str]] = None,
name=None,
sigma=1.0,
var_sigma=False,
*args,
**kwargs,
):
self.inputs: Union[Tuple, List[str]] = inputs self.inputs: Union[Tuple, List[str]] = inputs
self.outputs: Union[Tuple, List[str]] = outputs self.outputs: Union[Tuple, List[str]] = outputs
self.lambda_outputs = lambda_outputs self.lambda_outputs = lambda_outputs
@ -113,13 +123,22 @@ class DataNode(Node):
self.loss_fn = loss_fn self.loss_fn = loss_fn
def __str__(self): def __str__(self):
str_list = ["DataNode properties:\n" str_list = [
"lambda_outputs: {}\n".format(self.lambda_outputs)] "DataNode properties:\n" "lambda_outputs: {}\n".format(self.lambda_outputs)
return super().__str__() + ''.join(str_list) ]
return super().__str__() + "".join(str_list)
def get_data_node(fun: Callable, name=None, loss_fn='square', sigma=1., var_sigma=False, *args, **kwargs) -> DataNode: def get_data_node(
""" Construct a datanode from sampling functions. fun: Callable,
name=None,
loss_fn="square",
sigma=1.0,
var_sigma=False,
*args,
**kwargs,
) -> DataNode:
"""Construct a datanode from sampling functions.
:param fun: Each call of the Callable object should return a sampling dict. :param fun: Each call of the Callable object should return a sampling dict.
:type fun: Callable :type fun: Callable
@ -135,26 +154,56 @@ def get_data_node(fun: Callable, name=None, loss_fn='square', sigma=1., var_sigm
in_, out_ = fun() in_, out_ = fun()
inputs = list(in_.keys()) inputs = list(in_.keys())
outputs = list(out_.keys()) outputs = list(out_.keys())
lambda_outputs = list(filter(lambda x: x.startswith('lambda_'), outputs)) lambda_outputs = list(filter(lambda x: x.startswith("lambda_"), outputs))
outputs = list(filter(lambda x: not x.startswith('lambda_'), outputs)) outputs = list(filter(lambda x: not x.startswith("lambda_"), outputs))
name = (fun.__name__ if inspect.isfunction(fun) else type(fun).__name__) if name is None else name name = (
dn = DataNode(inputs=inputs, outputs=outputs, sample_fn=fun, lambda_outputs=lambda_outputs, loss_fn=loss_fn, (fun.__name__ if inspect.isfunction(fun) else type(fun).__name__)
name=name, sigma=sigma, var_sigma=var_sigma, *args, **kwargs) if name is None
else name
)
dn = DataNode(
inputs=inputs,
outputs=outputs,
sample_fn=fun,
lambda_outputs=lambda_outputs,
loss_fn=loss_fn,
name=name,
sigma=sigma,
var_sigma=var_sigma,
*args,
**kwargs,
)
return dn return dn
def datanode(_fun: Callable = None, name=None, loss_fn='square', sigma=1., var_sigma=False, **kwargs): def datanode(
_fun: Callable = None,
name=None,
loss_fn="square",
sigma=1.0,
var_sigma=False,
**kwargs,
):
"""As an alternative, decorate Callable classes as Datanode.""" """As an alternative, decorate Callable classes as Datanode."""
def wrap(fun): def wrap(fun):
if inspect.isclass(fun): if inspect.isclass(fun):
assert issubclass(fun, SampleDomain), f"{fun} should be subclass of .data.Sample" assert issubclass(
fun, SampleDomain
), f"{fun} should be subclass of .data.Sample"
fun = fun() fun = fun()
assert isinstance(fun, Callable) assert isinstance(fun, Callable)
@functools.wraps(fun) @functools.wraps(fun)
def wrapped_fun(): def wrapped_fun():
dn = get_data_node(fun, name=name, loss_fn=loss_fn, sigma=sigma, var_sigma=var_sigma, **kwargs) dn = get_data_node(
fun,
name=name,
loss_fn=loss_fn,
sigma=sigma,
var_sigma=var_sigma,
**kwargs,
)
return dn return dn
return wrapped_fun return wrapped_fun
@ -163,9 +212,12 @@ def datanode(_fun: Callable = None, name=None, loss_fn='square', sigma=1., var_s
def get_data_nodes(funs: List[Callable], *args, **kwargs) -> Tuple[DataNode]: def get_data_nodes(funs: List[Callable], *args, **kwargs) -> Tuple[DataNode]:
if 'names' in kwargs: if "names" in kwargs:
names = kwargs.pop('names') names = kwargs.pop("names")
return tuple(get_data_node(fun, name=name, *args, **kwargs) for fun, name in zip(funs, names)) return tuple(
get_data_node(fun, name=name, *args, **kwargs)
for fun, name in zip(funs, names)
)
else: else:
return tuple(get_data_node(fun, *args, **kwargs) for fun in funs) return tuple(get_data_node(fun, *args, **kwargs) for fun in funs)

View File

@ -1,28 +1,42 @@
""" A simple factory for constructing Geometric Objects""" """ A simple factory for constructing Geometric Objects"""
from .geo import Geometry from .geo import Geometry
from .geo_obj import Line1D, Line, Tube2D, Rectangle, Circle, Plane, Tube3D, Box, Sphere, Cylinder, CircularTube, \ from .geo_obj import (
Triangle, Heart Line1D,
Line,
Tube2D,
Rectangle,
Circle,
Plane,
Tube3D,
Box,
Sphere,
Cylinder,
CircularTube,
Triangle,
Heart,
)
__all__ = ['GeometryBuilder'] __all__ = ["GeometryBuilder"]
class GeometryBuilder: class GeometryBuilder:
GEOMAP = {'Line1D': Line1D, GEOMAP = {
'Line': Line, "Line1D": Line1D,
'Rectangle': Rectangle, "Line": Line,
'Circle': Circle, "Rectangle": Rectangle,
'Channel2D': Tube2D, "Circle": Circle,
'Plane': Plane, "Channel2D": Tube2D,
'Sphere': Sphere, "Plane": Plane,
'Box': Box, "Sphere": Sphere,
'Channel': Tube3D, "Box": Box,
'Channel3D': Tube3D, "Channel": Tube3D,
'Cylinder': Cylinder, "Channel3D": Tube3D,
'CircularTube': CircularTube, "Cylinder": Cylinder,
'Triangle': Triangle, "CircularTube": CircularTube,
'Heart': Heart, "Triangle": Triangle,
} "Heart": Heart,
}
@staticmethod @staticmethod
def get_geometry(geo: str, **kwargs) -> Geometry: def get_geometry(geo: str, **kwargs) -> Geometry:
@ -33,5 +47,7 @@ class GeometryBuilder:
:return: A geometry object with given kwargs. :return: A geometry object with given kwargs.
:rtype: Geometry :rtype: Geometry
""" """
assert geo in GeometryBuilder.GEOMAP.keys(), f'The geometry {geo} not implemented!' assert (
geo in GeometryBuilder.GEOMAP.keys()
), f"The geometry {geo} not implemented!"
return GeometryBuilder.GEOMAP[geo](**kwargs) return GeometryBuilder.GEOMAP[geo](**kwargs)

View File

@ -10,7 +10,7 @@ from functools import reduce
import collections import collections
from sympy import Max, Min, Mul from sympy import Max, Min, Mul
__all__ = ['lambdify_np'] __all__ = ["lambdify_np"]
class WrapSympy: class WrapSympy:
@ -20,10 +20,14 @@ class WrapSympy:
def _wrapper_guide(args): def _wrapper_guide(args):
func_1 = args[0] func_1 = args[0]
func_2 = args[1] func_2 = args[1]
cond_1 = (isinstance(func_1, WrapSympy) and not func_1.is_sympy) cond_1 = isinstance(func_1, WrapSympy) and not func_1.is_sympy
cond_2 = isinstance(func_2, WrapSympy) and not func_2.is_sympy cond_2 = isinstance(func_2, WrapSympy) and not func_2.is_sympy
cond_3 = (not isinstance(func_1, WrapSympy)) and isinstance(func_1, collections.Callable) cond_3 = (not isinstance(func_1, WrapSympy)) and isinstance(
cond_4 = (not isinstance(func_2, WrapSympy)) and isinstance(func_2, collections.Callable) func_1, collections.Callable
)
cond_4 = (not isinstance(func_2, WrapSympy)) and isinstance(
func_2, collections.Callable
)
return cond_1 or cond_2 or cond_3 or cond_4, func_1, func_2 return cond_1 or cond_2 or cond_3 or cond_4, func_1, func_2
@ -111,8 +115,11 @@ def _try_float(fn):
def _constant_bool(boolean: bool): def _constant_bool(boolean: bool):
def fn(**x): def fn(**x):
return np.ones_like(next(iter(x.items()))[1], dtype=bool) if boolean else np.zeros_like( return (
next(iter(x.items()))[1], dtype=bool) np.ones_like(next(iter(x.items()))[1], dtype=bool)
if boolean
else np.zeros_like(next(iter(x.items()))[1], dtype=bool)
)
return fn return fn
@ -128,7 +135,7 @@ def lambdify_np(f, r: Iterable):
if isinstance(r, dict): if isinstance(r, dict):
r = r.keys() r = r.keys()
if isinstance(f, WrapSympy) and f.is_sympy: if isinstance(f, WrapSympy) and f.is_sympy:
lambdify_f = lambdify([k for k in r], f, [PLACEHOLDER, 'numpy']) lambdify_f = lambdify([k for k in r], f, [PLACEHOLDER, "numpy"])
lambdify_f.input_keys = [k for k in r] lambdify_f.input_keys = [k for k in r]
return lambdify_f return lambdify_f
if isinstance(f, WrapSympy) and not f.is_sympy: if isinstance(f, WrapSympy) and not f.is_sympy:
@ -141,30 +148,31 @@ def lambdify_np(f, r: Iterable):
if isinstance(f, float): if isinstance(f, float):
return _constant_float(f) return _constant_float(f)
else: else:
lambdify_f = lambdify([k for k in r], f, [PLACEHOLDER, 'numpy']) lambdify_f = lambdify([k for k in r], f, [PLACEHOLDER, "numpy"])
lambdify_f.input_keys = [k for k in r] lambdify_f.input_keys = [k for k in r]
return lambdify_f return lambdify_f
PLACEHOLDER = {'amin': lambda x: reduce(lambda y, z: np.minimum(y, z), x), PLACEHOLDER = {
'amax': lambda x: reduce(lambda y, z: np.maximum(y, z), x), "amin": lambda x: reduce(lambda y, z: np.minimum(y, z), x),
'Min': lambda *x: reduce(lambda y, z: np.minimum(y, z), x), "amax": lambda x: reduce(lambda y, z: np.maximum(y, z), x),
'Max': lambda *x: reduce(lambda y, z: np.maximum(y, z), x), "Min": lambda *x: reduce(lambda y, z: np.minimum(y, z), x),
'Heaviside': lambda x: np.heaviside(x, 0), "Max": lambda *x: reduce(lambda y, z: np.maximum(y, z), x),
'equal': lambda x, y: np.isclose(x, y), "Heaviside": lambda x: np.heaviside(x, 0),
'Xor': np.logical_xor, "equal": lambda x, y: np.isclose(x, y),
'cos': np.cos, "Xor": np.logical_xor,
'sin': np.sin, "cos": np.cos,
'tan': np.tan, "sin": np.sin,
'exp': np.exp, "tan": np.tan,
'sqrt': np.sqrt, "exp": np.exp,
'log': np.log, "sqrt": np.sqrt,
'sinh': np.sinh, "log": np.log,
'cosh': np.cosh, "sinh": np.sinh,
'tanh': np.tanh, "cosh": np.cosh,
'asin': np.arcsin, "tanh": np.tanh,
'acos': np.arccos, "asin": np.arcsin,
'atan': np.arctan, "acos": np.arccos,
'Abs': np.abs, "atan": np.arctan,
'DiracDelta': np.zeros_like, "Abs": np.abs,
} "DiracDelta": np.zeros_like,
}

View File

@ -13,15 +13,15 @@ from idrlnet.header import logger, DIFF_SYMBOL
from idrlnet.pde import PdeNode from idrlnet.pde import PdeNode
from idrlnet.net import NetNode from idrlnet.net import NetNode
__all__ = ['ComputableNodeList', 'Vertex', 'VertexTaskPipeline'] __all__ = ["ComputableNodeList", "Vertex", "VertexTaskPipeline"]
x, y = sp.symbols('x y') x, y = sp.symbols("x y")
ComputableNodeList = [List[Union[PdeNode, NetNode]]] ComputableNodeList = [List[Union[PdeNode, NetNode]]]
class Vertex(Node): class Vertex(Node):
counter = 0 counter = 0
def __init__(self, pre=None, next=None, node=None, ntype='c'): def __init__(self, pre=None, next=None, node=None, ntype="c"):
node = Node() if node is None else node node = Node() if node is None else node
self.__dict__ = node.__dict__.copy() self.__dict__ = node.__dict__.copy()
self.index = type(self).counter self.index = type(self).counter
@ -29,7 +29,7 @@ class Vertex(Node):
self.pre = pre if pre is not None else set() self.pre = pre if pre is not None else set()
self.next = next if pre is not None else set() self.next = next if pre is not None else set()
self.ntype = ntype self.ntype = ntype
assert self.ntype in ('d', 'c', 'r') assert self.ntype in ("d", "c", "r")
def __eq__(self, other): def __eq__(self, other):
return self.index == other.index return self.index == other.index
@ -38,8 +38,11 @@ class Vertex(Node):
return self.index return self.index
def __str__(self): def __str__(self):
info = f"index: {self.index}\n" + f"pre: {[node.index for node in self.pre]}\n" \ info = (
+ f"next: {[node.index for node in self.next]}\n" f"index: {self.index}\n"
+ f"pre: {[node.index for node in self.pre]}\n"
+ f"next: {[node.index for node in self.next]}\n"
)
return super().__str__() + info return super().__str__() + info
@ -54,7 +57,9 @@ class VertexTaskPipeline:
def evaluation_order_list(self, evaluation_order_list): def evaluation_order_list(self, evaluation_order_list):
self._evaluation_order_list = evaluation_order_list self._evaluation_order_list = evaluation_order_list
def __init__(self, nodes: ComputableNodeList, invar: Variables, req_names: List[str]): def __init__(
self, nodes: ComputableNodeList, invar: Variables, req_names: List[str]
):
self.nodes = nodes self.nodes = nodes
self.req_names = req_names self.req_names = req_names
self.computable = set(invar.keys()) self.computable = set(invar.keys())
@ -74,14 +79,14 @@ class VertexTaskPipeline:
final_graph_node.inputs = [req_name] final_graph_node.inputs = [req_name]
final_graph_node.derivatives = tuple() final_graph_node.derivatives = tuple()
final_graph_node.outputs = tuple() final_graph_node.outputs = tuple()
final_graph_node.name = f'<{req_name}>' final_graph_node.name = f"<{req_name}>"
final_graph_node.ntype = 'r' final_graph_node.ntype = "r"
graph_nodes.add(final_graph_node) graph_nodes.add(final_graph_node)
req_name_dict[req_name].append(final_graph_node) req_name_dict[req_name].append(final_graph_node)
required_stack.append(final_graph_node) required_stack.append(final_graph_node)
final_graph_node.evaluate = lambda x: x final_graph_node.evaluate = lambda x: x
logger.info('Constructing computation graph...') logger.info("Constructing computation graph...")
while len(req_name_dict) > 0: while len(req_name_dict) > 0:
to_be_removed = set() to_be_removed = set()
to_be_added = defaultdict(list) to_be_added = defaultdict(list)
@ -96,14 +101,20 @@ class VertexTaskPipeline:
continue continue
for output in gn.outputs: for output in gn.outputs:
output = tuple(output.split(DIFF_SYMBOL)) output = tuple(output.split(DIFF_SYMBOL))
if len(output) <= len(req_name) and req_name[:len(output)] == output and len( if (
output) > match_score: len(output) <= len(req_name)
and req_name[: len(output)] == output
and len(output) > match_score
):
match_score = len(output) match_score = len(output)
match_gn = gn match_gn = gn
for p_in in invar.keys(): for p_in in invar.keys():
p_in = tuple(p_in.split(DIFF_SYMBOL)) p_in = tuple(p_in.split(DIFF_SYMBOL))
if len(p_in) <= len(req_name) and req_name[:len(p_in)] == p_in and len( if (
p_in) > match_score: len(p_in) <= len(req_name)
and req_name[: len(p_in)] == p_in
and len(p_in) > match_score
):
match_score = len(p_in) match_score = len(p_in)
match_gn = None match_gn = None
for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]: for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]:
@ -112,9 +123,13 @@ class VertexTaskPipeline:
raise Exception("Can't be computed: " + DIFF_SYMBOL.join(req_name)) raise Exception("Can't be computed: " + DIFF_SYMBOL.join(req_name))
elif match_gn is not None: elif match_gn is not None:
for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]: for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]:
logger.info(f'{sub_gn.name}.{DIFF_SYMBOL.join(req_name)} <---- {match_gn.name}') logger.info(
f"{sub_gn.name}.{DIFF_SYMBOL.join(req_name)} <---- {match_gn.name}"
)
match_gn.next.add(sub_gn) match_gn.next.add(sub_gn)
self.egde_data[(match_gn.name, sub_gn.name)].add(DIFF_SYMBOL.join(req_name)) self.egde_data[(match_gn.name, sub_gn.name)].add(
DIFF_SYMBOL.join(req_name)
)
required_stack.append(match_gn) required_stack.append(match_gn)
for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]: for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]:
sub_gn.pre.add(match_gn) sub_gn.pre.add(match_gn)
@ -148,51 +163,91 @@ class VertexTaskPipeline:
node.name = key node.name = key
node.outputs = (key,) node.outputs = (key,)
node.inputs = tuple() node.inputs = tuple()
node.ntype = 'd' node.ntype = "d"
self._graph_node_table[key] = node self._graph_node_table[key] = node
logger.info('Computation graph constructed.') logger.info("Computation graph constructed.")
def operation_order(self, invar: Variables): def operation_order(self, invar: Variables):
for node in self.evaluation_order_list: for node in self.evaluation_order_list:
if not set(node.derivatives).issubset(invar.keys()): if not set(node.derivatives).issubset(invar.keys()):
invar.differentiate_(independent_var=invar, required_derivatives=node.derivatives) invar.differentiate_(
invar.update(node.evaluate({**invar.subset(node.inputs), **invar.subset(node.derivatives)})) independent_var=invar, required_derivatives=node.derivatives
)
invar.update(
node.evaluate(
{**invar.subset(node.inputs), **invar.subset(node.derivatives)}
)
)
def forward_pipeline(self, invar: Variables, req_names: List[str] = None) -> Variables: def forward_pipeline(
self, invar: Variables, req_names: List[str] = None
) -> Variables:
if req_names is None or set(req_names).issubset(set(self.computable)): if req_names is None or set(req_names).issubset(set(self.computable)):
outvar = copy(invar) outvar = copy(invar)
self.operation_order(outvar) self.operation_order(outvar)
return outvar.subset(self.req_names if req_names is None else req_names) return outvar.subset(self.req_names if req_names is None else req_names)
else: else:
logger.info('The existing graph fails. Construct a temporary graph...') logger.info("The existing graph fails. Construct a temporary graph...")
return VertexTaskPipeline(self.nodes, invar, req_names).forward_pipeline(invar) return VertexTaskPipeline(self.nodes, invar, req_names).forward_pipeline(
invar
)
def to_json(self): def to_json(self):
pass pass
def display(self, filename: str = None): def display(self, filename: str = None):
_, ax = plt.subplots(1, 1, figsize=(8, 8)) _, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.axis('off') ax.axis("off")
pos = nx.spring_layout(self.G, k=10 / (math.sqrt(self.G.order()) + 0.1)) pos = nx.spring_layout(self.G, k=10 / (math.sqrt(self.G.order()) + 0.1))
nx.draw_networkx_nodes(self.G, pos, nx.draw_networkx_nodes(
nodelist=list( self.G,
node for node in self.G.nodes if self._graph_node_table[node].ntype == 'c'), pos,
cmap=plt.get_cmap('jet'), nodelist=list(
node_size=1300, node_color="pink", alpha=0.5) node
nx.draw_networkx_nodes(self.G, pos, for node in self.G.nodes
nodelist=list( if self._graph_node_table[node].ntype == "c"
node for node in self.G.nodes if self._graph_node_table[node].ntype == 'r'), ),
cmap=plt.get_cmap('jet'), cmap=plt.get_cmap("jet"),
node_size=1300, node_color="green", alpha=0.3) node_size=1300,
nx.draw_networkx_nodes(self.G, pos, node_color="pink",
nodelist=list( alpha=0.5,
node for node in self.G.nodes if self._graph_node_table[node].ntype == 'd'), )
cmap=plt.get_cmap('jet'), nx.draw_networkx_nodes(
node_size=1300, node_color="blue", alpha=0.3) self.G,
nx.draw_networkx_edges(self.G, pos, edge_color='r', arrows=True, arrowsize=30, arrowstyle="-|>") pos,
nodelist=list(
node
for node in self.G.nodes
if self._graph_node_table[node].ntype == "r"
),
cmap=plt.get_cmap("jet"),
node_size=1300,
node_color="green",
alpha=0.3,
)
nx.draw_networkx_nodes(
self.G,
pos,
nodelist=list(
node
for node in self.G.nodes
if self._graph_node_table[node].ntype == "d"
),
cmap=plt.get_cmap("jet"),
node_size=1300,
node_color="blue",
alpha=0.3,
)
nx.draw_networkx_edges(
self.G, pos, edge_color="r", arrows=True, arrowsize=30, arrowstyle="-|>"
)
nx.draw_networkx_labels(self.G, pos) nx.draw_networkx_labels(self.G, pos)
nx.draw_networkx_edge_labels(self.G, pos, edge_labels={k: ", ".join(v) for k, v in self.egde_data.items()}, nx.draw_networkx_edge_labels(
font_size=10) self.G,
pos,
edge_labels={k: ", ".join(v) for k, v in self.egde_data.items()},
font_size=10,
)
if filename is None: if filename is None:
plt.show() plt.show()
else: else:

View File

@ -14,7 +14,7 @@ class TestFun:
self.registered.append(self) self.registered.append(self)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
print(str(self.fun.__name__).center(50, '*')) print(str(self.fun.__name__).center(50, "*"))
self.fun() self.fun()
@staticmethod @staticmethod
@ -36,7 +36,12 @@ def testmemo(fun):
testmemo.memo = set() testmemo.memo = set()
log_format = '[%(asctime)s] [%(levelname)s] %(message)s' log_format = "[%(asctime)s] [%(levelname)s] %(message)s"
handlers = [logging.FileHandler('train.log', mode='a'), logging.StreamHandler()] handlers = [logging.FileHandler("train.log", mode="a"), logging.StreamHandler()]
logging.basicConfig(format=log_format, level=logging.INFO, datefmt='%d-%b-%y %H:%M:%S', handlers=handlers) logging.basicConfig(
format=log_format,
level=logging.INFO,
datefmt="%d-%b-%y %H:%M:%S",
handlers=handlers,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -4,11 +4,11 @@ from idrlnet.node import Node
from typing import Tuple, List, Dict, Union from typing import Tuple, List, Dict, Union
from contextlib import ExitStack from contextlib import ExitStack
__all__ = ['NetNode'] __all__ = ["NetNode"]
class WrapEvaluate: class WrapEvaluate:
def __init__(self, binding_node: 'NetNode'): def __init__(self, binding_node: "NetNode"):
self.binding_node = binding_node self.binding_node = binding_node
def __call__(self, inputs): def __call__(self, inputs):
@ -16,15 +16,23 @@ class WrapEvaluate:
if isinstance(inputs, dict): if isinstance(inputs, dict):
keep_type = dict keep_type = dict
inputs = torch.cat( inputs = torch.cat(
[torch.tensor(inputs[key], dtype=torch.float32) if not isinstance(inputs[key], torch.Tensor) else [
inputs[ torch.tensor(inputs[key], dtype=torch.float32)
key] for key in inputs], dim=1) if not isinstance(inputs[key], torch.Tensor)
else inputs[key]
for key in inputs
],
dim=1,
)
with ExitStack() as es: with ExitStack() as es:
if self.binding_node.require_no_grad: if self.binding_node.require_no_grad:
es.enter_context(torch.no_grad()) es.enter_context(torch.no_grad())
output_var = self.binding_node.net(inputs) output_var = self.binding_node.net(inputs)
if keep_type == dict: if keep_type == dict:
output_var = {outkey: output_var[:, i:i + 1] for i, outkey in enumerate(self.binding_node.outputs)} output_var = {
outkey: output_var[:, i : i + 1]
for i, outkey in enumerate(self.binding_node.outputs)
}
return output_var return output_var
@ -63,9 +71,18 @@ class NetNode(Node):
def net(self, net): def net(self, net):
self._net = net self._net = net
def __init__(self, inputs: Union[Tuple, List[str]], outputs: Union[Tuple, List[str]], def __init__(
net: torch.nn.Module, fixed: bool = False, require_no_grad: bool = False, is_reference=False, self,
name=None, *args, **kwargs): inputs: Union[Tuple, List[str]],
outputs: Union[Tuple, List[str]],
net: torch.nn.Module,
fixed: bool = False,
require_no_grad: bool = False,
is_reference=False,
name=None,
*args,
**kwargs
):
self.is_reference = is_reference self.is_reference = is_reference
self.inputs: Union[Tuple, List[str]] = inputs self.inputs: Union[Tuple, List[str]] = inputs
self.outputs: Union[Tuple, List[str]] = outputs self.outputs: Union[Tuple, List[str]] = outputs
@ -89,5 +106,5 @@ class NetNode(Node):
def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = True): def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = True):
return self.net.load_state_dict(state_dict, strict) return self.net.load_state_dict(state_dict, strict)
def state_dict(self, destination=None, prefix: str = '', keep_vars: bool = False): def state_dict(self, destination=None, prefix: str = "", keep_vars: bool = False):
return self.net.state_dict(destination, prefix, keep_vars) return self.net.state_dict(destination, prefix, keep_vars)

View File

@ -5,7 +5,7 @@ from idrlnet.torch_util import torch_lambdify
from idrlnet.variable import Variables from idrlnet.variable import Variables
from idrlnet.header import DIFF_SYMBOL from idrlnet.header import DIFF_SYMBOL
__all__ = ['Node'] __all__ = ["Node"]
class Node(object): class Node(object):
@ -58,7 +58,7 @@ class Node(object):
try: try:
return self._name return self._name
except: except:
self._name = 'Node' + str(id(self)) self._name = "Node" + str(id(self))
return self._name return self._name
@name.setter @name.setter
@ -66,23 +66,33 @@ class Node(object):
self._name = name self._name = name
@classmethod @classmethod
def new_node(cls, name: str = None, tf_eq: Callable = None, free_symbols: List[str] = None, *args, def new_node(
**kwargs) -> 'Node': cls,
name: str = None,
tf_eq: Callable = None,
free_symbols: List[str] = None,
*args,
**kwargs
) -> "Node":
node = cls() node = cls()
node.evaluate = LambdaTorchFun(free_symbols, tf_eq, name) node.evaluate = LambdaTorchFun(free_symbols, tf_eq, name)
node.inputs = [x for x in free_symbols if DIFF_SYMBOL not in x] node.inputs = [x for x in free_symbols if DIFF_SYMBOL not in x]
node.derivatives = [x for x in free_symbols if DIFF_SYMBOL in x] node.derivatives = [x for x in free_symbols if DIFF_SYMBOL in x]
node.outputs = [name, ] node.outputs = [
name,
]
node.name = name node.name = name
return node return node
def __str__(self): def __str__(self):
str_list = ["Basic properties:\n", str_list = [
"name: {}\n".format(self.name), "Basic properties:\n",
"inputs: {}\n".format(self.inputs), "name: {}\n".format(self.name),
"derivatives: {}\n".format(self.derivatives), "inputs: {}\n".format(self.inputs),
"outputs: {}\n".format(self.outputs), ] "derivatives: {}\n".format(self.derivatives),
return ''.join(str_list) "outputs: {}\n".format(self.outputs),
]
return "".join(str_list)
class LambdaTorchFun: class LambdaTorchFun:

View File

@ -6,7 +6,7 @@ from idrlnet.torch_util import _replace_derivatives
from idrlnet.header import DIFF_SYMBOL from idrlnet.header import DIFF_SYMBOL
from idrlnet.variable import Variables from idrlnet.variable import Variables
__all__ = ['PdeNode', 'ExpressionNode'] __all__ = ["PdeNode", "ExpressionNode"]
class PdeEvaluate: class PdeEvaluate:
@ -18,8 +18,11 @@ class PdeEvaluate:
def __call__(self, inputs: Variables) -> Variables: def __call__(self, inputs: Variables) -> Variables:
result = Variables() result = Variables()
for node in self.binding_pde.sub_nodes: for node in self.binding_pde.sub_nodes:
sub_inputs = {k: v for k, v in Variables(inputs).items() if sub_inputs = {
k in node.inputs or k in node.derivatives} k: v
for k, v in Variables(inputs).items()
if k in node.inputs or k in node.derivatives
}
r = node.evaluate(sub_inputs) r = node.evaluate(sub_inputs)
result.update(r) result.update(r)
return result return result
@ -53,9 +56,9 @@ class PdeNode(Node):
def __init__(self, suffix: str = "", **kwargs): def __init__(self, suffix: str = "", **kwargs):
if len(suffix) > 0: if len(suffix) > 0:
self.suffix = '[' + kwargs['suffix'] + ']' # todo: check prefix self.suffix = "[" + kwargs["suffix"] + "]" # todo: check prefix
else: else:
self.suffix = '' self.suffix = ""
self.name = type(self).__name__ + self.suffix self.name = type(self).__name__ + self.suffix
self.evaluate = PdeEvaluate(self) self.evaluate = PdeEvaluate(self)
@ -77,8 +80,10 @@ class PdeNode(Node):
def __str__(self): def __str__(self):
subnode_str = "\n\n".join( subnode_str = "\n\n".join(
str(sub_node) + "Equation: \n" + str(self.equations[sub_node.name]) for sub_node in self.sub_nodes) str(sub_node) + "Equation: \n" + str(self.equations[sub_node.name])
return super().__str__() + "subnodes".center(30, '-') + '\n' + subnode_str for sub_node in self.sub_nodes
)
return super().__str__() + "subnodes".center(30, "-") + "\n" + subnode_str
# todo: test required # todo: test required

View File

@ -6,20 +6,20 @@ from typing import Dict, List
class Signal(Enum): class Signal(Enum):
REGISTER = 'signal_register' REGISTER = "signal_register"
SOLVE_START = 'signal_solve_start' SOLVE_START = "signal_solve_start"
TRAIN_PIPE_START = 'signal_train_pipe_start' TRAIN_PIPE_START = "signal_train_pipe_start"
BEFORE_COMPUTE_LOSS = 'before_compute_loss' BEFORE_COMPUTE_LOSS = "before_compute_loss"
AFTER_COMPUTE_LOSS = 'compute_loss' AFTER_COMPUTE_LOSS = "compute_loss"
BEFORE_BACKWARD = 'signal_before_backward' BEFORE_BACKWARD = "signal_before_backward"
TRAIN_PIPE_END = 'signal_train_pipe_end' TRAIN_PIPE_END = "signal_train_pipe_end"
SOLVE_END = 'signal_solve_end' SOLVE_END = "signal_solve_end"
class Receiver(metaclass=abc.ABCMeta): class Receiver(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def receive_notify(self, obj: object, message: Dict): def receive_notify(self, obj: object, message: Dict):
raise NotImplementedError('Method receive_notify() not implemented!') raise NotImplementedError("Method receive_notify() not implemented!")
class Notifier: class Notifier:

View File

@ -15,7 +15,7 @@ from idrlnet.variable import Variables, DomainVariables
from idrlnet.graph import VertexTaskPipeline from idrlnet.graph import VertexTaskPipeline
import idrlnet import idrlnet
__all__ = ['Solver'] __all__ = ["Solver"]
class Solver(Notifier, Optimizable): class Solver(Notifier, Optimizable):
@ -65,20 +65,23 @@ class Solver(Notifier, Optimizable):
:param kwargs: :param kwargs:
""" """
def __init__(self, sample_domains: Tuple[Union[DataNode, SampleDomain], ...], def __init__(
netnodes: List[NetNode], self,
pdes: Optional[List] = None, sample_domains: Tuple[Union[DataNode, SampleDomain], ...],
network_dir: str = './network_dir', netnodes: List[NetNode],
summary_dir: Optional[str] = None, pdes: Optional[List] = None,
max_iter: int = 1000, network_dir: str = "./network_dir",
save_freq: int = 100, summary_dir: Optional[str] = None,
print_freq: int = 10, max_iter: int = 1000,
loading: bool = True, save_freq: int = 100,
init_network_dirs: Optional[List[str]] = None, print_freq: int = 10,
opt_config: Dict = None, loading: bool = True,
schedule_config: Dict = None, init_network_dirs: Optional[List[str]] = None,
result_dir='train_domain/results', opt_config: Dict = None,
**kwargs): schedule_config: Dict = None,
result_dir="train_domain/results",
**kwargs,
):
self.network_dir: str = network_dir self.network_dir: str = network_dir
self.domain_losses = {domain.name: domain.loss_fn for domain in sample_domains} self.domain_losses = {domain.name: domain.loss_fn for domain in sample_domains}
@ -96,8 +99,16 @@ class Solver(Notifier, Optimizable):
self.save_freq = save_freq self.save_freq = save_freq
self.print_freq = print_freq self.print_freq = print_freq
try: try:
self.parse_configure(**{**({"opt_config": opt_config} if opt_config is not None else {}), self.parse_configure(
**({"schedule_config": schedule_config} if schedule_config is not None else {})}) **{
**({"opt_config": opt_config} if opt_config is not None else {}),
**(
{"schedule_config": schedule_config}
if schedule_config is not None
else {}
),
}
)
except Exception: except Exception:
logger.error("Optimizer configuration failed") logger.error("Optimizer configuration failed")
raise raise
@ -109,7 +120,10 @@ class Solver(Notifier, Optimizable):
pass pass
self.sample_domains: Tuple[DataNode, ...] = sample_domains self.sample_domains: Tuple[DataNode, ...] = sample_domains
self.summary_dir = self.network_dir if summary_dir is None else summary_dir self.summary_dir = self.network_dir if summary_dir is None else summary_dir
self.receivers: List[Receiver] = [SummaryReceiver(self.summary_dir), HandleResultReceiver(result_dir)] self.receivers: List[Receiver] = [
SummaryReceiver(self.summary_dir),
HandleResultReceiver(result_dir),
]
@property @property
def network_dir(self): def network_dir(self):
@ -136,12 +150,23 @@ class Solver(Notifier, Optimizable):
:return: A list of trainable parameters. :return: A list of trainable parameters.
:rtype: List[torch.nn.parameter.Parameter] :rtype: List[torch.nn.parameter.Parameter]
""" """
parameter_list = list(map(lambda _net_node: {'params': _net_node.net.parameters()}, parameter_list = list(
filter(lambda _net_node: not _net_node.is_reference and (not _net_node.fixed), map(
self.netnodes))) lambda _net_node: {"params": _net_node.net.parameters()},
filter(
lambda _net_node: not _net_node.is_reference
and (not _net_node.fixed),
self.netnodes,
),
)
)
if len(parameter_list) == 0: if len(parameter_list) == 0:
'''To make sure successful initialization of optimizers.''' """To make sure successful initialization of optimizers."""
parameter_list = [torch.nn.parameter.Parameter(data=torch.Tensor([0.]), requires_grad=True)] parameter_list = [
torch.nn.parameter.Parameter(
data=torch.Tensor([0.0]), requires_grad=True
)
]
logger.warning("No trainable parameters found!") logger.warning("No trainable parameters found!")
return parameter_list return parameter_list
@ -158,15 +183,15 @@ class Solver(Notifier, Optimizable):
"""return sovler information, it will return components recursively""" """return sovler information, it will return components recursively"""
str_list = [] str_list = []
str_list.append("nets: \n") str_list.append("nets: \n")
str_list.append(''.join([str(net) for net in self.netnodes])) str_list.append("".join([str(net) for net in self.netnodes]))
str_list.append("domains: \n") str_list.append("domains: \n")
str_list.append(''.join([str(domain) for domain in self.sample_domains])) str_list.append("".join([str(domain) for domain in self.sample_domains]))
str_list.append('\n') str_list.append("\n")
str_list.append('optimizer config:\n') str_list.append("optimizer config:\n")
for i, _class in enumerate(type(self).mro()): for i, _class in enumerate(type(self).mro()):
if _class == Optimizable: if _class == Optimizable:
str_list.append(super(type(self).mro()[i - 1], self).__str__()) str_list.append(super(type(self).mro()[i - 1], self).__str__())
return ''.join(str_list) return "".join(str_list)
def set_param_ranges(self, param_ranges: Dict): def set_param_ranges(self, param_ranges: Dict):
for domain in self.sample_domains: for domain in self.sample_domains:
@ -184,7 +209,7 @@ class Solver(Notifier, Optimizable):
for value in self.sample_domains: for value in self.sample_domains:
if value.name == name: if value.name == name:
return value return value
raise KeyError(f'domain {name} not exist!') raise KeyError(f"domain {name} not exist!")
def generate_computation_pipeline(self): def generate_computation_pipeline(self):
"""Generate computation pipeline for all domains. """Generate computation pipeline for all domains.
@ -195,28 +220,40 @@ class Solver(Notifier, Optimizable):
self.vertex_pipelines = {} self.vertex_pipelines = {}
for domain_name, var in in_var.items(): for domain_name, var in in_var.items():
logger.info(f"Constructing computation graph for domain <{domain_name}>") logger.info(f"Constructing computation graph for domain <{domain_name}>")
self.vertex_pipelines[domain_name] = VertexTaskPipeline(self.netnodes + self.pdes, var, self.vertex_pipelines[domain_name] = VertexTaskPipeline(
self.outvar_dict_index[domain_name]) self.netnodes + self.pdes, var, self.outvar_dict_index[domain_name]
)
self.vertex_pipelines[domain_name].display( self.vertex_pipelines[domain_name].display(
os.path.join(self.network_dir, f'{domain_name}_{self.global_step}.png')) os.path.join(self.network_dir, f"{domain_name}_{self.global_step}.png")
)
def forward_through_all_graph(self, invar_dict: DomainVariables, def forward_through_all_graph(
req_outvar_dict_index: Dict[str, List[str]]) -> DomainVariables: self, invar_dict: DomainVariables, req_outvar_dict_index: Dict[str, List[str]]
) -> DomainVariables:
outvar_dict = {} outvar_dict = {}
for (key, req_outvar_names) in req_outvar_dict_index.items(): for (key, req_outvar_names) in req_outvar_dict_index.items():
outvar_dict[key] = self.vertex_pipelines[key].forward_pipeline(invar_dict[key], req_outvar_names) outvar_dict[key] = self.vertex_pipelines[key].forward_pipeline(
invar_dict[key], req_outvar_names
)
return outvar_dict return outvar_dict
def append_sample_domain(self, datanode): def append_sample_domain(self, datanode):
self.sample_domains = self.sample_domains + (datanode,) self.sample_domains = self.sample_domains + (datanode,)
def _generate_dict_index(self) -> None: def _generate_dict_index(self) -> None:
self.invar_dict_index = {domain.name: domain.inputs for domain in self.sample_domains} self.invar_dict_index = {
self.outvar_dict_index = {domain.name: domain.outputs for domain in self.sample_domains} domain.name: domain.inputs for domain in self.sample_domains
self.lambda_dict_index = {domain.name: domain.lambda_outputs for domain in self.sample_domains} }
self.outvar_dict_index = {
domain.name: domain.outputs for domain in self.sample_domains
}
self.lambda_dict_index = {
domain.name: domain.lambda_outputs for domain in self.sample_domains
}
def generate_in_out_dict(self, samples: DomainVariables) -> \ def generate_in_out_dict(
Tuple[DomainVariables, DomainVariables, DomainVariables]: self, samples: DomainVariables
) -> Tuple[DomainVariables, DomainVariables, DomainVariables]:
invar_dict = {} invar_dict = {}
for domain, variable in samples.items(): for domain, variable in samples.items():
inner = {} inner = {}
@ -226,20 +263,40 @@ class Solver(Notifier, Optimizable):
invar_dict[domain] = inner invar_dict[domain] = inner
invar_dict = { invar_dict = {
domain: Variables({key: val for key, val in variable.items() if key in self.invar_dict_index[domain]}) for domain: Variables(
domain, variable in samples.items()} {
key: val
for key, val in variable.items()
if key in self.invar_dict_index[domain]
}
)
for domain, variable in samples.items()
}
outvar_dict = { outvar_dict = {
domain: Variables({key: val for key, val in variable.items() if key in self.outvar_dict_index[domain]}) for domain: Variables(
domain, variable in samples.items()} {
key: val
for key, val in variable.items()
if key in self.outvar_dict_index[domain]
}
)
for domain, variable in samples.items()
}
lambda_dict = { lambda_dict = {
domain: Variables({key: val for key, val in variable.items() if key in self.lambda_dict_index[domain]}) for domain: Variables(
domain, variable in samples.items()} {
key: val
for key, val in variable.items()
if key in self.lambda_dict_index[domain]
}
)
for domain, variable in samples.items()
}
return invar_dict, outvar_dict, lambda_dict return invar_dict, outvar_dict, lambda_dict
def solve(self): def solve(self):
"""After the solver instance is initialized, the method could be called to solve the entire problem. """After the solver instance is initialized, the method could be called to solve the entire problem."""
""" self.notify(self, message={Signal.SOLVE_START: "default"})
self.notify(self, message={Signal.SOLVE_START: 'default'})
while self.global_step < self.max_iter: while self.global_step < self.max_iter:
loss = self.train_pipe() loss = self.train_pipe()
if self.global_step % self.print_freq == 0: if self.global_step % self.print_freq == 0:
@ -247,13 +304,13 @@ class Solver(Notifier, Optimizable):
if self.global_step % self.save_freq == 0: if self.global_step % self.save_freq == 0:
self.save() self.save()
logger.info("Training Stage Ends") logger.info("Training Stage Ends")
self.notify(self, message={Signal.SOLVE_END: 'default'}) self.notify(self, message={Signal.SOLVE_END: "default"})
def train_pipe(self): def train_pipe(self):
"""Sample once; calculate the loss once; backward propagation once """Sample once; calculate the loss once; backward propagation once
:return: None :return: None
""" """
self.notify(self, message={Signal.TRAIN_PIPE_START: 'defaults'}) self.notify(self, message={Signal.TRAIN_PIPE_START: "defaults"})
for opt in self.optimizers: for opt in self.optimizers:
opt.zero_grad() opt.zero_grad()
samples = self.sample_variables_from_domains() samples = self.sample_variables_from_domains()
@ -263,7 +320,7 @@ class Solver(Notifier, Optimizable):
loss = self.compute_loss(in_var, pred_out_sample, true_out, lambda_out) loss = self.compute_loss(in_var, pred_out_sample, true_out, lambda_out)
except RuntimeError: except RuntimeError:
raise raise
self.notify(self, message={Signal.BEFORE_BACKWARD: 'defaults'}) self.notify(self, message={Signal.BEFORE_BACKWARD: "defaults"})
loss.backward() loss.backward()
for opt in self.optimizers: for opt in self.optimizers:
opt.step() opt.step()
@ -271,40 +328,64 @@ class Solver(Notifier, Optimizable):
for scheduler in self.schedulers: for scheduler in self.schedulers:
scheduler.step(self.global_step) scheduler.step(self.global_step)
self.notify(self, message={Signal.TRAIN_PIPE_END: 'defaults'}) self.notify(self, message={Signal.TRAIN_PIPE_END: "defaults"})
return loss return loss
def compute_loss(self, in_var: DomainVariables, pred_out_sample: DomainVariables, def compute_loss(
true_out: DomainVariables, self,
lambda_out: DomainVariables) -> torch.Tensor: in_var: DomainVariables,
"""Compute the total loss in one epoch. pred_out_sample: DomainVariables,
true_out: DomainVariables,
""" lambda_out: DomainVariables,
) -> torch.Tensor:
"""Compute the total loss in one epoch."""
diff = dict() diff = dict()
for domain_name, domain_val in true_out.items(): for domain_name, domain_val in true_out.items():
if len(domain_val) == 0: if len(domain_val) == 0:
continue continue
diff[domain_name] = pred_out_sample[domain_name] - domain_val.to_torch_tensor_() diff[domain_name] = (
pred_out_sample[domain_name] - domain_val.to_torch_tensor_()
)
diff[domain_name].update(lambda_out[domain_name]) diff[domain_name].update(lambda_out[domain_name])
diff[domain_name].update(area=in_var[domain_name]['area']) diff[domain_name].update(area=in_var[domain_name]["area"])
for domain, var in diff.items(): for domain, var in diff.items():
lambda_diff = dict() lambda_diff = dict()
for constraint, _ in var.items(): for constraint, _ in var.items():
if 'lambda_' + constraint in in_var[domain].keys(): if "lambda_" + constraint in in_var[domain].keys():
lambda_diff['lambda_' + constraint] = in_var[domain]['lambda_' + constraint] lambda_diff["lambda_" + constraint] = in_var[domain][
"lambda_" + constraint
]
var.update(lambda_diff) var.update(lambda_diff)
self.loss_component = Variables( self.loss_component = Variables(
ChainMap( ChainMap(
*[diff[domain_name].weighted_loss(f"{domain_name}_loss", *[
loss_function=self.domain_losses[domain_name]) for diff[domain_name].weighted_loss(
domain_name, domain_val in f"{domain_name}_loss",
diff.items()])) loss_function=self.domain_losses[domain_name],
)
for domain_name, domain_val in diff.items()
]
)
)
self.notify(self, message={Signal.BEFORE_COMPUTE_LOSS: {**self.loss_component}}) self.notify(self, message={Signal.BEFORE_COMPUTE_LOSS: {**self.loss_component}})
loss = sum({domain_name: self.get_sample_domain(domain_name).sigma * self.loss_component[f"{domain_name}_loss"] for loss = sum(
domain_name in diff}.values()) {
self.notify(self, message={Signal.AFTER_COMPUTE_LOSS: {**self.loss_component, **{'total_loss': loss}}}) domain_name: self.get_sample_domain(domain_name).sigma
* self.loss_component[f"{domain_name}_loss"]
for domain_name in diff
}.values()
)
self.notify(
self,
message={
Signal.AFTER_COMPUTE_LOSS: {
**self.loss_component,
**{"total_loss": loss},
}
},
)
return loss return loss
def infer_step(self, domain_attr: Dict[str, List[str]]) -> DomainVariables: def infer_step(self, domain_attr: Dict[str, List[str]]) -> DomainVariables:
@ -323,40 +404,46 @@ class Solver(Notifier, Optimizable):
return {data_node.name: data_node.sample() for data_node in self.sample_domains} return {data_node.name: data_node.sample() for data_node in self.sample_domains}
def save(self): def save(self):
"""Save parameters of netnodes and the global step to `model.ckpt`. """Save parameters of netnodes and the global step to `model.ckpt`."""
""" save_path = os.path.join(self.network_dir, "model.ckpt")
save_path = os.path.join(self.network_dir, 'model.ckpt')
logger.info("save to path: {}".format(os.path.abspath(save_path))) logger.info("save to path: {}".format(os.path.abspath(save_path)))
save_dict = {f"{net_node.name}_dict": net_node.state_dict() for net_node in save_dict = {
filter(lambda _net: not _net.is_reference, self.netnodes)} f"{net_node.name}_dict": net_node.state_dict()
for net_node in filter(lambda _net: not _net.is_reference, self.netnodes)
}
for i, opt in enumerate(self.optimizers): for i, opt in enumerate(self.optimizers):
save_dict['optimizer_{}_dict'.format(i)] = opt.state_dict() save_dict["optimizer_{}_dict".format(i)] = opt.state_dict()
save_dict['global_step'] = self.global_step save_dict["global_step"] = self.global_step
torch.save(save_dict, save_path) torch.save(save_dict, save_path)
def init_load(self): def init_load(self):
for network_dir in self.init_network_dirs: for network_dir in self.init_network_dirs:
save_path = os.path.join(network_dir, 'model.ckpt') save_path = os.path.join(network_dir, "model.ckpt")
save_dict = torch.load(save_path) save_dict = torch.load(save_path)
for net_node in self.netnodes: for net_node in self.netnodes:
if f"{net_node.name}_dict" in save_dict.keys() and not net_node.is_reference: if (
f"{net_node.name}_dict" in save_dict.keys()
and not net_node.is_reference
):
net_node.load_state_dict(save_dict[f"{net_node.name}_dict"]) net_node.load_state_dict(save_dict[f"{net_node.name}_dict"])
logger.info(f"Successfully loading initialization {net_node.name}.") logger.info(f"Successfully loading initialization {net_node.name}.")
def load(self): def load(self):
"""Load parameters of netnodes and the global step from `model.ckpt`. """Load parameters of netnodes and the global step from `model.ckpt`."""
""" save_path = os.path.join(self.network_dir, "model.ckpt")
save_path = os.path.join(self.network_dir, 'model.ckpt')
if not idrlnet.GPU_ENABLED: if not idrlnet.GPU_ENABLED:
save_dict = torch.load(save_path, map_location=torch.device('cpu')) save_dict = torch.load(save_path, map_location=torch.device("cpu"))
else: else:
save_dict = torch.load(save_path) save_dict = torch.load(save_path)
# todo: save on CPU, load on GPU # todo: save on CPU, load on GPU
for i, opt in enumerate(self.optimizers): for i, opt in enumerate(self.optimizers):
opt.load_state_dict(save_dict['optimizer_{}_dict'.format(i)]) opt.load_state_dict(save_dict["optimizer_{}_dict".format(i)])
self.global_step = save_dict['global_step'] self.global_step = save_dict["global_step"]
for net_node in self.netnodes: for net_node in self.netnodes:
if f"{net_node.name}_dict" in save_dict.keys() and not net_node.is_reference: if (
f"{net_node.name}_dict" in save_dict.keys()
and not net_node.is_reference
):
net_node.load_state_dict(save_dict[f"{net_node.name}_dict"]) net_node.load_state_dict(save_dict[f"{net_node.name}_dict"])
logger.info(f"Successfully loading {net_node.name}.") logger.info(f"Successfully loading {net_node.name}.")
@ -364,27 +451,34 @@ class Solver(Notifier, Optimizable):
""" """
Call interfaces of ``Optimizable`` Call interfaces of ``Optimizable``
""" """
opt = self.optimizer_config['optimizer'] opt = self.optimizer_config["optimizer"]
if isinstance(opt, str) and opt in Optimizable.OPTIMIZER_MAP: if isinstance(opt, str) and opt in Optimizable.OPTIMIZER_MAP:
opt = Optimizable.OPTIMIZER_MAP[opt](self.trainable_parameters, opt = Optimizable.OPTIMIZER_MAP[opt](
**{k: v for k, v in self.optimizer_config.items() if k != 'optimizer'}) self.trainable_parameters,
**{k: v for k, v in self.optimizer_config.items() if k != "optimizer"},
)
elif isinstance(opt, Callable): elif isinstance(opt, Callable):
opt = opt opt = opt
else: else:
raise NotImplementedError( raise NotImplementedError(
'The optimizer is not implemented. You may use one of the following optimizer:\n' + '\n'.join( "The optimizer is not implemented. You may use one of the following optimizer:\n"
Optimizable.OPTIMIZER_MAP.keys()) + '\n Example: opt_config=dict(optimizer="Adam", lr=1e-3)') + "\n".join(Optimizable.OPTIMIZER_MAP.keys())
+ '\n Example: opt_config=dict(optimizer="Adam", lr=1e-3)'
)
lr_scheduler = self.schedule_config['scheduler'] lr_scheduler = self.schedule_config["scheduler"]
if isinstance(lr_scheduler, str) and lr_scheduler in Optimizable.SCHEDULE_MAP: if isinstance(lr_scheduler, str) and lr_scheduler in Optimizable.SCHEDULE_MAP:
lr_scheduler = Optimizable.SCHEDULE_MAP[lr_scheduler](opt, lr_scheduler = Optimizable.SCHEDULE_MAP[lr_scheduler](
**{k: v for k, v in self.schedule_config.items() if opt,
k != 'scheduler'}) **{k: v for k, v in self.schedule_config.items() if k != "scheduler"},
)
elif isinstance(lr_scheduler, Callable): elif isinstance(lr_scheduler, Callable):
lr_scheduler = lr_scheduler lr_scheduler = lr_scheduler
else: else:
raise NotImplementedError( raise NotImplementedError(
'The scheduler is not implemented. You may use one of the following scheduler:\n' + '\n'.join( "The scheduler is not implemented. You may use one of the following scheduler:\n"
Optimizable.SCHEDULE_MAP.keys()) + '\n Example: schedule_config=dict(scheduler="ExponentialLR", gamma=0.999') + "\n".join(Optimizable.SCHEDULE_MAP.keys())
+ '\n Example: schedule_config=dict(scheduler="ExponentialLR", gamma=0.999'
)
self.optimizers = [opt] self.optimizers = [opt]
self.schedulers = [lr_scheduler] self.schedulers = [lr_scheduler]

View File

@ -10,7 +10,7 @@ import torch
from idrlnet.header import DIFF_SYMBOL from idrlnet.header import DIFF_SYMBOL
from functools import reduce from functools import reduce
__all__ = ['integral', 'torch_lambdify'] __all__ = ["integral", "torch_lambdify"]
def integral_fun(x): def integral_fun(x):
@ -19,7 +19,7 @@ def integral_fun(x):
return x return x
integral = implemented_function('integral', lambda x: integral_fun(x)) integral = implemented_function("integral", lambda x: integral_fun(x))
def torch_lambdify(r, f, *args, **kwargs): def torch_lambdify(r, f, *args, **kwargs):
@ -41,27 +41,27 @@ def torch_lambdify(r, f, *args, **kwargs):
# todo: more functions # todo: more functions
TORCH_SYMPY_PRINTER = { TORCH_SYMPY_PRINTER = {
'sin': torch.sin, "sin": torch.sin,
'cos': torch.cos, "cos": torch.cos,
'tan': torch.tan, "tan": torch.tan,
'exp': torch.exp, "exp": torch.exp,
'sqrt': torch.sqrt, "sqrt": torch.sqrt,
'Abs': torch.abs, "Abs": torch.abs,
'tanh': torch.tanh, "tanh": torch.tanh,
'DiracDelta': torch.zeros_like, "DiracDelta": torch.zeros_like,
'Heaviside': lambda x: torch.heaviside(x, torch.tensor([0.])), "Heaviside": lambda x: torch.heaviside(x, torch.tensor([0.0])),
'amin': lambda x: reduce(lambda y, z: torch.minimum(y, z), x), "amin": lambda x: reduce(lambda y, z: torch.minimum(y, z), x),
'amax': lambda x: reduce(lambda y, z: torch.maximum(y, z), x), "amax": lambda x: reduce(lambda y, z: torch.maximum(y, z), x),
'Min': lambda *x: reduce(lambda y, z: torch.minimum(y, z), x), "Min": lambda *x: reduce(lambda y, z: torch.minimum(y, z), x),
'Max': lambda *x: reduce(lambda y, z: torch.maximum(y, z), x), "Max": lambda *x: reduce(lambda y, z: torch.maximum(y, z), x),
'equal': lambda x, y: torch.isclose(x, y), "equal": lambda x, y: torch.isclose(x, y),
'Xor': torch.logical_xor, "Xor": torch.logical_xor,
'log': torch.log, "log": torch.log,
'sinh': torch.sinh, "sinh": torch.sinh,
'cosh': torch.cosh, "cosh": torch.cosh,
'asin': torch.arcsin, "asin": torch.arcsin,
'acos': torch.arccos, "acos": torch.arccos,
'atan': torch.arctan, "atan": torch.arctan,
} }
@ -75,9 +75,12 @@ def _replace_derivatives(expr):
expr = expr.subs(deriv, Function(str(deriv))(*deriv.free_symbols)) expr = expr.subs(deriv, Function(str(deriv))(*deriv.free_symbols))
while True: while True:
try: try:
custom_fun = {_fun for _fun in expr.atoms(Function) if custom_fun = {
(_fun.class_key()[1] == 0) and (not _fun.class_key()[2] == 'integral') _fun
}.pop() for _fun in expr.atoms(Function)
if (_fun.class_key()[1] == 0)
and (not _fun.class_key()[2] == "integral")
}.pop()
new_symbol_name = str(custom_fun) new_symbol_name = str(custom_fun)
expr = expr.subs(custom_fun, Symbol(new_symbol_name)) expr = expr.subs(custom_fun, Symbol(new_symbol_name))
except KeyError: except KeyError:
@ -90,7 +93,10 @@ class UnderlineDerivativePrinter(StrPrinter):
return expr.func.__name__ return expr.func.__name__
def _print_Derivative(self, expr): def _print_Derivative(self, expr):
return "".join([str(expr.args[0].func)] + [order * (DIFF_SYMBOL + str(key)) for key, order in expr.args[1:]]) return "".join(
[str(expr.args[0].func)]
+ [order * (DIFF_SYMBOL + str(key)) for key, order in expr.args[1:]]
)
def sstr(expr, **settings): def sstr(expr, **settings):

View File

@ -13,14 +13,14 @@ from collections import defaultdict
import pandas as pd import pandas as pd
from idrlnet.header import DIFF_SYMBOL from idrlnet.header import DIFF_SYMBOL
__all__ = ['Loss', 'Variables', 'DomainVariables', 'export_var'] __all__ = ["Loss", "Variables", "DomainVariables", "export_var"]
class Loss(enum.Enum): class Loss(enum.Enum):
"""Enumerate loss functions""" """Enumerate loss functions"""
L1 = 'L1' L1 = "L1"
square = 'square' square = "square"
class LossFunction: class LossFunction:
@ -35,56 +35,67 @@ class LossFunction:
raise NotImplementedError(f"loss function {loss_function} is not defined!") raise NotImplementedError(f"loss function {loss_function} is not defined!")
@staticmethod @staticmethod
def weighted_L1_loss(variables: 'Variables', name: str) -> 'Variables': def weighted_L1_loss(variables: "Variables", name: str) -> "Variables":
loss = 0. loss = 0.0
for key, val in variables.items(): for key, val in variables.items():
if key.startswith("lambda_") or key == 'area': if key.startswith("lambda_") or key == "area":
continue continue
elif "lambda_" + key in variables.keys(): elif "lambda_" + key in variables.keys():
loss += torch.sum((torch.abs(val)) * variables["lambda_" + key] * variables["area"]) loss += torch.sum(
(torch.abs(val)) * variables["lambda_" + key] * variables["area"]
)
else: else:
loss += torch.sum((torch.abs(val)) * variables["area"]) loss += torch.sum((torch.abs(val)) * variables["area"])
return Variables({name: loss}) return Variables({name: loss})
@staticmethod @staticmethod
def weighted_square_loss(variables: 'Variables', name: str) -> 'Variables': def weighted_square_loss(variables: "Variables", name: str) -> "Variables":
loss = 0. loss = 0.0
for key, val in variables.items(): for key, val in variables.items():
if key.startswith("lambda_") or key == 'area': if key.startswith("lambda_") or key == "area":
continue continue
elif "lambda_" + key in variables.keys(): elif "lambda_" + key in variables.keys():
loss += torch.sum((val ** 2) * variables["lambda_" + key] * variables["area"]) loss += torch.sum(
(val ** 2) * variables["lambda_" + key] * variables["area"]
)
else: else:
loss += torch.sum((val ** 2) * variables["area"]) loss += torch.sum((val ** 2) * variables["area"])
return Variables({name: loss}) return Variables({name: loss})
class Variables(dict): class Variables(dict):
def __sub__(self, other: 'Variables') -> 'Variables': def __sub__(self, other: "Variables") -> "Variables":
return Variables( return Variables(
{key: (self[key] if key in self else 0) - (other[key] if key in other else 0) for key in {**self, **other}}) {
key: (self[key] if key in self else 0)
- (other[key] if key in other else 0)
for key in {**self, **other}
}
)
def weighted_loss(self, name: str, loss_function: Union[Loss, str]) -> 'Variables': def weighted_loss(self, name: str, loss_function: Union[Loss, str]) -> "Variables":
"""Regard the variable as residuals and reduce to a weighted_loss.""" """Regard the variable as residuals and reduce to a weighted_loss."""
return LossFunction.weighted_loss(variables=self, loss_function=loss_function, name=name) return LossFunction.weighted_loss(
variables=self, loss_function=loss_function, name=name
)
def subset(self, subset_keys: List[str]) -> 'Variables': def subset(self, subset_keys: List[str]) -> "Variables":
"""Construct a new variable with subset references""" """Construct a new variable with subset references"""
return Variables({name: self[name] for name in subset_keys if name in self}) return Variables({name: self[name] for name in subset_keys if name in self})
def to_torch_tensor_(self) -> 'Variables[str, torch.Tensor]': def to_torch_tensor_(self) -> "Variables[str, torch.Tensor]":
"""Convert the variables to torch.Tensor""" """Convert the variables to torch.Tensor"""
for key, val in self.items(): for key, val in self.items():
if not isinstance(val, torch.Tensor): if not isinstance(val, torch.Tensor):
self[key] = torch.Tensor(val) self[key] = torch.Tensor(val)
if (not key.startswith('lambda_')) and (not key == 'area'): if (not key.startswith("lambda_")) and (not key == "area"):
self[key].requires_grad_() self[key].requires_grad_()
return self return self
def to_ndarray_(self) -> 'Variables[str, np.ndarray]': def to_ndarray_(self) -> "Variables[str, np.ndarray]":
"""convert to a numpy based variables""" """convert to a numpy based variables"""
for key, val in self.items(): for key, val in self.items():
@ -92,7 +103,7 @@ class Variables(dict):
self[key] = val.detach().cpu().numpy() self[key] = val.detach().cpu().numpy()
return self return self
def to_ndarray(self) -> 'Variables[str, np.ndarray]': def to_ndarray(self) -> "Variables[str, np.ndarray]":
"""Return a new numpy based variables""" """Return a new numpy based variables"""
new_var = Variables() new_var = Variables()
@ -130,26 +141,36 @@ class Variables(dict):
variables[name] = var_t variables[name] = var_t
return variables return variables
def differentiate_one_step_(self: 'Variables', independent_var: 'Variables', required_derivatives: List[str]): def differentiate_one_step_(
self: "Variables", independent_var: "Variables", required_derivatives: List[str]
):
"""One order of derivatives will be computed towards the required_derivatives.""" """One order of derivatives will be computed towards the required_derivatives."""
required_derivatives = [d for d in required_derivatives if d not in self] required_derivatives = [d for d in required_derivatives if d not in self]
required_derivatives_set = set( required_derivatives_set = set(
tuple(required_derivative.split(DIFF_SYMBOL)) for required_derivative in required_derivatives) tuple(required_derivative.split(DIFF_SYMBOL))
for required_derivative in required_derivatives
)
dependent_var_set = set(tuple(dv.split(DIFF_SYMBOL)) for dv in self.keys()) dependent_var_set = set(tuple(dv.split(DIFF_SYMBOL)) for dv in self.keys())
computable_derivative_dict = defaultdict(set) computable_derivative_dict = defaultdict(set)
for dv, rd in itertools.product(dependent_var_set, required_derivatives_set): for dv, rd in itertools.product(dependent_var_set, required_derivatives_set):
if len(rd) > len(dv) and rd[:len(dv)] == dv and rd[:len(dv) + 1] not in dependent_var_set: if (
len(rd) > len(dv)
and rd[: len(dv)] == dv
and rd[: len(dv) + 1] not in dependent_var_set
):
computable_derivative_dict[rd[len(dv)]].add(DIFF_SYMBOL.join(dv)) computable_derivative_dict[rd[len(dv)]].add(DIFF_SYMBOL.join(dv))
derivative_variables = Variables() derivative_variables = Variables()
for key, value in computable_derivative_dict.items(): for key, value in computable_derivative_dict.items():
for v in value: for v in value:
f__x = torch.autograd.grad(self[v], f__x = torch.autograd.grad(
independent_var[key], self[v],
grad_outputs=torch.ones_like(self[v]), independent_var[key],
retain_graph=True, grad_outputs=torch.ones_like(self[v]),
create_graph=True, retain_graph=True,
allow_unused=True)[0] create_graph=True,
allow_unused=True,
)[0]
if f__x is not None: if f__x is not None:
f__x.requires_grad_() f__x.requires_grad_()
else: else:
@ -157,7 +178,9 @@ class Variables(dict):
derivative_variables[DIFF_SYMBOL.join([v, key])] = f__x derivative_variables[DIFF_SYMBOL.join([v, key])] = f__x
self.update(derivative_variables) self.update(derivative_variables)
def differentiate_(self: 'Variables', independent_var: 'Variables', required_derivatives: List[str]): def differentiate_(
self: "Variables", independent_var: "Variables", required_derivatives: List[str]
):
"""Derivatives will be computed towards the required_derivatives""" """Derivatives will be computed towards the required_derivatives"""
n_keys = 0 n_keys = 0
@ -168,8 +191,11 @@ class Variables(dict):
new_keys = len(self.keys()) new_keys = len(self.keys())
@staticmethod @staticmethod
def var_differentiate_one_step(dependent_var: 'Variables', independent_var: 'Variables', def var_differentiate_one_step(
required_derivatives: List[str]): dependent_var: "Variables",
independent_var: "Variables",
required_derivatives: List[str],
):
"""Perform one step of differentiate towards the required_derivatives""" """Perform one step of differentiate towards the required_derivatives"""
dependent_var.differentiate_one_step_(independent_var, required_derivatives) dependent_var.differentiate_one_step_(independent_var, required_derivatives)
@ -177,15 +203,15 @@ class Variables(dict):
def to_csv(self, filename: str) -> None: def to_csv(self, filename: str) -> None:
"""Export variable to csv""" """Export variable to csv"""
if not filename.endswith('.csv'): if not filename.endswith(".csv"):
filename += '.csv' filename += ".csv"
df = self.to_dataframe() df = self.to_dataframe()
df.to_csv(filename, index=False) df.to_csv(filename, index=False)
def to_vtu(self, filename: str, coordinates=None) -> None: def to_vtu(self, filename: str, coordinates=None) -> None:
"""Export variable to vtu""" """Export variable to vtu"""
coordinates = ['x', 'y', 'z'] if coordinates is None else coordinates coordinates = ["x", "y", "z"] if coordinates is None else coordinates
shape = 0 shape = 0
for axis in coordinates: for axis in coordinates:
if axis not in self.keys(): if axis not in self.keys():
@ -196,27 +222,29 @@ class Variables(dict):
if value.shape == (1, 1): if value.shape == (1, 1):
self[key] = np.ones(shape) * value self[key] = np.ones(shape) * value
self[key] = np.asarray(self[key], dtype=np.float64) self[key] = np.asarray(self[key], dtype=np.float64)
pointsToVTK(filename, pointsToVTK(
self[coordinates[0]][:, 0].copy(), filename,
self[coordinates[1]][:, 0].copy(), self[coordinates[0]][:, 0].copy(),
self[coordinates[2]][:, 0].copy(), self[coordinates[1]][:, 0].copy(),
data={key: value[:, 0].copy() for key, value in self.items()}) self[coordinates[2]][:, 0].copy(),
data={key: value[:, 0].copy() for key, value in self.items()},
)
def save(self, path, formats=None): def save(self, path, formats=None):
"""Export variable to various formats""" """Export variable to various formats"""
if formats is None: if formats is None:
formats = ['np', 'csv', 'vtu'] formats = ["np", "csv", "vtu"]
np_var = self.to_ndarray() np_var = self.to_ndarray()
if 'np' in formats: if "np" in formats:
np.savez(path, **np_var) np.savez(path, **np_var)
if 'csv' in formats: if "csv" in formats:
np_var.to_csv(path) np_var.to_csv(path)
if 'vtu' in formats: if "vtu" in formats:
np_var.to_vtu(filename=path) np_var.to_vtu(filename=path)
@staticmethod @staticmethod
def cat(*var_list) -> 'Variables': def cat(*var_list) -> "Variables":
"""todo: catenate in var list""" """todo: catenate in var list"""
return Variables() return Variables()
@ -224,12 +252,14 @@ class Variables(dict):
DomainVariables = Dict[str, Variables] DomainVariables = Dict[str, Variables]
def export_var(domain_var: DomainVariables, path='./inference_domain/results', formats=None): def export_var(
domain_var: DomainVariables, path="./inference_domain/results", formats=None
):
"""Export a dict of variables to ``csv``, ``vtu`` or ``npz``.""" """Export a dict of variables to ``csv``, ``vtu`` or ``npz``."""
if formats is None: if formats is None:
formats = ['csv', 'vtu', 'np'] formats = ["csv", "vtu", "np"]
path = pathlib.Path(path) path = pathlib.Path(path)
path.mkdir(exist_ok=True, parents=True) path.mkdir(exist_ok=True, parents=True)
for key in domain_var.keys(): for key in domain_var.keys():
domain_var[key].save(os.path.join(path, f'{key}'), formats) domain_var[key].save(os.path.join(path, f"{key}"), formats)