From f94494c43e7cd4115f147d86c910e184e4fcdc22 Mon Sep 17 00:00:00 2001 From: zweien <278954153@qq.com> Date: Tue, 13 Jul 2021 10:39:09 +0800 Subject: [PATCH] style: change to black style --- docs/conf.py | 30 +- examples/Volterra_IDE/volterra_ide.py | 64 ++-- examples/burgers_equation/burgers_equation.py | 77 +++-- examples/euler_beam/euler_beam.py | 67 ++-- .../inverse_wave_equation.py | 111 ++++--- .../minimal_surface_of_revolution.py | 97 +++--- .../minimal_surface_of_revolution_pretrain.py | 28 +- .../parameterized_poisson.py | 80 ++--- examples/simple_poisson/simple_poisson.py | 70 +++-- idrlnet/__init__.py | 11 +- idrlnet/architecture/grid.py | 198 +++++++++--- idrlnet/architecture/layer.py | 67 ++-- idrlnet/architecture/mlp.py | 184 +++++++---- idrlnet/callbacks.py | 44 ++- idrlnet/data.py | 96 ++++-- idrlnet/geo_utils/geo_builder.py | 54 ++-- idrlnet/geo_utils/sympy_np.py | 68 +++-- idrlnet/graph.py | 139 ++++++--- idrlnet/header.py | 13 +- idrlnet/net.py | 37 ++- idrlnet/node.py | 32 +- idrlnet/pde.py | 19 +- idrlnet/receivers.py | 18 +- idrlnet/solver.py | 288 ++++++++++++------ idrlnet/torch_util.py | 60 ++-- idrlnet/variable.py | 126 +++++--- 26 files changed, 1343 insertions(+), 735 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 13afc9f..f2aaf9a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,16 +13,16 @@ import os import sys -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) # -- Project information ----------------------------------------------------- -project = 'idrlnet' -copyright = '2021, IDRL' -author = 'IDRL' +project = "idrlnet" +copyright = "2021, IDRL" +author = "IDRL" # The full version, including alpha/beta/rc tags -release = '0.0.1-rc1' +release = "0.0.1-rc1" # -- General configuration --------------------------------------------------- @@ -34,37 +34,37 @@ extensions = [ "sphinx.ext.mathjax", "sphinx.ext.napoleon", "sphinx.ext.viewcode", - 'myst_parser', - 'sphinx.ext.autosectionlabel', + "myst_parser", + "sphinx.ext.autosectionlabel", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] source_suffix = { - '.rst': 'restructuredtext', - '.txt': 'markdown', - '.md': 'markdown', + ".rst": "restructuredtext", + ".txt": "markdown", + ".md": "markdown", } # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # for MarkdownParser -from sphinx_markdown_parser.parser import MarkdownParser # noqa +from sphinx_markdown_parser.parser import MarkdownParser # noqa # def setup(app): diff --git a/examples/Volterra_IDE/volterra_ide.py b/examples/Volterra_IDE/volterra_ide.py index a3b62a4..8c4317d 100644 --- a/examples/Volterra_IDE/volterra_ide.py +++ b/examples/Volterra_IDE/volterra_ide.py @@ -3,9 +3,9 @@ import sympy as sp import numpy as np import matplotlib.pyplot as plt -x = sp.Symbol('x') -s = sp.Symbol('s') -f = sp.Function('f')(x) +x = sp.Symbol("x") +s = sp.Symbol("s") +f = sp.Function("f")(x) geo = sc.Line1D(0, 5) @@ -19,43 +19,49 @@ def interior(): @sc.datanode def init(): points = geo.sample_boundary(1, sieve=sp.Eq(x, 0)) - points['lambda_f'] = 1000 * np.ones_like(points['x']) - constraints = {'f': 1} + points["lambda_f"] = 1000 * np.ones_like(points["x"]) + constraints = {"f": 1} return points, constraints -@sc.datanode(name='InteriorInfer') +@sc.datanode(name="InteriorInfer") def infer(): - points = {'x': np.linspace(0, 5, 1000).reshape(-1, 1)} + points = {"x": np.linspace(0, 5, 1000).reshape(-1, 1)} return points, {} -netnode = sc.get_net_node(inputs=('x',), outputs=('f',), name='net') -exp_lhs = sc.ExpressionNode(expression=f.diff(x) + f, name='lhs') +netnode = sc.get_net_node(inputs=("x",), outputs=("f",), name="net") +exp_lhs = sc.ExpressionNode(expression=f.diff(x) + f, name="lhs") -fs = sp.Symbol('fs') -exp_rhs = sc.Int1DNode(expression=sp.exp(s - x) * fs, var=s, lb=0, ub=x, expression_name='rhs', - funs={'fs': {'eval': netnode, - 'input_map': {'x': 's'}, - 'output_map': {'f': 'fs'}}}, - degree=10) -diff = sc.Difference(T='lhs', S='rhs', dim=1, time=False) +fs = sp.Symbol("fs") +exp_rhs = sc.Int1DNode( + expression=sp.exp(s - x) * fs, + var=s, + lb=0, + ub=x, + expression_name="rhs", + funs={"fs": {"eval": netnode, "input_map": {"x": "s"}, "output_map": {"f": "fs"}}}, + degree=10, +) +diff = sc.Difference(T="lhs", S="rhs", dim=1, time=False) -solver = sc.Solver(sample_domains=(interior(), init(), infer()), - netnodes=[netnode], - pdes=[exp_lhs, exp_rhs, diff], - loading=True, - max_iter=3000) +solver = sc.Solver( + sample_domains=(interior(), init(), infer()), + netnodes=[netnode], + pdes=[exp_lhs, exp_rhs, diff], + loading=True, + max_iter=3000, +) solver.solve() -points = solver.infer_step({'InteriorInfer': ['x', 'f']}) -num_x = points['InteriorInfer']['x'].detach().cpu().numpy().ravel() -num_f = points['InteriorInfer']['f'].detach().cpu().numpy().ravel() +points = solver.infer_step({"InteriorInfer": ["x", "f"]}) +num_x = points["InteriorInfer"]["x"].detach().cpu().numpy().ravel() +num_f = points["InteriorInfer"]["f"].detach().cpu().numpy().ravel() -fig = plt.figure(figsize=(8,4)) +fig = plt.figure(figsize=(8, 4)) plt.plot(num_x, num_f) plt.plot(num_x, np.exp(-num_x) * np.cosh(num_x)) -plt.xlabel('x') -plt.ylabel('y') -plt.legend(['Prediction', 'Exact']) -plt.savefig('ide.png', dpi=1000, bbox_inches='tight') +plt.xlabel("x") +plt.ylabel("y") +plt.legend(["Prediction", "Exact"]) +plt.savefig("ide.png", dpi=1000, bbox_inches="tight") plt.show() diff --git a/examples/burgers_equation/burgers_equation.py b/examples/burgers_equation/burgers_equation.py index 6758d11..fcfe41e 100644 --- a/examples/burgers_equation/burgers_equation.py +++ b/examples/burgers_equation/burgers_equation.py @@ -4,63 +4,82 @@ import matplotlib.pyplot as plt import matplotlib.tri as tri import idrlnet.shortcut as sc -x = Symbol('x') -t_symbol = Symbol('t') +x = Symbol("x") +t_symbol = Symbol("t") time_range = {t_symbol: (0, 1)} -geo = sc.Line1D(-1., 1.) +geo = sc.Line1D(-1.0, 1.0) -@sc.datanode(name='burgers_equation') +@sc.datanode(name="burgers_equation") def interior_domain(): - points = geo.sample_interior(10000, bounds={x: (-1., 1.)}, param_ranges=time_range) - constraints = {'burgers_u': 0} + points = geo.sample_interior( + 10000, bounds={x: (-1.0, 1.0)}, param_ranges=time_range + ) + constraints = {"burgers_u": 0} return points, constraints -@sc.datanode(name='t_boundary') +@sc.datanode(name="t_boundary") def init_domain(): points = geo.sample_interior(100, param_ranges={t_symbol: 0.0}) - constraints = sc.Variables({'u': -sin(math.pi * x)}) + constraints = sc.Variables({"u": -sin(math.pi * x)}) return points, constraints @sc.datanode(name="x_boundary") def boundary_domain(): points = geo.sample_boundary(100, param_ranges=time_range) - constraints = sc.Variables({'u': 0}) + constraints = sc.Variables({"u": 0}) return points, constraints -net = sc.get_net_node(inputs=('x', 't',), outputs=('u',), name='net1', arch=sc.Arch.mlp) -pde = sc.BurgersNode(u='u', v=0.01 / math.pi) -s = sc.Solver(sample_domains=(interior_domain(), init_domain(), boundary_domain()), - netnodes=[net], pdes=[pde], max_iter=4000) +net = sc.get_net_node( + inputs=( + "x", + "t", + ), + outputs=("u",), + name="net1", + arch=sc.Arch.mlp, +) +pde = sc.BurgersNode(u="u", v=0.01 / math.pi) +s = sc.Solver( + sample_domains=(interior_domain(), init_domain(), boundary_domain()), + netnodes=[net], + pdes=[pde], + max_iter=4000, +) s.solve() -coord = s.infer_step({'burgers_equation': ['x', 't', 'u'], 't_boundary': ['x', 't'], - 'x_boundary': ['x', 't']}) -num_x = coord['burgers_equation']['x'].cpu().detach().numpy().ravel() -num_t = coord['burgers_equation']['t'].cpu().detach().numpy().ravel() -num_u = coord['burgers_equation']['u'].cpu().detach().numpy().ravel() +coord = s.infer_step( + { + "burgers_equation": ["x", "t", "u"], + "t_boundary": ["x", "t"], + "x_boundary": ["x", "t"], + } +) +num_x = coord["burgers_equation"]["x"].cpu().detach().numpy().ravel() +num_t = coord["burgers_equation"]["t"].cpu().detach().numpy().ravel() +num_u = coord["burgers_equation"]["u"].cpu().detach().numpy().ravel() -init_x = coord['t_boundary']['x'].cpu().detach().numpy().ravel() -init_t = coord['t_boundary']['t'].cpu().detach().numpy().ravel() -boundary_x = coord['x_boundary']['x'].cpu().detach().numpy().ravel() -boundary_t = coord['x_boundary']['t'].cpu().detach().numpy().ravel() +init_x = coord["t_boundary"]["x"].cpu().detach().numpy().ravel() +init_t = coord["t_boundary"]["t"].cpu().detach().numpy().ravel() +boundary_x = coord["x_boundary"]["x"].cpu().detach().numpy().ravel() +boundary_t = coord["x_boundary"]["t"].cpu().detach().numpy().ravel() triang_total = tri.Triangulation(num_t.flatten(), num_x.flatten()) u_pre = num_u.flatten() fig = plt.figure(figsize=(15, 5)) ax1 = fig.add_subplot(221) -tcf = ax1.tricontourf(triang_total, u_pre, 100, cmap='jet') +tcf = ax1.tricontourf(triang_total, u_pre, 100, cmap="jet") tc_bar = plt.colorbar(tcf) tc_bar.ax.tick_params(labelsize=10) -ax1.set_xlabel('$t$') -ax1.set_ylabel('$x$') -ax1.set_title('$u(x,t)$') -ax1.scatter(init_t, init_x, c='black', marker='x', s=8) -ax1.scatter(boundary_t, boundary_x, c='black', marker='x', s=8) +ax1.set_xlabel("$t$") +ax1.set_ylabel("$x$") +ax1.set_title("$u(x,t)$") +ax1.scatter(init_t, init_x, c="black", marker="x", s=8) +ax1.scatter(boundary_t, boundary_x, c="black", marker="x", s=8) plt.xlim(0, 1) plt.ylim(-1, 1) -plt.savefig('Burgers.png', dpi=500, bbox_inches='tight', pad_inches=0.02) +plt.savefig("Burgers.png", dpi=500, bbox_inches="tight", pad_inches=0.02) diff --git a/examples/euler_beam/euler_beam.py b/examples/euler_beam/euler_beam.py index ac805f5..8962298 100644 --- a/examples/euler_beam/euler_beam.py +++ b/examples/euler_beam/euler_beam.py @@ -3,59 +3,68 @@ import sympy as sp import numpy as np import idrlnet.shortcut as sc -x = sp.symbols('x') +x = sp.symbols("x") Line = sc.Line1D(0, 1) -y = sp.Function('y')(x) +y = sp.Function("y")(x) -@sc.datanode(name='interior') +@sc.datanode(name="interior") class Interior(sc.SampleDomain): def sampling(self, *args, **kwargs): - return Line.sample_interior(1000), {'dddd_y': 0} + return Line.sample_interior(1000), {"dddd_y": 0} -@sc.datanode(name='left_boundary1') +@sc.datanode(name="left_boundary1") class LeftBoundary1(sc.SampleDomain): def sampling(self, *args, **kwargs): - return Line.sample_boundary(100, sieve=(sp.Eq(x, 0))), {'y': 0} + return Line.sample_boundary(100, sieve=(sp.Eq(x, 0))), {"y": 0} -@sc.datanode(name='left_boundary2') +@sc.datanode(name="left_boundary2") class LeftBoundary2(sc.SampleDomain): def sampling(self, *args, **kwargs): - return Line.sample_boundary(100, sieve=(sp.Eq(x, 0))), {'d_y': 0} + return Line.sample_boundary(100, sieve=(sp.Eq(x, 0))), {"d_y": 0} -@sc.datanode(name='right_boundary1') +@sc.datanode(name="right_boundary1") class RightBoundary1(sc.SampleDomain): def sampling(self, *args, **kwargs): - return Line.sample_boundary(100, sieve=(sp.Eq(x, 1))), {'dd_y': 0} + return Line.sample_boundary(100, sieve=(sp.Eq(x, 1))), {"dd_y": 0} -@sc.datanode(name='right_boundary2') +@sc.datanode(name="right_boundary2") class RightBoundary2(sc.SampleDomain): def sampling(self, *args, **kwargs): - return Line.sample_boundary(100, sieve=(sp.Eq(x, 1))), {'ddd_y': 0} + return Line.sample_boundary(100, sieve=(sp.Eq(x, 1))), {"ddd_y": 0} -@sc.datanode(name='infer') +@sc.datanode(name="infer") class Infer(sc.SampleDomain): def sampling(self, *args, **kwargs): - return {'x': np.linspace(0, 1, 1000).reshape(-1, 1)}, {} + return {"x": np.linspace(0, 1, 1000).reshape(-1, 1)}, {} -net = sc.get_net_node(inputs=('x',), outputs=('y',), name='net', arch=sc.Arch.mlp) +net = sc.get_net_node(inputs=("x",), outputs=("y",), name="net", arch=sc.Arch.mlp) -pde1 = sc.ExpressionNode(name='dddd_y', expression=y.diff(x).diff(x).diff(x).diff(x) + 1) -pde2 = sc.ExpressionNode(name='d_y', expression=y.diff(x)) -pde3 = sc.ExpressionNode(name='dd_y', expression=y.diff(x).diff(x)) -pde4 = sc.ExpressionNode(name='ddd_y', expression=y.diff(x).diff(x).diff(x)) +pde1 = sc.ExpressionNode( + name="dddd_y", expression=y.diff(x).diff(x).diff(x).diff(x) + 1 +) +pde2 = sc.ExpressionNode(name="d_y", expression=y.diff(x)) +pde3 = sc.ExpressionNode(name="dd_y", expression=y.diff(x).diff(x)) +pde4 = sc.ExpressionNode(name="ddd_y", expression=y.diff(x).diff(x).diff(x)) solver = sc.Solver( - sample_domains=(Interior(), LeftBoundary1(), LeftBoundary2(), RightBoundary1(), RightBoundary2()), + sample_domains=( + Interior(), + LeftBoundary1(), + LeftBoundary2(), + RightBoundary1(), + RightBoundary2(), + ), netnodes=[net], pdes=[pde1, pde2, pde3, pde4], - max_iter=2000) + max_iter=2000, +) solver.solve() @@ -65,14 +74,14 @@ def exact(x): solver.sample_domains = (Infer(),) -points = solver.infer_step({'infer': ['x', 'y']}) -xs = points['infer']['x'].detach().cpu().numpy().ravel() -y_pred = points['infer']['y'].detach().cpu().numpy().ravel() -plt.plot(xs, y_pred, label='Pred') +points = solver.infer_step({"infer": ["x", "y"]}) +xs = points["infer"]["x"].detach().cpu().numpy().ravel() +y_pred = points["infer"]["y"].detach().cpu().numpy().ravel() +plt.plot(xs, y_pred, label="Pred") y_exact = exact(xs) -plt.plot(xs, y_exact, label='Exact', linestyle='--') +plt.plot(xs, y_exact, label="Exact", linestyle="--") plt.legend() -plt.xlabel('x') -plt.ylabel('w') -plt.savefig('Euler_beam.png', dpi=300, bbox_inches='tight') +plt.xlabel("x") +plt.ylabel("w") +plt.savefig("Euler_beam.png", dpi=300, bbox_inches="tight") plt.show() diff --git a/examples/inverse_wave_equation/inverse_wave_equation.py b/examples/inverse_wave_equation/inverse_wave_equation.py index 62ed14d..43b1793 100644 --- a/examples/inverse_wave_equation/inverse_wave_equation.py +++ b/examples/inverse_wave_equation/inverse_wave_equation.py @@ -10,104 +10,121 @@ import matplotlib.pyplot as plt L = float(pi) geo = sc.Line1D(0, L) -t_symbol = Symbol('t') -x = Symbol('x') +t_symbol = Symbol("t") +x = Symbol("x") time_range = {t_symbol: (0, 2 * L)} c = 1.54 -external_filename = 'external_sample.csv' +external_filename = "external_sample.csv" def generate_observed_data(): if os.path.exists(external_filename): return - points = geo.sample_interior(density=20, - bounds={x: (0, L)}, - param_ranges=time_range, - low_discrepancy=True) - points['u'] = np.sin(points['x']) * (np.sin(c * points['t']) + np.cos(c * points['t'])) - points['u'][np.random.choice(len(points['u']), 10, replace=False)] = 3. + points = geo.sample_interior( + density=20, bounds={x: (0, L)}, param_ranges=time_range, low_discrepancy=True + ) + points["u"] = np.sin(points["x"]) * ( + np.sin(c * points["t"]) + np.cos(c * points["t"]) + ) + points["u"][np.random.choice(len(points["u"]), 10, replace=False)] = 3.0 points = {k: v.ravel() for k, v in points.items()} points = pd.DataFrame.from_dict(points) - points.to_csv('external_sample.csv', index=False) + points.to_csv("external_sample.csv", index=False) generate_observed_data() # @sc.datanode(name='wave_domain') -@sc.datanode(name='wave_domain', loss_fn='L1') +@sc.datanode(name="wave_domain", loss_fn="L1") class WaveExternal(sc.SampleDomain): def __init__(self): - points = pd.read_csv('external_sample.csv') - self.points = {col: points[col].to_numpy().reshape(-1, 1) for col in points.columns} - self.constraints = {'u': self.points.pop('u')} + points = pd.read_csv("external_sample.csv") + self.points = { + col: points[col].to_numpy().reshape(-1, 1) for col in points.columns + } + self.constraints = {"u": self.points.pop("u")} def sampling(self, *args, **kwargs): return self.points, self.constraints -@sc.datanode(name='wave_external') +@sc.datanode(name="wave_external") class WaveEq(sc.SampleDomain): def sampling(self, *args, **kwargs): - points = geo.sample_interior(density=1000, bounds={x: (0, L)}, param_ranges=time_range) - constraints = {'wave_equation': 0.} + points = geo.sample_interior( + density=1000, bounds={x: (0, L)}, param_ranges=time_range + ) + constraints = {"wave_equation": 0.0} return points, constraints -@sc.datanode(name='center_infer') +@sc.datanode(name="center_infer") class CenterInfer(sc.SampleDomain): def __init__(self): self.points = sc.Variables() - self.points['t'] = np.linspace(0, 2 * L, 200).reshape(-1, 1) - self.points['x'] = np.ones_like(self.points['t']) * L / 2 - self.points['area'] = np.ones_like(self.points['t']) + self.points["t"] = np.linspace(0, 2 * L, 200).reshape(-1, 1) + self.points["x"] = np.ones_like(self.points["t"]) * L / 2 + self.points["area"] = np.ones_like(self.points["t"]) def sampling(self, *args, **kwargs): return self.points, {} -net = sc.get_net_node(inputs=('x', 't',), outputs=('u',), name='net1', arch=sc.Arch.mlp) -var_c = sc.get_net_node(inputs=('x',), outputs=('c',), arch=sc.Arch.single_var) -pde = sc.WaveNode(c='c', dim=1, time=True, u='u') -s = sc.Solver(sample_domains=(WaveExternal(), WaveEq()), - netnodes=[net, var_c], - pdes=[pde], - # network_dir='square_network_dir', - network_dir='network_dir', - max_iter=5000) +net = sc.get_net_node( + inputs=( + "x", + "t", + ), + outputs=("u",), + name="net1", + arch=sc.Arch.mlp, +) +var_c = sc.get_net_node(inputs=("x",), outputs=("c",), arch=sc.Arch.single_var) +pde = sc.WaveNode(c="c", dim=1, time=True, u="u") +s = sc.Solver( + sample_domains=(WaveExternal(), WaveEq()), + netnodes=[net, var_c], + pdes=[pde], + # network_dir='square_network_dir', + network_dir="network_dir", + max_iter=5000, +) s.solve() _, ax = plt.subplots(1, 1, figsize=(8, 4)) -coord = s.infer_step(domain_attr={'wave_domain': ['x', 't', 'u']}) -num_t = coord['wave_domain']['t'].cpu().detach().numpy().ravel() -num_u = coord['wave_domain']['u'].cpu().detach().numpy().ravel() -ax.scatter(num_t, num_u, c='r', marker='o', label='predicted points') +coord = s.infer_step(domain_attr={"wave_domain": ["x", "t", "u"]}) +num_t = coord["wave_domain"]["t"].cpu().detach().numpy().ravel() +num_u = coord["wave_domain"]["u"].cpu().detach().numpy().ravel() +ax.scatter(num_t, num_u, c="r", marker="o", label="predicted points") print("true paratmeter c: {:.4f}".format(c)) predict_c = var_c.evaluate(torch.Tensor([[1.0]])).item() print("predicted parameter c: {:.4f}".format(predict_c)) -num_t = WaveExternal().sample_fn.points['t'].ravel() -num_u = WaveExternal().sample_fn.constraints['u'].ravel() -ax.scatter(num_t, num_u, c='b', marker='x', label='observed points') +num_t = WaveExternal().sample_fn.points["t"].ravel() +num_u = WaveExternal().sample_fn.constraints["u"].ravel() +ax.scatter(num_t, num_u, c="b", marker="x", label="observed points") s.sample_domains = (CenterInfer(),) -points = s.infer_step({'center_infer': ['t', 'x', 'u']}) -num_t = points['center_infer']['t'].cpu().detach().numpy().ravel() -num_u = points['center_infer']['u'].cpu().detach().numpy().ravel() -num_x = points['center_infer']['x'].cpu().detach().numpy().ravel() -ax.plot(num_t, np.sin(num_x) * (np.sin(c * num_t) + np.cos(c * num_t)), c='k', label='exact') -ax.plot(num_t, num_u, '--', c='g', linewidth=4, label='predict') +points = s.infer_step({"center_infer": ["t", "x", "u"]}) +num_t = points["center_infer"]["t"].cpu().detach().numpy().ravel() +num_u = points["center_infer"]["u"].cpu().detach().numpy().ravel() +num_x = points["center_infer"]["x"].cpu().detach().numpy().ravel() +ax.plot( + num_t, np.sin(num_x) * (np.sin(c * num_t) + np.cos(c * num_t)), c="k", label="exact" +) +ax.plot(num_t, num_u, "--", c="g", linewidth=4, label="predict") ax.legend() -ax.set_xlabel('t') -ax.set_ylabel('u') +ax.set_xlabel("t") +ax.set_ylabel("u") # ax.set_title(f'Square loss ($x=0.5L$, c={predict_c:.4f}))') -ax.set_title(f'L1 loss ($x=0.5L$, c={predict_c:.4f})') +ax.set_title(f"L1 loss ($x=0.5L$, c={predict_c:.4f})") ax.grid(True) ax.set_xlim([-0.5, 6.5]) ax.set_ylim([-3.5, 4.5]) # plt.savefig('square.png', dpi=1000, bbox_inches='tight', pad_inches=0.02) -plt.savefig('L1.png', dpi=1000, bbox_inches='tight', pad_inches=0.02) +plt.savefig("L1.png", dpi=1000, bbox_inches="tight", pad_inches=0.02) plt.show() plt.close() diff --git a/examples/minimal_surface_of_revolution/minimal_surface_of_revolution.py b/examples/minimal_surface_of_revolution/minimal_surface_of_revolution.py index df3d8d5..68f7c35 100644 --- a/examples/minimal_surface_of_revolution/minimal_surface_of_revolution.py +++ b/examples/minimal_surface_of_revolution/minimal_surface_of_revolution.py @@ -9,26 +9,30 @@ import math import idrlnet.shortcut as sc -x = sp.Symbol('x') -u = sp.Function('u')(x) +x = sp.Symbol("x") +u = sp.Function("u")(x) geo = sc.Line1D(-1, 0.5) -@sc.datanode(sigma=1000.) +@sc.datanode(sigma=1000.0) class Boundary(sc.SampleDomain): def __init__(self): - self.points = geo.sample_boundary(1, ) - self.constraints = {'u': np.cosh(self.points['x'])} + self.points = geo.sample_boundary( + 1, + ) + self.constraints = {"u": np.cosh(self.points["x"])} def sampling(self, *args, **kwargs): return self.points, self.constraints -@sc.datanode(loss_fn='L1') +@sc.datanode(loss_fn="L1") class Interior(sc.SampleDomain): def sampling(self, *args, **kwargs): points = geo.sample_interior(10000) - constraints = {'integral_dx': 0, } + constraints = { + "integral_dx": 0, + } return points, constraints @@ -36,8 +40,8 @@ class Interior(sc.SampleDomain): class InteriorInfer(sc.SampleDomain): def __init__(self): self.points = sc.Variables() - self.points['x'] = np.linspace(-1, 0.5, 1001, endpoint=True).reshape(-1, 1) - self.points['area'] = np.ones_like(self.points['x']) + self.points["x"] = np.linspace(-1, 0.5, 1001, endpoint=True).reshape(-1, 1) + self.points["area"] = np.ones_like(self.points["x"]) def sampling(self, *args, **kwargs): return self.points, {} @@ -46,8 +50,8 @@ class InteriorInfer(sc.SampleDomain): # plot Intermediate results class PlotReceiver(sc.Receiver): def __init__(self): - if not os.path.exists('plot'): - os.mkdir('plot') + if not os.path.exists("plot"): + os.mkdir("plot") xx = np.linspace(-1, 0.5, 1001, endpoint=True) self.xx = xx angle = np.linspace(0, math.pi * 2, 100) @@ -58,28 +62,30 @@ class PlotReceiver(sc.Receiver): zz_mesh = yy * np.sin(angle_mesh) fig = plt.figure(figsize=(8, 8)) - ax = fig.gca(projection='3d') + ax = fig.gca(projection="3d") ax.set_zlim3d(-1.25 - 1, 0.75 + 1) ax.set_ylim3d(-2, 2) ax.set_xlim3d(-2, 2) my_col = cm.cool((yy * np.ones_like(angle_mesh) - 1.0) / 0.6) ax.plot_surface(yy_mesh, zz_mesh, xx_mesh, facecolors=my_col) - ax.view_init(elev=15., azim=0) + ax.view_init(elev=15.0, azim=0) ax.dist = 5 - plt.axis('off') - plt.tight_layout(pad=0., w_pad=0., h_pad=.0) - plt.savefig(f'plot/p_exact.png') + plt.axis("off") + plt.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0) + plt.savefig(f"plot/p_exact.png") plt.show() plt.close() self.predict_history = [] def receive_notify(self, obj: sc.Solver, message: Dict): - if sc.Signal.SOLVE_START in message or (sc.Signal.TRAIN_PIPE_END in message and obj.global_step % 200 == 0): + if sc.Signal.SOLVE_START in message or ( + sc.Signal.TRAIN_PIPE_END in message and obj.global_step % 200 == 0 + ): print("plotting") - points = s.infer_step({'InteriorInfer': ['x', 'u']}) - num_x = points['InteriorInfer']['x'].detach().cpu().numpy().ravel() - num_u = points['InteriorInfer']['u'].detach().cpu().numpy().ravel() + points = s.infer_step({"InteriorInfer": ["x", "u"]}) + num_x = points["InteriorInfer"]["x"].detach().cpu().numpy().ravel() + num_u = points["InteriorInfer"]["u"].detach().cpu().numpy().ravel() angle = np.linspace(0, math.pi * 2, 100) xx_mesh, angle_mesh = np.meshgrid(num_x, angle) @@ -87,28 +93,28 @@ class PlotReceiver(sc.Receiver): zz_mesh = num_u * np.sin(angle_mesh) fig = plt.figure(figsize=(8, 8)) - ax = fig.gca(projection='3d') + ax = fig.gca(projection="3d") ax.set_zlim3d(-1.25 - 1, 0.75 + 1) ax.set_ylim3d(-2, 2) ax.set_xlim3d(-2, 2) my_col = cm.cool((num_u * np.ones_like(angle_mesh) - 1.0) / 0.6) ax.plot_surface(yy_mesh, zz_mesh, xx_mesh, facecolors=my_col) - ax.view_init(elev=15., azim=0) + ax.view_init(elev=15.0, azim=0) ax.dist = 5 - plt.axis('off') - plt.tight_layout(pad=0., w_pad=0., h_pad=.0) - plt.savefig(f'plot/p_{obj.global_step}.png') + plt.axis("off") + plt.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0) + plt.savefig(f"plot/p_{obj.global_step}.png") plt.show() plt.close() self.predict_history.append((num_u, obj.global_step)) if sc.Signal.SOLVE_END in message: try: - with open('result.pickle', 'rb') as f: + with open("result.pickle", "rb") as f: self.predict_history = pickle.load(f) except: - with open('result.pickle', 'wb') as f: + with open("result.pickle", "wb") as f: pickle.dump(self.predict_history, f) for yy, step in self.predict_history: if step == 0: @@ -116,28 +122,35 @@ class PlotReceiver(sc.Receiver): if step == 200: plt.plot(yy, self.xx, label=f"iter={step}") if step == 800: - plt.plot(yy[::100], self.xx[::100], '-o', label=f"iter={step}") - plt.plot(np.cosh(self.xx)[::100], self.xx[::100], '-x', label='exact') - plt.plot([0, np.cosh(-1)], [-1, -1], '--', color='gray') - plt.plot([0, np.cosh(0.5)], [0.5, 0.5], '--', color='gray') + plt.plot(yy[::100], self.xx[::100], "-o", label=f"iter={step}") + plt.plot(np.cosh(self.xx)[::100], self.xx[::100], "-x", label="exact") + plt.plot([0, np.cosh(-1)], [-1, -1], "--", color="gray") + plt.plot([0, np.cosh(0.5)], [0.5, 0.5], "--", color="gray") plt.legend() plt.xlim([0, 1.7]) - plt.xlabel('y') - plt.ylabel('x') - plt.savefig('iterations.png') + plt.xlabel("y") + plt.ylabel("x") + plt.savefig("iterations.png") plt.show() plt.close() -dx_exp = sc.ExpressionNode(expression=sp.Abs(u) * sp.sqrt((u.diff(x)) ** 2 + 1), name='dx') -net = sc.get_net_node(inputs=('x',), outputs=('u',), name='net', arch=sc.Arch.mlp) +dx_exp = sc.ExpressionNode( + expression=sp.Abs(u) * sp.sqrt((u.diff(x)) ** 2 + 1), name="dx" +) +net = sc.get_net_node(inputs=("x",), outputs=("u",), name="net", arch=sc.Arch.mlp) -integral = sc.ICNode('dx', dim=1, time=False) +integral = sc.ICNode("dx", dim=1, time=False) -s = sc.Solver(sample_domains=(Boundary(), Interior(), InteriorInfer()), - netnodes=[net], - init_network_dirs=['pretrain_network_dir'], - pdes=[dx_exp, integral, ], - max_iter=1500) +s = sc.Solver( + sample_domains=(Boundary(), Interior(), InteriorInfer()), + netnodes=[net], + init_network_dirs=["pretrain_network_dir"], + pdes=[ + dx_exp, + integral, + ], + max_iter=1500, +) s.register_receiver(PlotReceiver()) s.solve() diff --git a/examples/minimal_surface_of_revolution/minimal_surface_of_revolution_pretrain.py b/examples/minimal_surface_of_revolution/minimal_surface_of_revolution_pretrain.py index 48c8b7c..ea83dc0 100644 --- a/examples/minimal_surface_of_revolution/minimal_surface_of_revolution_pretrain.py +++ b/examples/minimal_surface_of_revolution/minimal_surface_of_revolution_pretrain.py @@ -3,30 +3,34 @@ import numpy as np import sympy as sp import idrlnet.shortcut as sc -x = sp.Symbol('x') +x = sp.Symbol("x") geo = sc.Line1D(-1, 0.5) -@sc.datanode(loss_fn='L1') +@sc.datanode(loss_fn="L1") class Interior(sc.SampleDomain): def sampling(self, *args, **kwargs): points = geo.sample_interior(100) - constraints = {'u': (np.cosh(0.5) - np.cosh(-1)) / 1.5 * (x + 1.0) + np.cosh(-1)} + constraints = { + "u": (np.cosh(0.5) - np.cosh(-1)) / 1.5 * (x + 1.0) + np.cosh(-1) + } return points, constraints -net = sc.get_net_node(inputs=('x',), outputs=('u',), name='net', arch=sc.Arch.mlp) +net = sc.get_net_node(inputs=("x",), outputs=("u",), name="net", arch=sc.Arch.mlp) -s = sc.Solver(sample_domains=(Interior(),), - netnodes=[net], - pdes=[], - network_dir='pretrain_network_dir', - max_iter=1000) +s = sc.Solver( + sample_domains=(Interior(),), + netnodes=[net], + pdes=[], + network_dir="pretrain_network_dir", + max_iter=1000, +) s.solve() -points = s.infer_step({'Interior': ['x', 'u']}) -num_x = points['Interior']['x'].detach().cpu().numpy().ravel() -num_u = points['Interior']['u'].detach().cpu().numpy().ravel() +points = s.infer_step({"Interior": ["x", "u"]}) +num_x = points["Interior"]["x"].detach().cpu().numpy().ravel() +num_u = points["Interior"]["u"].detach().cpu().numpy().ravel() xx = np.linspace(-1, 0.5, 1000, endpoint=True) yy = np.cosh(xx) diff --git a/examples/parameterized_poisson/parameterized_poisson.py b/examples/parameterized_poisson/parameterized_poisson.py index d78d952..429da28 100644 --- a/examples/parameterized_poisson/parameterized_poisson.py +++ b/examples/parameterized_poisson/parameterized_poisson.py @@ -4,18 +4,20 @@ import matplotlib.pyplot as plt import matplotlib.tri as tri import numpy as np -x, y = sp.symbols('x y') -temp = sp.Symbol('temp') +x, y = sp.symbols("x y") +temp = sp.Symbol("temp") temp_range = {temp: (-0.2, 0.2)} -rec = sc.Rectangle((-1., -1.), (1., 1.)) +rec = sc.Rectangle((-1.0, -1.0), (1.0, 1.0)) @sc.datanode class Right(sc.SampleDomain): # Due to `name` is not specified, Right will be the name of datanode automatically def sampling(self, *args, **kwargs): - points = rec.sample_boundary(1000, sieve=(sp.Eq(x, 1.)), param_ranges=temp_range) - constraints = sc.Variables({'T': 0.}) + points = rec.sample_boundary( + 1000, sieve=(sp.Eq(x, 1.0)), param_ranges=temp_range + ) + constraints = sc.Variables({"T": 0.0}) return points, constraints @@ -23,16 +25,20 @@ class Right(sc.SampleDomain): class Left(sc.SampleDomain): # Due to `name` is not specified, Left will be the name of datanode automatically def sampling(self, *args, **kwargs): - points = rec.sample_boundary(1000, sieve=(sp.Eq(x, -1.)), param_ranges=temp_range) - constraints = sc.Variables({'T': temp}) + points = rec.sample_boundary( + 1000, sieve=(sp.Eq(x, -1.0)), param_ranges=temp_range + ) + constraints = sc.Variables({"T": temp}) return points, constraints @sc.datanode(name="up_down") class UpDownBoundaryDomain(sc.SampleDomain): def sampling(self, *args, **kwargs): - points = rec.sample_boundary(1000, sieve=((x > -1.) & (x < 1.)), param_ranges=temp_range) - constraints = sc.Variables({'normal_gradient_T': 0.}) + points = rec.sample_boundary( + 1000, sieve=((x > -1.0) & (x < 1.0)), param_ranges=temp_range + ) + constraints = sc.Variables({"normal_gradient_T": 0.0}) return points, constraints @@ -43,47 +49,53 @@ class HeatDomain(sc.SampleDomain): def sampling(self, *args, **kwargs): points = rec.sample_interior(self.points, param_ranges=temp_range) - constraints = sc.Variables({'diffusion_T': 1.}) + constraints = sc.Variables({"diffusion_T": 1.0}) return points, constraints -net = sc.get_net_node(inputs=('x', 'y', 'temp'), outputs=('T',), name='net1', arch=sc.Arch.mlp) -pde = sc.DiffusionNode(T='T', D=1., Q=0., dim=2, time=False) -grad = sc.NormalGradient('T', dim=2, time=False) -s = sc.Solver(sample_domains=(HeatDomain(), Left(), Right(), UpDownBoundaryDomain()), - netnodes=[net], - pdes=[pde, grad], - max_iter=3000) +net = sc.get_net_node( + inputs=("x", "y", "temp"), outputs=("T",), name="net1", arch=sc.Arch.mlp +) +pde = sc.DiffusionNode(T="T", D=1.0, Q=0.0, dim=2, time=False) +grad = sc.NormalGradient("T", dim=2, time=False) +s = sc.Solver( + sample_domains=(HeatDomain(), Left(), Right(), UpDownBoundaryDomain()), + netnodes=[net], + pdes=[pde, grad], + max_iter=3000, +) s.solve() def infer_temp(temp_num, file_suffix=None): temp_range[temp] = temp_num - s.set_domain_parameter('heat_domain', {'points': 10000}) - coord = s.infer_step({'heat_domain': ['x', 'y', 'T']}) - num_x = coord['heat_domain']['x'].cpu().detach().numpy().ravel() - num_y = coord['heat_domain']['y'].cpu().detach().numpy().ravel() - num_Tp = coord['heat_domain']['T'].cpu().detach().numpy().ravel() + s.set_domain_parameter("heat_domain", {"points": 10000}) + coord = s.infer_step({"heat_domain": ["x", "y", "T"]}) + num_x = coord["heat_domain"]["x"].cpu().detach().numpy().ravel() + num_y = coord["heat_domain"]["y"].cpu().detach().numpy().ravel() + num_Tp = coord["heat_domain"]["T"].cpu().detach().numpy().ravel() # Ground truth - num_T = -(num_x + 1 + temp_num) * (num_x - 1.) / 2 + num_T = -(num_x + 1 + temp_num) * (num_x - 1.0) / 2 fig, ax = plt.subplots(1, 3, figsize=(10, 3)) triang_total = tri.Triangulation(num_x, num_y) - ax[0].tricontourf(triang_total, num_Tp, 100, cmap='hot', vmin=-0.2, vmax=1.21 / 2) - ax[0].axis('off') - ax[0].set_title(f'prediction($T_l={temp_num:.2f}$)') - ax[1].tricontourf(triang_total, num_T, 100, cmap='hot', vmin=-0.2, vmax=1.21 / 2) - ax[1].axis('off') - ax[1].set_title(f'ground truth($T_l={temp_num:.2f}$)') - ax[2].tricontourf(triang_total, np.abs(num_T - num_Tp), 100, cmap='hot', vmin=0, vmax=1.21 / 2) - ax[2].axis('off') - ax[2].set_title('absolute error') + ax[0].tricontourf(triang_total, num_Tp, 100, cmap="hot", vmin=-0.2, vmax=1.21 / 2) + ax[0].axis("off") + ax[0].set_title(f"prediction($T_l={temp_num:.2f}$)") + ax[1].tricontourf(triang_total, num_T, 100, cmap="hot", vmin=-0.2, vmax=1.21 / 2) + ax[1].axis("off") + ax[1].set_title(f"ground truth($T_l={temp_num:.2f}$)") + ax[2].tricontourf( + triang_total, np.abs(num_T - num_Tp), 100, cmap="hot", vmin=0, vmax=1.21 / 2 + ) + ax[2].axis("off") + ax[2].set_title("absolute error") if file_suffix is None: - plt.savefig(f'poisson_{temp_num:.2f}.png', dpi=300, bbox_inches='tight') + plt.savefig(f"poisson_{temp_num:.2f}.png", dpi=300, bbox_inches="tight") plt.show() else: - plt.savefig(f'poisson_{file_suffix}.png', dpi=300, bbox_inches='tight') + plt.savefig(f"poisson_{file_suffix}.png", dpi=300, bbox_inches="tight") plt.show() diff --git a/examples/simple_poisson/simple_poisson.py b/examples/simple_poisson/simple_poisson.py index 6c0f8f2..f4c172f 100644 --- a/examples/simple_poisson/simple_poisson.py +++ b/examples/simple_poisson/simple_poisson.py @@ -4,24 +4,24 @@ import matplotlib.pyplot as plt import matplotlib.tri as tri import numpy as np -x, y = sp.symbols('x y') -rec = sc.Rectangle((-1., -1.), (1., 1.)) +x, y = sp.symbols("x y") +rec = sc.Rectangle((-1.0, -1.0), (1.0, 1.0)) @sc.datanode class LeftRight(sc.SampleDomain): # Due to `name` is not specified, LeftRight will be the name of datanode automatically def sampling(self, *args, **kwargs): - points = rec.sample_boundary(1000, sieve=((y > -1.) & (y < 1.))) - constraints = {'T': 0.} + points = rec.sample_boundary(1000, sieve=((y > -1.0) & (y < 1.0))) + constraints = {"T": 0.0} return points, constraints @sc.datanode(name="up_down") class UpDownBoundaryDomain(sc.SampleDomain): def sampling(self, *args, **kwargs): - points = rec.sample_boundary(1000, sieve=((x > -1.) & (x < 1.))) - constraints = {'normal_gradient_T': 0.} + points = rec.sample_boundary(1000, sieve=((x > -1.0) & (x < 1.0))) + constraints = {"normal_gradient_T": 0.0} return points, constraints @@ -32,39 +32,51 @@ class HeatDomain(sc.SampleDomain): def sampling(self, *args, **kwargs): points = rec.sample_interior(self.points) - constraints = {'diffusion_T': 1.} + constraints = {"diffusion_T": 1.0} return points, constraints -net = sc.get_net_node(inputs=('x', 'y',), outputs=('T',), name='net1', arch=sc.Arch.mlp) -pde = sc.DiffusionNode(T='T', D=1., Q=0., dim=2, time=False) -grad = sc.NormalGradient('T', dim=2, time=False) -s = sc.Solver(sample_domains=(HeatDomain(), LeftRight(), UpDownBoundaryDomain()), - netnodes=[net], - pdes=[pde, grad], - max_iter=1000) +net = sc.get_net_node( + inputs=( + "x", + "y", + ), + outputs=("T",), + name="net1", + arch=sc.Arch.mlp, +) +pde = sc.DiffusionNode(T="T", D=1.0, Q=0.0, dim=2, time=False) +grad = sc.NormalGradient("T", dim=2, time=False) +s = sc.Solver( + sample_domains=(HeatDomain(), LeftRight(), UpDownBoundaryDomain()), + netnodes=[net], + pdes=[pde, grad], + max_iter=1000, +) s.solve() # Inference -s.set_domain_parameter('heat_domain', {'points': 10000}) -coord = s.infer_step({'heat_domain': ['x', 'y', 'T']}) -num_x = coord['heat_domain']['x'].cpu().detach().numpy().ravel() -num_y = coord['heat_domain']['y'].cpu().detach().numpy().ravel() -num_Tp = coord['heat_domain']['T'].cpu().detach().numpy().ravel() +s.set_domain_parameter("heat_domain", {"points": 10000}) +coord = s.infer_step({"heat_domain": ["x", "y", "T"]}) +num_x = coord["heat_domain"]["x"].cpu().detach().numpy().ravel() +num_y = coord["heat_domain"]["y"].cpu().detach().numpy().ravel() +num_Tp = coord["heat_domain"]["T"].cpu().detach().numpy().ravel() # Ground truth num_T = -num_x * num_x / 2 + 0.5 fig, ax = plt.subplots(1, 3, figsize=(10, 3)) triang_total = tri.Triangulation(num_x, num_y) -ax[0].tricontourf(triang_total, num_Tp, 100, cmap='hot', vmin=0, vmax=0.5) -ax[0].axis('off') -ax[0].set_title('prediction') -ax[1].tricontourf(triang_total, num_T, 100, cmap='hot', vmin=0, vmax=0.5) -ax[1].axis('off') -ax[1].set_title('ground truth') -ax[2].tricontourf(triang_total, np.abs(num_T - num_Tp), 100, cmap='hot', vmin=0, vmax=0.5) -ax[2].axis('off') -ax[2].set_title('absolute error') +ax[0].tricontourf(triang_total, num_Tp, 100, cmap="hot", vmin=0, vmax=0.5) +ax[0].axis("off") +ax[0].set_title("prediction") +ax[1].tricontourf(triang_total, num_T, 100, cmap="hot", vmin=0, vmax=0.5) +ax[1].axis("off") +ax[1].set_title("ground truth") +ax[2].tricontourf( + triang_total, np.abs(num_T - num_Tp), 100, cmap="hot", vmin=0, vmax=0.5 +) +ax[2].axis("off") +ax[2].set_title("absolute error") -plt.savefig('simple_poisson.png', dpi=300, bbox_inches='tight') +plt.savefig("simple_poisson.png", dpi=300, bbox_inches="tight") diff --git a/idrlnet/__init__.py b/idrlnet/__init__.py index 593ccb4..169e941 100644 --- a/idrlnet/__init__.py +++ b/idrlnet/__init__.py @@ -1,15 +1,16 @@ import torch + # todo more careful check GPU_ENABLED = True if torch.cuda.is_available(): try: - _ = torch.Tensor([0., 0.]).cuda() - torch.set_default_tensor_type('torch.cuda.FloatTensor') - print('gpu available') + _ = torch.Tensor([0.0, 0.0]).cuda() + torch.set_default_tensor_type("torch.cuda.FloatTensor") + print("gpu available") GPU_ENABLED = True except: - print('gpu not available') + print("gpu not available") GPU_ENABLED = False else: - print('gpu not available') + print("gpu not available") GPU_ENABLED = False diff --git a/idrlnet/architecture/grid.py b/idrlnet/architecture/grid.py index a6a1264..40805c7 100644 --- a/idrlnet/architecture/grid.py +++ b/idrlnet/architecture/grid.py @@ -15,14 +15,28 @@ def indicator(xn: torch.Tensor, *axis_bounds): i = 0 lb, ub, lb_eq = axis_bounds[0] if lb_eq: - indic = torch.logical_and(xn[:, i:i + 1] >= axis_bounds[0][0], axis_bounds[0][1] >= xn[:, i:i + 1]) + indic = torch.logical_and( + xn[:, i : i + 1] >= axis_bounds[0][0], axis_bounds[0][1] >= xn[:, i : i + 1] + ) else: - indic = torch.logical_and(xn[:, i:i + 1] > axis_bounds[0][0], axis_bounds[0][1] >= xn[:, i:i + 1]) + indic = torch.logical_and( + xn[:, i : i + 1] > axis_bounds[0][0], axis_bounds[0][1] >= xn[:, i : i + 1] + ) for i, (lb, ub, lb_eq) in enumerate(axis_bounds[1:]): if lb_eq: - indic = torch.logical_and(indic, torch.logical_and(xn[:, i + 1:i + 2] >= lb, ub >= xn[:, i + 1:i + 2])) + indic = torch.logical_and( + indic, + torch.logical_and( + xn[:, i + 1 : i + 2] >= lb, ub >= xn[:, i + 1 : i + 2] + ), + ) else: - indic = torch.logical_and(indic, torch.logical_and(xn[:, i + 1:i + 2] > lb, ub >= xn[:, i + 1:i + 2])) + indic = torch.logical_and( + indic, + torch.logical_and( + xn[:, i + 1 : i + 2] > lb, ub >= xn[:, i + 1 : i + 2] + ), + ) return indic @@ -34,8 +48,8 @@ class NetEval(torch.nn.Module): self.n_columns = len(self.columns) - 1 self.n_rows = len(self.rows) - 1 self.nets = [] - if 'net_generator' in kwargs.keys(): - net_gen = kwargs.pop('net_generator') + if "net_generator" in kwargs.keys(): + net_gen = kwargs.pop("net_generator") else: net_gen = lambda: mlp.MLP([n_inputs, 20, 20, 20, 20, n_outputs]) for i in range(self.n_columns): @@ -50,8 +64,18 @@ class NetEval(torch.nn.Module): y = 0 for i in range(self.n_columns): for j in range(self.n_rows): - y += indicator(xn, (self.columns[i], self.columns[i + 1], True if i == 0 else False), - (self.rows[j], self.rows[j + 1], True if j == 0 else False)) * self.nets[i][j](x) + y += ( + indicator( + xn, + ( + self.columns[i], + self.columns[i + 1], + True if i == 0 else False, + ), + (self.rows[j], self.rows[j + 1], True if j == 0 else False), + ) + * self.nets[i][j](x) + ) return y @@ -59,7 +83,10 @@ class Interface: def __init__(self, points1, points2, nr, outputs, i1, j1, i2, j2, overlap=0.2): x_min, x_max = min(points1[0], points2[0]), max(points1[0], points2[0]) y_min, y_max = min(points1[1], points2[1]), max(points1[1], points2[1]) - self.geo = Rectangle((x_min - overlap / 2, y_min - overlap / 2), (x_max + overlap / 2, y_max + overlap / 2)) + self.geo = Rectangle( + (x_min - overlap / 2, y_min - overlap / 2), + (x_max + overlap / 2, y_max + overlap / 2), + ) self.nr = nr self.outputs = outputs self.i1 = i1 @@ -69,16 +96,26 @@ class Interface: def __call__(self, *args, **kwargs): points = self.geo.sample_boundary(self.nr) - return points, {f'difference_{output}_{self.i1}_{self.j1}_{output}_{self.i2}_{self.j2}': 0 - for output in self.outputs} + return points, { + f"difference_{output}_{self.i1}_{self.j1}_{output}_{self.i2}_{self.j2}": 0 + for output in self.outputs + } class NetGridNode(NetNode): - def __init__(self, inputs: Union[Tuple, List[str]], outputs: Union[Tuple, List[str]], - x_segments: List[float] = None, y_segments: List[float] = None, - z_segments: List[float] = None, t_segments: List[float] = None, columns: List[float] = None, - rows: List[float] = None, *args, - **kwargs): + def __init__( + self, + inputs: Union[Tuple, List[str]], + outputs: Union[Tuple, List[str]], + x_segments: List[float] = None, + y_segments: List[float] = None, + z_segments: List[float] = None, + t_segments: List[float] = None, + columns: List[float] = None, + rows: List[float] = None, + *args, + **kwargs, + ): if columns is None: columns = [] if rows is None: @@ -87,8 +124,16 @@ class NetGridNode(NetNode): fixed = False self.columns = columns self.rows = rows - self.main_net = NetEval(n_inputs=len(inputs), n_outputs=len(outputs), columns=columns, rows=rows, **kwargs) - super(NetGridNode, self).__init__(inputs, outputs, self.main_net, fixed, require_no_grad, *args, **kwargs) + self.main_net = NetEval( + n_inputs=len(inputs), + n_outputs=len(outputs), + columns=columns, + rows=rows, + **kwargs, + ) + super(NetGridNode, self).__init__( + inputs, outputs, self.main_net, fixed, require_no_grad, *args, **kwargs + ) def get_grid(self, overlap, nr_points_per_interface_area=100): n_columns = self.main_net.n_columns @@ -98,54 +143,119 @@ class NetGridNode(NetNode): constraints = [] for i in range(n_columns): for j in range(n_rows): - nn = NetNode(inputs=self.inputs, - outputs=tuple(f'{output}_{i}_{j}' for output in self.outputs), - net=self.main_net.nets[i][j], - name=f'{self.name}[{i}][{j}]') + nn = NetNode( + inputs=self.inputs, + outputs=tuple(f"{output}_{i}_{j}" for output in self.outputs), + net=self.main_net.nets[i][j], + name=f"{self.name}[{i}][{j}]", + ) nn.is_reference = True netnodes.append(nn) if i > 0: for output in self.outputs: - diff_Node = Difference(f'{output}_{i - 1}_{j}', f'{output}_{i}_{j}', dim=2, time=False) + diff_Node = Difference( + f"{output}_{i - 1}_{j}", + f"{output}_{i}_{j}", + dim=2, + time=False, + ) eqs.append(diff_Node) - interface = Interface((self.columns[i], self.rows[j]), (self.columns[i], self.rows[j + 1]), - nr_points_per_interface_area, self.outputs, i - 1, j, i, j, overlap=overlap) + interface = Interface( + (self.columns[i], self.rows[j]), + (self.columns[i], self.rows[j + 1]), + nr_points_per_interface_area, + self.outputs, + i - 1, + j, + i, + j, + overlap=overlap, + ) - constraints.append(get_data_node(interface, name=f'interface[{i - 1}][{j}]_[{i}][{j}]')) + constraints.append( + get_data_node( + interface, name=f"interface[{i - 1}][{j}]_[{i}][{j}]" + ) + ) if j > 0: for output in self.outputs: - diff_Node = Difference(f'{output}_{i}_{j - 1}', f'{output}_{i}_{j}', dim=2, time=False) + diff_Node = Difference( + f"{output}_{i}_{j - 1}", + f"{output}_{i}_{j}", + dim=2, + time=False, + ) eqs.append(diff_Node) - interface = Interface((self.columns[i], self.rows[j]), (self.columns[i + 1], self.rows[j]), - nr_points_per_interface_area, self.outputs, i, j - 1, i, j, overlap=overlap) + interface = Interface( + (self.columns[i], self.rows[j]), + (self.columns[i + 1], self.rows[j]), + nr_points_per_interface_area, + self.outputs, + i, + j - 1, + i, + j, + overlap=overlap, + ) - constraints.append(get_data_node(interface, name=f'interface[{i}][{j - 1}]_[{i}][{j}]')) + constraints.append( + get_data_node( + interface, name=f"interface[{i}][{j - 1}]_[{i}][{j}]" + ) + ) return netnodes, eqs, constraints -def get_net_reg_grid_2d(inputs: Union[Tuple, List[str]], outputs: Union[Tuple, List[str]], name: str, - columns: List[float], rows: List[float], **kwargs): - if 'overlap' in kwargs.keys(): - overlap = kwargs.pop('overlap') +def get_net_reg_grid_2d( + inputs: Union[Tuple, List[str]], + outputs: Union[Tuple, List[str]], + name: str, + columns: List[float], + rows: List[float], + **kwargs, +): + if "overlap" in kwargs.keys(): + overlap = kwargs.pop("overlap") else: overlap = 0.2 - net = NetGridNode(inputs=inputs, outputs=outputs, columns=columns, rows=rows, name=name, **kwargs) - nets, eqs, interfaces = net.get_grid(nr_points_per_interface_area=1000, overlap=overlap) + net = NetGridNode( + inputs=inputs, outputs=outputs, columns=columns, rows=rows, name=name, **kwargs + ) + nets, eqs, interfaces = net.get_grid( + nr_points_per_interface_area=1000, overlap=overlap + ) nets.append(net) return nets, eqs, interfaces -def get_net_reg_grid(inputs: Union[Tuple, List[str]], outputs: Union[Tuple, List[str]], name: str, - x_segments: List[float] = None, y_segments: List[float] = None, z_segments: List[float] = None, - t_segments: List[float] = None, **kwargs): - if 'overlap' in kwargs.keys(): - overlap = kwargs.pop('overlap') +def get_net_reg_grid( + inputs: Union[Tuple, List[str]], + outputs: Union[Tuple, List[str]], + name: str, + x_segments: List[float] = None, + y_segments: List[float] = None, + z_segments: List[float] = None, + t_segments: List[float] = None, + **kwargs, +): + if "overlap" in kwargs.keys(): + overlap = kwargs.pop("overlap") else: overlap = 0.2 - net = NetGridNode(inputs=inputs, outputs=outputs, x_segments=x_segments, y_segments=y_segments, - z_segments=z_segments, t_segments=t_segments, name=name, **kwargs) - nets, eqs, interfaces = net.get_grid(nr_points_per_interface_area=1000, overlap=overlap) + net = NetGridNode( + inputs=inputs, + outputs=outputs, + x_segments=x_segments, + y_segments=y_segments, + z_segments=z_segments, + t_segments=t_segments, + name=name, + **kwargs, + ) + nets, eqs, interfaces = net.get_grid( + nr_points_per_interface_area=1000, overlap=overlap + ) nets.append(net) return nets, eqs, interfaces diff --git a/idrlnet/architecture/layer.py b/idrlnet/architecture/layer.py index 742f4a6..02fefe0 100644 --- a/idrlnet/architecture/layer.py +++ b/idrlnet/architecture/layer.py @@ -5,35 +5,40 @@ import math import torch from idrlnet.header import logger -__all__ = ['Activation', 'Initializer', 'get_activation_layer', 'get_linear_layer'] +__all__ = ["Activation", "Initializer", "get_activation_layer", "get_linear_layer"] class Activation(enum.Enum): - relu = 'relu' - silu = 'silu' - selu = 'selu' - sigmoid = 'sigmoid' - tanh = 'tanh' - swish = 'swish' - poly = 'poly' - sin = 'sin' - leaky_relu = 'leaky_relu' + relu = "relu" + silu = "silu" + selu = "selu" + sigmoid = "sigmoid" + tanh = "tanh" + swish = "swish" + poly = "poly" + sin = "sin" + leaky_relu = "leaky_relu" class Initializer(enum.Enum): - Xavier_uniform = 'Xavier_uniform' - constant = 'constant' - kaiming_uniform = 'kaiming_uniform' - default = 'default' + Xavier_uniform = "Xavier_uniform" + constant = "constant" + kaiming_uniform = "kaiming_uniform" + default = "default" -def get_linear_layer(input_dim: int, output_dim: int, weight_norm=False, - initializer: Initializer = Initializer.Xavier_uniform, *args, - **kwargs): +def get_linear_layer( + input_dim: int, + output_dim: int, + weight_norm=False, + initializer: Initializer = Initializer.Xavier_uniform, + *args, + **kwargs, +): layer = torch.nn.Linear(input_dim, output_dim) init_method = InitializerFactory.get_initializer(initializer=initializer, **kwargs) init_method(layer.weight) - torch.nn.init.constant_(layer.bias, 0.) + torch.nn.init.constant_(layer.bias, 0.0) if weight_norm: layer = torch.nn.utils.weight_norm(layer) return layer @@ -81,8 +86,10 @@ class ActivationFactory: elif activation == Activation.silu: return Silu() else: - logger.error(f'Activation {activation} is not supported!') - raise NotImplementedError('Activation ' + activation.name + ' is not supported') + logger.error(f"Activation {activation} is not supported!") + raise NotImplementedError( + "Activation " + activation.name + " is not supported" + ) class Silu: @@ -105,8 +112,12 @@ def leaky_relu(x, leak=0.1): def triangle_wave(x): y = 0.0 for i in range(3): - y += (-1.0) ** (i) * torch.sin(2.0 * math.pi * (2.0 * i + 1.0) * x) / (2.0 * i + 1.0) ** (2) - y = 0.5 * (8 / (math.pi ** 2) * y) + .5 + y += ( + (-1.0) ** (i) + * torch.sin(2.0 * math.pi * (2.0 * i + 1.0) * x) + / (2.0 * i + 1.0) ** (2) + ) + y = 0.5 * (8 / (math.pi ** 2) * y) + 0.5 return y @@ -139,11 +150,15 @@ class InitializerFactory: if initializer == Initializer.Xavier_uniform: return torch.nn.init.xavier_uniform_ elif initializer == Initializer.constant: - return lambda x: torch.nn.init.constant_(x, kwargs['constant']) + return lambda x: torch.nn.init.constant_(x, kwargs["constant"]) elif initializer == Initializer.kaiming_uniform: - return lambda x: torch.nn.init.kaiming_uniform_(x, mode='fan_in', nonlinearity='relu') + return lambda x: torch.nn.init.kaiming_uniform_( + x, mode="fan_in", nonlinearity="relu" + ) elif initializer == Initializer.default: return lambda x: x else: - logger.error('initialization ' + initializer.name + ' is not supported') - raise NotImplementedError('initialization ' + initializer.name + ' is not supported') + logger.error("initialization " + initializer.name + " is not supported") + raise NotImplementedError( + "initialization " + initializer.name + " is not supported" + ) diff --git a/idrlnet/architecture/mlp.py b/idrlnet/architecture/mlp.py index 977221e..37d2b87 100644 --- a/idrlnet/architecture/mlp.py +++ b/idrlnet/architecture/mlp.py @@ -3,7 +3,12 @@ import torch import math from collections import OrderedDict -from idrlnet.architecture.layer import get_linear_layer, get_activation_layer, Initializer, Activation +from idrlnet.architecture.layer import ( + get_linear_layer, + get_activation_layer, + Initializer, + Activation, +) from typing import List, Union, Tuple from idrlnet.header import logger from idrlnet.net import NetNode @@ -28,25 +33,36 @@ class MLP(torch.nn.Module): :param kwargs: """ - def __init__(self, n_seq: List[int], activation: Union[Activation, List[Activation]] = Activation.swish, - initialization: Initializer = Initializer.kaiming_uniform, - weight_norm: bool = True, name: str = 'mlp', *args, **kwargs): + def __init__( + self, + n_seq: List[int], + activation: Union[Activation, List[Activation]] = Activation.swish, + initialization: Initializer = Initializer.kaiming_uniform, + weight_norm: bool = True, + name: str = "mlp", + *args, + **kwargs, + ): super().__init__() self.layers = OrderedDict() - current_activation = '' + current_activation = "" assert isinstance(n_seq, Activation) or isinstance(n_seq, list) for i in range(len(n_seq) - 1): if isinstance(activation, list): current_activation = activation[i] elif i < len(n_seq) - 2: current_activation = activation - self.layers['{}_{}'.format(name, i)] = get_linear_layer(n_seq[i], n_seq[i + 1], weight_norm, initialization, - *args, **kwargs) - if (isinstance(activation, Activation) and i < len(n_seq) - 2) or isinstance(activation, list): - if current_activation == 'none': + self.layers["{}_{}".format(name, i)] = get_linear_layer( + n_seq[i], n_seq[i + 1], weight_norm, initialization, *args, **kwargs + ) + if ( + isinstance(activation, Activation) and i < len(n_seq) - 2 + ) or isinstance(activation, list): + if current_activation == "none": continue - self.layers['{}_{}_activation'.format(name, i)] = get_activation_layer(current_activation, *args, - **kwargs) + self.layers["{}_{}_activation".format(name, i)] = get_activation_layer( + current_activation, *args, **kwargs + ) self.layers = torch.nn.ModuleDict(self.layers) def forward(self, x): @@ -61,8 +77,15 @@ class MLP(torch.nn.Module): class Siren(torch.nn.Module): - def __init__(self, n_seq: List[int], first_omega: float = 30.0, - omega: float = 30.0, name: str = 'siren', *args, **kwargs): + def __init__( + self, + n_seq: List[int], + first_omega: float = 30.0, + omega: float = 30.0, + name: str = "siren", + *args, + **kwargs, + ): super().__init__() self.layers = OrderedDict() self.first_omega = first_omega @@ -70,24 +93,37 @@ class Siren(torch.nn.Module): assert isinstance(n_seq, str) or isinstance(n_seq, list) for i in range(len(n_seq) - 1): if i == 0: - self.layers['{}_{}'.format(name, i)] = self.get_siren_layer(n_seq[i], n_seq[i + 1], True, first_omega) + self.layers["{}_{}".format(name, i)] = self.get_siren_layer( + n_seq[i], n_seq[i + 1], True, first_omega + ) else: - self.layers['{}_{}'.format(name, i)] = self.get_siren_layer(n_seq[i], n_seq[i + 1], False, omega) + self.layers["{}_{}".format(name, i)] = self.get_siren_layer( + n_seq[i], n_seq[i + 1], False, omega + ) if i < (len(n_seq) - 2): - self.layers['{}_{}_activation'.format(name, i)] = get_activation_layer(Activation.sin, *args, **kwargs) + self.layers["{}_{}_activation".format(name, i)] = get_activation_layer( + Activation.sin, *args, **kwargs + ) self.layers = torch.nn.ModuleDict(self.layers) @staticmethod - def get_siren_layer(input_dim: int, output_dim: int, is_first: bool, omega_0: float): + def get_siren_layer( + input_dim: int, output_dim: int, is_first: bool, omega_0: float + ): layer = torch.nn.Linear(input_dim, output_dim) dim = input_dim if is_first: torch.nn.init.uniform_(layer.weight.data, -1.0 / dim, 1.0 / dim) else: - torch.nn.init.uniform_(layer.weight.data, -1.0 * math.sqrt(6.0 / dim) / omega_0, - math.sqrt(6.0 / dim) / omega_0) - torch.nn.init.uniform_(layer.bias.data, -1 * math.sqrt(1 / dim), math.sqrt(1 / dim)) + torch.nn.init.uniform_( + layer.weight.data, + -1.0 * math.sqrt(6.0 / dim) / omega_0, + math.sqrt(6.0 / dim) / omega_0, + ) + torch.nn.init.uniform_( + layer.bias.data, -1 * math.sqrt(1 / dim), math.sqrt(1 / dim) + ) return layer def forward(self, x): @@ -113,7 +149,7 @@ class SingleVar(torch.nn.Module): self.value = torch.nn.Parameter(torch.Tensor([initialization])) def forward(self, x) -> torch.Tensor: - return x[:, :1] * 0. + self.value + return x[:, :1] * 0.0 + self.value def get_value(self) -> torch.Tensor: return self.value @@ -135,7 +171,7 @@ class BoundedSingleVar(torch.nn.Module): self.ub, self.lb = upper_bound, lower_bound def forward(self, x) -> torch.Tensor: - return x[:, :1] * 0. + self.layer(self.value) * (self.ub - self.lb) + self.lb + return x[:, :1] * 0.0 + self.layer(self.value) * (self.ub - self.lb) + self.lb def get_value(self) -> torch.Tensor: return self.layer(self.value) * (self.ub - self.lb) + self.lb @@ -144,18 +180,22 @@ class BoundedSingleVar(torch.nn.Module): class Arch(enum.Enum): """Enumerate pre-defined neural networks.""" - mlp = 'mlp' - toy = 'toy' - mlp_xl = 'mlp_xl' - single_var = 'single_var' - bounded_single_var = 'bounded_single_var' - siren = 'siren' + mlp = "mlp" + toy = "toy" + mlp_xl = "mlp_xl" + single_var = "single_var" + bounded_single_var = "bounded_single_var" + siren = "siren" -def get_net_node(inputs: Union[Tuple[str, ...], List[str]], outputs: Union[Tuple[str, ...], List[str]], - arch: Arch = None, name=None, - *args, - **kwargs) -> NetNode: +def get_net_node( + inputs: Union[Tuple[str, ...], List[str]], + outputs: Union[Tuple[str, ...], List[str]], + arch: Arch = None, + name=None, + *args, + **kwargs, +) -> NetNode: """Get a net node wrapping networks with pre-defined configurations :param inputs: Input symbols for the generated node. @@ -175,36 +215,65 @@ def get_net_node(inputs: Union[Tuple[str, ...], List[str]], outputs: Union[Tuple :return: """ arch = Arch.mlp if arch is None else arch - if 'evaluate' in kwargs.keys(): - evaluate = kwargs.pop('evaluate') + if "evaluate" in kwargs.keys(): + evaluate = kwargs.pop("evaluate") else: if arch == Arch.mlp: - seq = kwargs['seq'] if 'seq' in kwargs.keys() else [len(inputs), 20, 20, 20, 20, len(outputs)] - evaluate = MLP(n_seq=seq, activation=Activation.swish, initialization=Initializer.kaiming_uniform, - weight_norm=True) + seq = ( + kwargs["seq"] + if "seq" in kwargs.keys() + else [len(inputs), 20, 20, 20, 20, len(outputs)] + ) + evaluate = MLP( + n_seq=seq, + activation=Activation.swish, + initialization=Initializer.kaiming_uniform, + weight_norm=True, + ) elif arch == Arch.toy: evaluate = SimpleExpr("nothing") - elif arch == Arch.mlp_xl or arch == 'fc': - seq = kwargs['seq'] if 'seq' in kwargs.keys() else [len(inputs), 512, 512, 512, 512, 512, 512, len(outputs)] - evaluate = MLP(n_seq=seq, activation=Activation.silu, initialization=Initializer.kaiming_uniform, - weight_norm=True) + elif arch == Arch.mlp_xl or arch == "fc": + seq = ( + kwargs["seq"] + if "seq" in kwargs.keys() + else [len(inputs), 512, 512, 512, 512, 512, 512, len(outputs)] + ) + evaluate = MLP( + n_seq=seq, + activation=Activation.silu, + initialization=Initializer.kaiming_uniform, + weight_norm=True, + ) elif arch == Arch.single_var: - evaluate = SingleVar(initialization=kwargs.get('initialization', 1.)) + evaluate = SingleVar(initialization=kwargs.get("initialization", 1.0)) elif arch == Arch.bounded_single_var: - evaluate = BoundedSingleVar(lower_bound=kwargs['lower_bound'], upper_bound=kwargs['upper_bound']) + evaluate = BoundedSingleVar( + lower_bound=kwargs["lower_bound"], upper_bound=kwargs["upper_bound"] + ) elif arch == Arch.siren: - seq = kwargs['seq'] if 'seq' in kwargs.keys() else [len(inputs), 512, 512, 512, 512, 512, 512, len(outputs)] + seq = ( + kwargs["seq"] + if "seq" in kwargs.keys() + else [len(inputs), 512, 512, 512, 512, 512, 512, len(outputs)] + ) evaluate = Siren(n_seq=seq) else: - logger.error(f'{arch} is not supported!') - raise NotImplementedError(f'{arch} is not supported!') - nn = NetNode(inputs=inputs, outputs=outputs, net=evaluate, name=name, *args, **kwargs) + logger.error(f"{arch} is not supported!") + raise NotImplementedError(f"{arch} is not supported!") + nn = NetNode( + inputs=inputs, outputs=outputs, net=evaluate, name=name, *args, **kwargs + ) return nn -def get_shared_net_node(shared_node: NetNode, inputs: Union[Tuple[str, ...], List[str]], - outputs: Union[Tuple[str, ...], List[str]], name=None, *args, - **kwargs) -> NetNode: +def get_shared_net_node( + shared_node: NetNode, + inputs: Union[Tuple[str, ...], List[str]], + outputs: Union[Tuple[str, ...], List[str]], + name=None, + *args, + **kwargs, +) -> NetNode: """Construct a netnode, the net of which is shared by a given netnode. One can specify different inputs and outputs just like an independent netnode. However, the net parameters may have multiple references. Thus the step operations during optimization should only be applied once. @@ -221,22 +290,29 @@ def get_shared_net_node(shared_node: NetNode, inputs: Union[Tuple[str, ...], Lis :param kwargs: :return: """ - nn = NetNode(inputs, outputs, shared_node.net, is_reference=True, name=name, *args, **kwargs) + nn = NetNode( + inputs, outputs, shared_node.net, is_reference=True, name=name, *args, **kwargs + ) return nn def get_inter_name(length: int, prefix: str): - return [prefix + f'_{i}' for i in range(length)] + return [prefix + f"_{i}" for i in range(length)] class SimpleExpr(torch.nn.Module): """This class is for testing. One can override SimpleExper.forward to represent complex formulas.""" - def __init__(self, expr, name='expr'): + def __init__(self, expr, name="expr"): super().__init__() self.evaluate = expr self.name = name self._placeholder = torch.nn.Parameter(torch.Tensor([0.0])) def forward(self, x): - return self._placeholder + x[:, :1] * x[:, :1] / 2 + x[:, 1:] * x[:, 1:] / 2 - self._placeholder + return ( + self._placeholder + + x[:, :1] * x[:, :1] / 2 + + x[:, 1:] * x[:, 1:] / 2 + - self._placeholder + ) diff --git a/idrlnet/callbacks.py b/idrlnet/callbacks.py index 7311134..0f6691e 100644 --- a/idrlnet/callbacks.py +++ b/idrlnet/callbacks.py @@ -7,13 +7,13 @@ from torch.utils.tensorboard import SummaryWriter from idrlnet.receivers import Receiver, Signal from idrlnet.variable import Variables -__all__ = ['GradientReceiver', 'SummaryReceiver', 'HandleResultReceiver'] +__all__ = ["GradientReceiver", "SummaryReceiver", "HandleResultReceiver"] class GradientReceiver(Receiver): """Register the receiver to monitor gradient norm on the Tensorboard.""" - def receive_notify(self, solver: 'Solver', message): # noqa + def receive_notify(self, solver: "Solver", message): # noqa if not (Signal.TRAIN_PIPE_END in message): return for netnode in solver.netnodes: @@ -23,9 +23,11 @@ class GradientReceiver(Receiver): for p in model.parameters(): param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 - total_norm = total_norm ** (1. / 2) + total_norm = total_norm ** (1.0 / 2) assert isinstance(solver.receivers[0], SummaryWriter) - solver.summary_receiver.add_scalar('gradient/total_norm', total_norm, solver.global_step) + solver.summary_receiver.add_scalar( + "gradient/total_norm", total_norm, solver.global_step + ) class SummaryReceiver(SummaryWriter, Receiver): @@ -34,15 +36,19 @@ class SummaryReceiver(SummaryWriter, Receiver): def __init__(self, *args, **kwargs): SummaryWriter.__init__(self, *args, **kwargs) - def receive_notify(self, solver: 'Solver', message: Dict): # noqa + def receive_notify(self, solver: "Solver", message: Dict): # noqa if Signal.AFTER_COMPUTE_LOSS in message.keys(): loss_component = message[Signal.AFTER_COMPUTE_LOSS] - self.add_scalars('loss_overview', loss_component, solver.global_step) + self.add_scalars("loss_overview", loss_component, solver.global_step) for key, value in loss_component.items(): - self.add_scalar(f'loss_component/{key}', value, solver.global_step) + self.add_scalar(f"loss_component/{key}", value, solver.global_step) if Signal.TRAIN_PIPE_END in message.keys(): for i, optimizer in enumerate(solver.optimizers): - self.add_scalar(f'optimizer/lr_{i}', optimizer.param_groups[0]['lr'], solver.global_step) + self.add_scalar( + f"optimizer/lr_{i}", + optimizer.param_groups[0]["lr"], + solver.global_step, + ) class HandleResultReceiver(Receiver): @@ -51,11 +57,13 @@ class HandleResultReceiver(Receiver): def __init__(self, result_dir): self.result_dir = result_dir - def receive_notify(self, solver: 'Solver', message: Dict): # noqa + def receive_notify(self, solver: "Solver", message: Dict): # noqa if Signal.SOLVE_END in message.keys(): samples = solver.sample_variables_from_domains() in_var, _, lambda_out = solver.generate_in_out_dict(samples) - pred_out_sample = solver.forward_through_all_graph(in_var, solver.outvar_dict_index) + pred_out_sample = solver.forward_through_all_graph( + in_var, solver.outvar_dict_index + ) diff_out_sample = {key: Variables() for key in pred_out_sample} results_path = pathlib.Path(self.result_dir) results_path.mkdir(exist_ok=True, parents=True) @@ -65,7 +73,15 @@ class HandleResultReceiver(Receiver): pred_out_sample[key][_key] = samples[key][_key] diff_out_sample[key][_key] = samples[key][_key] else: - diff_out_sample[key][_key] = pred_out_sample[key][_key] - samples[key][_key] - samples[key].save(os.path.join(results_path, f'{key}_true'), ['vtu', 'np', 'csv']) - pred_out_sample[key].save(os.path.join(results_path, f'{key}_pred'), ['vtu', 'np', 'csv']) - diff_out_sample[key].save(os.path.join(results_path, f'{key}_diff'), ['vtu', 'np', 'csv']) + diff_out_sample[key][_key] = ( + pred_out_sample[key][_key] - samples[key][_key] + ) + samples[key].save( + os.path.join(results_path, f"{key}_true"), ["vtu", "np", "csv"] + ) + pred_out_sample[key].save( + os.path.join(results_path, f"{key}_pred"), ["vtu", "np", "csv"] + ) + diff_out_sample[key].save( + os.path.join(results_path, f"{key}_diff"), ["vtu", "np", "csv"] + ) diff --git a/idrlnet/data.py b/idrlnet/data.py index b00910d..f66343d 100644 --- a/idrlnet/data.py +++ b/idrlnet/data.py @@ -36,6 +36,7 @@ class DataNode(Node): :param args: :param kwargs: """ + counter = 0 @property @@ -87,18 +88,27 @@ class DataNode(Node): try: output_vars[key] = lambdify_np(value, input_vars)(**input_vars) except: - logger.error('unsupported constraints type.') - raise ValueError('unsupported constraints type.') + logger.error("unsupported constraints type.") + raise ValueError("unsupported constraints type.") try: return Variables({**input_vars, **output_vars}).to_torch_tensor_() except: return Variables({**input_vars, **output_vars}) - def __init__(self, inputs: Union[Tuple[str, ...], List[str]], outputs: Union[Tuple[str, ...], List[str]], - sample_fn: Callable, loss_fn: str = 'square', lambda_outputs: Union[Tuple[str, ...], List[str]] = None, - name=None, sigma=1.0, var_sigma=False, - *args, **kwargs): + def __init__( + self, + inputs: Union[Tuple[str, ...], List[str]], + outputs: Union[Tuple[str, ...], List[str]], + sample_fn: Callable, + loss_fn: str = "square", + lambda_outputs: Union[Tuple[str, ...], List[str]] = None, + name=None, + sigma=1.0, + var_sigma=False, + *args, + **kwargs, + ): self.inputs: Union[Tuple, List[str]] = inputs self.outputs: Union[Tuple, List[str]] = outputs self.lambda_outputs = lambda_outputs @@ -113,13 +123,22 @@ class DataNode(Node): self.loss_fn = loss_fn def __str__(self): - str_list = ["DataNode properties:\n" - "lambda_outputs: {}\n".format(self.lambda_outputs)] - return super().__str__() + ''.join(str_list) + str_list = [ + "DataNode properties:\n" "lambda_outputs: {}\n".format(self.lambda_outputs) + ] + return super().__str__() + "".join(str_list) -def get_data_node(fun: Callable, name=None, loss_fn='square', sigma=1., var_sigma=False, *args, **kwargs) -> DataNode: - """ Construct a datanode from sampling functions. +def get_data_node( + fun: Callable, + name=None, + loss_fn="square", + sigma=1.0, + var_sigma=False, + *args, + **kwargs, +) -> DataNode: + """Construct a datanode from sampling functions. :param fun: Each call of the Callable object should return a sampling dict. :type fun: Callable @@ -135,26 +154,56 @@ def get_data_node(fun: Callable, name=None, loss_fn='square', sigma=1., var_sigm in_, out_ = fun() inputs = list(in_.keys()) outputs = list(out_.keys()) - lambda_outputs = list(filter(lambda x: x.startswith('lambda_'), outputs)) - outputs = list(filter(lambda x: not x.startswith('lambda_'), outputs)) - name = (fun.__name__ if inspect.isfunction(fun) else type(fun).__name__) if name is None else name - dn = DataNode(inputs=inputs, outputs=outputs, sample_fn=fun, lambda_outputs=lambda_outputs, loss_fn=loss_fn, - name=name, sigma=sigma, var_sigma=var_sigma, *args, **kwargs) + lambda_outputs = list(filter(lambda x: x.startswith("lambda_"), outputs)) + outputs = list(filter(lambda x: not x.startswith("lambda_"), outputs)) + name = ( + (fun.__name__ if inspect.isfunction(fun) else type(fun).__name__) + if name is None + else name + ) + dn = DataNode( + inputs=inputs, + outputs=outputs, + sample_fn=fun, + lambda_outputs=lambda_outputs, + loss_fn=loss_fn, + name=name, + sigma=sigma, + var_sigma=var_sigma, + *args, + **kwargs, + ) return dn -def datanode(_fun: Callable = None, name=None, loss_fn='square', sigma=1., var_sigma=False, **kwargs): +def datanode( + _fun: Callable = None, + name=None, + loss_fn="square", + sigma=1.0, + var_sigma=False, + **kwargs, +): """As an alternative, decorate Callable classes as Datanode.""" def wrap(fun): if inspect.isclass(fun): - assert issubclass(fun, SampleDomain), f"{fun} should be subclass of .data.Sample" + assert issubclass( + fun, SampleDomain + ), f"{fun} should be subclass of .data.Sample" fun = fun() assert isinstance(fun, Callable) @functools.wraps(fun) def wrapped_fun(): - dn = get_data_node(fun, name=name, loss_fn=loss_fn, sigma=sigma, var_sigma=var_sigma, **kwargs) + dn = get_data_node( + fun, + name=name, + loss_fn=loss_fn, + sigma=sigma, + var_sigma=var_sigma, + **kwargs, + ) return dn return wrapped_fun @@ -163,9 +212,12 @@ def datanode(_fun: Callable = None, name=None, loss_fn='square', sigma=1., var_s def get_data_nodes(funs: List[Callable], *args, **kwargs) -> Tuple[DataNode]: - if 'names' in kwargs: - names = kwargs.pop('names') - return tuple(get_data_node(fun, name=name, *args, **kwargs) for fun, name in zip(funs, names)) + if "names" in kwargs: + names = kwargs.pop("names") + return tuple( + get_data_node(fun, name=name, *args, **kwargs) + for fun, name in zip(funs, names) + ) else: return tuple(get_data_node(fun, *args, **kwargs) for fun in funs) diff --git a/idrlnet/geo_utils/geo_builder.py b/idrlnet/geo_utils/geo_builder.py index 5056616..0820d28 100644 --- a/idrlnet/geo_utils/geo_builder.py +++ b/idrlnet/geo_utils/geo_builder.py @@ -1,28 +1,42 @@ """ A simple factory for constructing Geometric Objects""" from .geo import Geometry -from .geo_obj import Line1D, Line, Tube2D, Rectangle, Circle, Plane, Tube3D, Box, Sphere, Cylinder, CircularTube, \ - Triangle, Heart +from .geo_obj import ( + Line1D, + Line, + Tube2D, + Rectangle, + Circle, + Plane, + Tube3D, + Box, + Sphere, + Cylinder, + CircularTube, + Triangle, + Heart, +) -__all__ = ['GeometryBuilder'] +__all__ = ["GeometryBuilder"] class GeometryBuilder: - GEOMAP = {'Line1D': Line1D, - 'Line': Line, - 'Rectangle': Rectangle, - 'Circle': Circle, - 'Channel2D': Tube2D, - 'Plane': Plane, - 'Sphere': Sphere, - 'Box': Box, - 'Channel': Tube3D, - 'Channel3D': Tube3D, - 'Cylinder': Cylinder, - 'CircularTube': CircularTube, - 'Triangle': Triangle, - 'Heart': Heart, - } + GEOMAP = { + "Line1D": Line1D, + "Line": Line, + "Rectangle": Rectangle, + "Circle": Circle, + "Channel2D": Tube2D, + "Plane": Plane, + "Sphere": Sphere, + "Box": Box, + "Channel": Tube3D, + "Channel3D": Tube3D, + "Cylinder": Cylinder, + "CircularTube": CircularTube, + "Triangle": Triangle, + "Heart": Heart, + } @staticmethod def get_geometry(geo: str, **kwargs) -> Geometry: @@ -33,5 +47,7 @@ class GeometryBuilder: :return: A geometry object with given kwargs. :rtype: Geometry """ - assert geo in GeometryBuilder.GEOMAP.keys(), f'The geometry {geo} not implemented!' + assert ( + geo in GeometryBuilder.GEOMAP.keys() + ), f"The geometry {geo} not implemented!" return GeometryBuilder.GEOMAP[geo](**kwargs) diff --git a/idrlnet/geo_utils/sympy_np.py b/idrlnet/geo_utils/sympy_np.py index 320e5c0..4c0247a 100644 --- a/idrlnet/geo_utils/sympy_np.py +++ b/idrlnet/geo_utils/sympy_np.py @@ -10,7 +10,7 @@ from functools import reduce import collections from sympy import Max, Min, Mul -__all__ = ['lambdify_np'] +__all__ = ["lambdify_np"] class WrapSympy: @@ -20,10 +20,14 @@ class WrapSympy: def _wrapper_guide(args): func_1 = args[0] func_2 = args[1] - cond_1 = (isinstance(func_1, WrapSympy) and not func_1.is_sympy) + cond_1 = isinstance(func_1, WrapSympy) and not func_1.is_sympy cond_2 = isinstance(func_2, WrapSympy) and not func_2.is_sympy - cond_3 = (not isinstance(func_1, WrapSympy)) and isinstance(func_1, collections.Callable) - cond_4 = (not isinstance(func_2, WrapSympy)) and isinstance(func_2, collections.Callable) + cond_3 = (not isinstance(func_1, WrapSympy)) and isinstance( + func_1, collections.Callable + ) + cond_4 = (not isinstance(func_2, WrapSympy)) and isinstance( + func_2, collections.Callable + ) return cond_1 or cond_2 or cond_3 or cond_4, func_1, func_2 @@ -111,8 +115,11 @@ def _try_float(fn): def _constant_bool(boolean: bool): def fn(**x): - return np.ones_like(next(iter(x.items()))[1], dtype=bool) if boolean else np.zeros_like( - next(iter(x.items()))[1], dtype=bool) + return ( + np.ones_like(next(iter(x.items()))[1], dtype=bool) + if boolean + else np.zeros_like(next(iter(x.items()))[1], dtype=bool) + ) return fn @@ -128,7 +135,7 @@ def lambdify_np(f, r: Iterable): if isinstance(r, dict): r = r.keys() if isinstance(f, WrapSympy) and f.is_sympy: - lambdify_f = lambdify([k for k in r], f, [PLACEHOLDER, 'numpy']) + lambdify_f = lambdify([k for k in r], f, [PLACEHOLDER, "numpy"]) lambdify_f.input_keys = [k for k in r] return lambdify_f if isinstance(f, WrapSympy) and not f.is_sympy: @@ -141,30 +148,31 @@ def lambdify_np(f, r: Iterable): if isinstance(f, float): return _constant_float(f) else: - lambdify_f = lambdify([k for k in r], f, [PLACEHOLDER, 'numpy']) + lambdify_f = lambdify([k for k in r], f, [PLACEHOLDER, "numpy"]) lambdify_f.input_keys = [k for k in r] return lambdify_f -PLACEHOLDER = {'amin': lambda x: reduce(lambda y, z: np.minimum(y, z), x), - 'amax': lambda x: reduce(lambda y, z: np.maximum(y, z), x), - 'Min': lambda *x: reduce(lambda y, z: np.minimum(y, z), x), - 'Max': lambda *x: reduce(lambda y, z: np.maximum(y, z), x), - 'Heaviside': lambda x: np.heaviside(x, 0), - 'equal': lambda x, y: np.isclose(x, y), - 'Xor': np.logical_xor, - 'cos': np.cos, - 'sin': np.sin, - 'tan': np.tan, - 'exp': np.exp, - 'sqrt': np.sqrt, - 'log': np.log, - 'sinh': np.sinh, - 'cosh': np.cosh, - 'tanh': np.tanh, - 'asin': np.arcsin, - 'acos': np.arccos, - 'atan': np.arctan, - 'Abs': np.abs, - 'DiracDelta': np.zeros_like, - } +PLACEHOLDER = { + "amin": lambda x: reduce(lambda y, z: np.minimum(y, z), x), + "amax": lambda x: reduce(lambda y, z: np.maximum(y, z), x), + "Min": lambda *x: reduce(lambda y, z: np.minimum(y, z), x), + "Max": lambda *x: reduce(lambda y, z: np.maximum(y, z), x), + "Heaviside": lambda x: np.heaviside(x, 0), + "equal": lambda x, y: np.isclose(x, y), + "Xor": np.logical_xor, + "cos": np.cos, + "sin": np.sin, + "tan": np.tan, + "exp": np.exp, + "sqrt": np.sqrt, + "log": np.log, + "sinh": np.sinh, + "cosh": np.cosh, + "tanh": np.tanh, + "asin": np.arcsin, + "acos": np.arccos, + "atan": np.arctan, + "Abs": np.abs, + "DiracDelta": np.zeros_like, +} diff --git a/idrlnet/graph.py b/idrlnet/graph.py index 3c6d545..c6885c6 100644 --- a/idrlnet/graph.py +++ b/idrlnet/graph.py @@ -13,15 +13,15 @@ from idrlnet.header import logger, DIFF_SYMBOL from idrlnet.pde import PdeNode from idrlnet.net import NetNode -__all__ = ['ComputableNodeList', 'Vertex', 'VertexTaskPipeline'] -x, y = sp.symbols('x y') +__all__ = ["ComputableNodeList", "Vertex", "VertexTaskPipeline"] +x, y = sp.symbols("x y") ComputableNodeList = [List[Union[PdeNode, NetNode]]] class Vertex(Node): counter = 0 - def __init__(self, pre=None, next=None, node=None, ntype='c'): + def __init__(self, pre=None, next=None, node=None, ntype="c"): node = Node() if node is None else node self.__dict__ = node.__dict__.copy() self.index = type(self).counter @@ -29,7 +29,7 @@ class Vertex(Node): self.pre = pre if pre is not None else set() self.next = next if pre is not None else set() self.ntype = ntype - assert self.ntype in ('d', 'c', 'r') + assert self.ntype in ("d", "c", "r") def __eq__(self, other): return self.index == other.index @@ -38,8 +38,11 @@ class Vertex(Node): return self.index def __str__(self): - info = f"index: {self.index}\n" + f"pre: {[node.index for node in self.pre]}\n" \ - + f"next: {[node.index for node in self.next]}\n" + info = ( + f"index: {self.index}\n" + + f"pre: {[node.index for node in self.pre]}\n" + + f"next: {[node.index for node in self.next]}\n" + ) return super().__str__() + info @@ -54,7 +57,9 @@ class VertexTaskPipeline: def evaluation_order_list(self, evaluation_order_list): self._evaluation_order_list = evaluation_order_list - def __init__(self, nodes: ComputableNodeList, invar: Variables, req_names: List[str]): + def __init__( + self, nodes: ComputableNodeList, invar: Variables, req_names: List[str] + ): self.nodes = nodes self.req_names = req_names self.computable = set(invar.keys()) @@ -74,14 +79,14 @@ class VertexTaskPipeline: final_graph_node.inputs = [req_name] final_graph_node.derivatives = tuple() final_graph_node.outputs = tuple() - final_graph_node.name = f'<{req_name}>' - final_graph_node.ntype = 'r' + final_graph_node.name = f"<{req_name}>" + final_graph_node.ntype = "r" graph_nodes.add(final_graph_node) req_name_dict[req_name].append(final_graph_node) required_stack.append(final_graph_node) final_graph_node.evaluate = lambda x: x - logger.info('Constructing computation graph...') + logger.info("Constructing computation graph...") while len(req_name_dict) > 0: to_be_removed = set() to_be_added = defaultdict(list) @@ -96,14 +101,20 @@ class VertexTaskPipeline: continue for output in gn.outputs: output = tuple(output.split(DIFF_SYMBOL)) - if len(output) <= len(req_name) and req_name[:len(output)] == output and len( - output) > match_score: + if ( + len(output) <= len(req_name) + and req_name[: len(output)] == output + and len(output) > match_score + ): match_score = len(output) match_gn = gn for p_in in invar.keys(): p_in = tuple(p_in.split(DIFF_SYMBOL)) - if len(p_in) <= len(req_name) and req_name[:len(p_in)] == p_in and len( - p_in) > match_score: + if ( + len(p_in) <= len(req_name) + and req_name[: len(p_in)] == p_in + and len(p_in) > match_score + ): match_score = len(p_in) match_gn = None for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]: @@ -112,9 +123,13 @@ class VertexTaskPipeline: raise Exception("Can't be computed: " + DIFF_SYMBOL.join(req_name)) elif match_gn is not None: for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]: - logger.info(f'{sub_gn.name}.{DIFF_SYMBOL.join(req_name)} <---- {match_gn.name}') + logger.info( + f"{sub_gn.name}.{DIFF_SYMBOL.join(req_name)} <---- {match_gn.name}" + ) match_gn.next.add(sub_gn) - self.egde_data[(match_gn.name, sub_gn.name)].add(DIFF_SYMBOL.join(req_name)) + self.egde_data[(match_gn.name, sub_gn.name)].add( + DIFF_SYMBOL.join(req_name) + ) required_stack.append(match_gn) for sub_gn in req_name_dict[DIFF_SYMBOL.join(req_name)]: sub_gn.pre.add(match_gn) @@ -148,51 +163,91 @@ class VertexTaskPipeline: node.name = key node.outputs = (key,) node.inputs = tuple() - node.ntype = 'd' + node.ntype = "d" self._graph_node_table[key] = node - logger.info('Computation graph constructed.') + logger.info("Computation graph constructed.") def operation_order(self, invar: Variables): for node in self.evaluation_order_list: if not set(node.derivatives).issubset(invar.keys()): - invar.differentiate_(independent_var=invar, required_derivatives=node.derivatives) - invar.update(node.evaluate({**invar.subset(node.inputs), **invar.subset(node.derivatives)})) + invar.differentiate_( + independent_var=invar, required_derivatives=node.derivatives + ) + invar.update( + node.evaluate( + {**invar.subset(node.inputs), **invar.subset(node.derivatives)} + ) + ) - def forward_pipeline(self, invar: Variables, req_names: List[str] = None) -> Variables: + def forward_pipeline( + self, invar: Variables, req_names: List[str] = None + ) -> Variables: if req_names is None or set(req_names).issubset(set(self.computable)): outvar = copy(invar) self.operation_order(outvar) return outvar.subset(self.req_names if req_names is None else req_names) else: - logger.info('The existing graph fails. Construct a temporary graph...') - return VertexTaskPipeline(self.nodes, invar, req_names).forward_pipeline(invar) + logger.info("The existing graph fails. Construct a temporary graph...") + return VertexTaskPipeline(self.nodes, invar, req_names).forward_pipeline( + invar + ) def to_json(self): pass def display(self, filename: str = None): _, ax = plt.subplots(1, 1, figsize=(8, 8)) - ax.axis('off') + ax.axis("off") pos = nx.spring_layout(self.G, k=10 / (math.sqrt(self.G.order()) + 0.1)) - nx.draw_networkx_nodes(self.G, pos, - nodelist=list( - node for node in self.G.nodes if self._graph_node_table[node].ntype == 'c'), - cmap=plt.get_cmap('jet'), - node_size=1300, node_color="pink", alpha=0.5) - nx.draw_networkx_nodes(self.G, pos, - nodelist=list( - node for node in self.G.nodes if self._graph_node_table[node].ntype == 'r'), - cmap=plt.get_cmap('jet'), - node_size=1300, node_color="green", alpha=0.3) - nx.draw_networkx_nodes(self.G, pos, - nodelist=list( - node for node in self.G.nodes if self._graph_node_table[node].ntype == 'd'), - cmap=plt.get_cmap('jet'), - node_size=1300, node_color="blue", alpha=0.3) - nx.draw_networkx_edges(self.G, pos, edge_color='r', arrows=True, arrowsize=30, arrowstyle="-|>") + nx.draw_networkx_nodes( + self.G, + pos, + nodelist=list( + node + for node in self.G.nodes + if self._graph_node_table[node].ntype == "c" + ), + cmap=plt.get_cmap("jet"), + node_size=1300, + node_color="pink", + alpha=0.5, + ) + nx.draw_networkx_nodes( + self.G, + pos, + nodelist=list( + node + for node in self.G.nodes + if self._graph_node_table[node].ntype == "r" + ), + cmap=plt.get_cmap("jet"), + node_size=1300, + node_color="green", + alpha=0.3, + ) + nx.draw_networkx_nodes( + self.G, + pos, + nodelist=list( + node + for node in self.G.nodes + if self._graph_node_table[node].ntype == "d" + ), + cmap=plt.get_cmap("jet"), + node_size=1300, + node_color="blue", + alpha=0.3, + ) + nx.draw_networkx_edges( + self.G, pos, edge_color="r", arrows=True, arrowsize=30, arrowstyle="-|>" + ) nx.draw_networkx_labels(self.G, pos) - nx.draw_networkx_edge_labels(self.G, pos, edge_labels={k: ", ".join(v) for k, v in self.egde_data.items()}, - font_size=10) + nx.draw_networkx_edge_labels( + self.G, + pos, + edge_labels={k: ", ".join(v) for k, v in self.egde_data.items()}, + font_size=10, + ) if filename is None: plt.show() else: diff --git a/idrlnet/header.py b/idrlnet/header.py index b6b0df4..18faf90 100644 --- a/idrlnet/header.py +++ b/idrlnet/header.py @@ -14,7 +14,7 @@ class TestFun: self.registered.append(self) def __call__(self, *args, **kwargs): - print(str(self.fun.__name__).center(50, '*')) + print(str(self.fun.__name__).center(50, "*")) self.fun() @staticmethod @@ -36,7 +36,12 @@ def testmemo(fun): testmemo.memo = set() -log_format = '[%(asctime)s] [%(levelname)s] %(message)s' -handlers = [logging.FileHandler('train.log', mode='a'), logging.StreamHandler()] -logging.basicConfig(format=log_format, level=logging.INFO, datefmt='%d-%b-%y %H:%M:%S', handlers=handlers) +log_format = "[%(asctime)s] [%(levelname)s] %(message)s" +handlers = [logging.FileHandler("train.log", mode="a"), logging.StreamHandler()] +logging.basicConfig( + format=log_format, + level=logging.INFO, + datefmt="%d-%b-%y %H:%M:%S", + handlers=handlers, +) logger = logging.getLogger(__name__) diff --git a/idrlnet/net.py b/idrlnet/net.py index 3f31f04..196bec1 100644 --- a/idrlnet/net.py +++ b/idrlnet/net.py @@ -4,11 +4,11 @@ from idrlnet.node import Node from typing import Tuple, List, Dict, Union from contextlib import ExitStack -__all__ = ['NetNode'] +__all__ = ["NetNode"] class WrapEvaluate: - def __init__(self, binding_node: 'NetNode'): + def __init__(self, binding_node: "NetNode"): self.binding_node = binding_node def __call__(self, inputs): @@ -16,15 +16,23 @@ class WrapEvaluate: if isinstance(inputs, dict): keep_type = dict inputs = torch.cat( - [torch.tensor(inputs[key], dtype=torch.float32) if not isinstance(inputs[key], torch.Tensor) else - inputs[ - key] for key in inputs], dim=1) + [ + torch.tensor(inputs[key], dtype=torch.float32) + if not isinstance(inputs[key], torch.Tensor) + else inputs[key] + for key in inputs + ], + dim=1, + ) with ExitStack() as es: if self.binding_node.require_no_grad: es.enter_context(torch.no_grad()) output_var = self.binding_node.net(inputs) if keep_type == dict: - output_var = {outkey: output_var[:, i:i + 1] for i, outkey in enumerate(self.binding_node.outputs)} + output_var = { + outkey: output_var[:, i : i + 1] + for i, outkey in enumerate(self.binding_node.outputs) + } return output_var @@ -63,9 +71,18 @@ class NetNode(Node): def net(self, net): self._net = net - def __init__(self, inputs: Union[Tuple, List[str]], outputs: Union[Tuple, List[str]], - net: torch.nn.Module, fixed: bool = False, require_no_grad: bool = False, is_reference=False, - name=None, *args, **kwargs): + def __init__( + self, + inputs: Union[Tuple, List[str]], + outputs: Union[Tuple, List[str]], + net: torch.nn.Module, + fixed: bool = False, + require_no_grad: bool = False, + is_reference=False, + name=None, + *args, + **kwargs + ): self.is_reference = is_reference self.inputs: Union[Tuple, List[str]] = inputs self.outputs: Union[Tuple, List[str]] = outputs @@ -89,5 +106,5 @@ class NetNode(Node): def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = True): return self.net.load_state_dict(state_dict, strict) - def state_dict(self, destination=None, prefix: str = '', keep_vars: bool = False): + def state_dict(self, destination=None, prefix: str = "", keep_vars: bool = False): return self.net.state_dict(destination, prefix, keep_vars) diff --git a/idrlnet/node.py b/idrlnet/node.py index 9074b55..836083e 100644 --- a/idrlnet/node.py +++ b/idrlnet/node.py @@ -5,7 +5,7 @@ from idrlnet.torch_util import torch_lambdify from idrlnet.variable import Variables from idrlnet.header import DIFF_SYMBOL -__all__ = ['Node'] +__all__ = ["Node"] class Node(object): @@ -58,7 +58,7 @@ class Node(object): try: return self._name except: - self._name = 'Node' + str(id(self)) + self._name = "Node" + str(id(self)) return self._name @name.setter @@ -66,23 +66,33 @@ class Node(object): self._name = name @classmethod - def new_node(cls, name: str = None, tf_eq: Callable = None, free_symbols: List[str] = None, *args, - **kwargs) -> 'Node': + def new_node( + cls, + name: str = None, + tf_eq: Callable = None, + free_symbols: List[str] = None, + *args, + **kwargs + ) -> "Node": node = cls() node.evaluate = LambdaTorchFun(free_symbols, tf_eq, name) node.inputs = [x for x in free_symbols if DIFF_SYMBOL not in x] node.derivatives = [x for x in free_symbols if DIFF_SYMBOL in x] - node.outputs = [name, ] + node.outputs = [ + name, + ] node.name = name return node def __str__(self): - str_list = ["Basic properties:\n", - "name: {}\n".format(self.name), - "inputs: {}\n".format(self.inputs), - "derivatives: {}\n".format(self.derivatives), - "outputs: {}\n".format(self.outputs), ] - return ''.join(str_list) + str_list = [ + "Basic properties:\n", + "name: {}\n".format(self.name), + "inputs: {}\n".format(self.inputs), + "derivatives: {}\n".format(self.derivatives), + "outputs: {}\n".format(self.outputs), + ] + return "".join(str_list) class LambdaTorchFun: diff --git a/idrlnet/pde.py b/idrlnet/pde.py index 8bd901a..53f06ff 100644 --- a/idrlnet/pde.py +++ b/idrlnet/pde.py @@ -6,7 +6,7 @@ from idrlnet.torch_util import _replace_derivatives from idrlnet.header import DIFF_SYMBOL from idrlnet.variable import Variables -__all__ = ['PdeNode', 'ExpressionNode'] +__all__ = ["PdeNode", "ExpressionNode"] class PdeEvaluate: @@ -18,8 +18,11 @@ class PdeEvaluate: def __call__(self, inputs: Variables) -> Variables: result = Variables() for node in self.binding_pde.sub_nodes: - sub_inputs = {k: v for k, v in Variables(inputs).items() if - k in node.inputs or k in node.derivatives} + sub_inputs = { + k: v + for k, v in Variables(inputs).items() + if k in node.inputs or k in node.derivatives + } r = node.evaluate(sub_inputs) result.update(r) return result @@ -53,9 +56,9 @@ class PdeNode(Node): def __init__(self, suffix: str = "", **kwargs): if len(suffix) > 0: - self.suffix = '[' + kwargs['suffix'] + ']' # todo: check prefix + self.suffix = "[" + kwargs["suffix"] + "]" # todo: check prefix else: - self.suffix = '' + self.suffix = "" self.name = type(self).__name__ + self.suffix self.evaluate = PdeEvaluate(self) @@ -77,8 +80,10 @@ class PdeNode(Node): def __str__(self): subnode_str = "\n\n".join( - str(sub_node) + "Equation: \n" + str(self.equations[sub_node.name]) for sub_node in self.sub_nodes) - return super().__str__() + "subnodes".center(30, '-') + '\n' + subnode_str + str(sub_node) + "Equation: \n" + str(self.equations[sub_node.name]) + for sub_node in self.sub_nodes + ) + return super().__str__() + "subnodes".center(30, "-") + "\n" + subnode_str # todo: test required diff --git a/idrlnet/receivers.py b/idrlnet/receivers.py index 6f0e9e6..488c782 100644 --- a/idrlnet/receivers.py +++ b/idrlnet/receivers.py @@ -6,20 +6,20 @@ from typing import Dict, List class Signal(Enum): - REGISTER = 'signal_register' - SOLVE_START = 'signal_solve_start' - TRAIN_PIPE_START = 'signal_train_pipe_start' - BEFORE_COMPUTE_LOSS = 'before_compute_loss' - AFTER_COMPUTE_LOSS = 'compute_loss' - BEFORE_BACKWARD = 'signal_before_backward' - TRAIN_PIPE_END = 'signal_train_pipe_end' - SOLVE_END = 'signal_solve_end' + REGISTER = "signal_register" + SOLVE_START = "signal_solve_start" + TRAIN_PIPE_START = "signal_train_pipe_start" + BEFORE_COMPUTE_LOSS = "before_compute_loss" + AFTER_COMPUTE_LOSS = "compute_loss" + BEFORE_BACKWARD = "signal_before_backward" + TRAIN_PIPE_END = "signal_train_pipe_end" + SOLVE_END = "signal_solve_end" class Receiver(metaclass=abc.ABCMeta): @abc.abstractmethod def receive_notify(self, obj: object, message: Dict): - raise NotImplementedError('Method receive_notify() not implemented!') + raise NotImplementedError("Method receive_notify() not implemented!") class Notifier: diff --git a/idrlnet/solver.py b/idrlnet/solver.py index e69aa74..cd9723b 100644 --- a/idrlnet/solver.py +++ b/idrlnet/solver.py @@ -15,7 +15,7 @@ from idrlnet.variable import Variables, DomainVariables from idrlnet.graph import VertexTaskPipeline import idrlnet -__all__ = ['Solver'] +__all__ = ["Solver"] class Solver(Notifier, Optimizable): @@ -65,20 +65,23 @@ class Solver(Notifier, Optimizable): :param kwargs: """ - def __init__(self, sample_domains: Tuple[Union[DataNode, SampleDomain], ...], - netnodes: List[NetNode], - pdes: Optional[List] = None, - network_dir: str = './network_dir', - summary_dir: Optional[str] = None, - max_iter: int = 1000, - save_freq: int = 100, - print_freq: int = 10, - loading: bool = True, - init_network_dirs: Optional[List[str]] = None, - opt_config: Dict = None, - schedule_config: Dict = None, - result_dir='train_domain/results', - **kwargs): + def __init__( + self, + sample_domains: Tuple[Union[DataNode, SampleDomain], ...], + netnodes: List[NetNode], + pdes: Optional[List] = None, + network_dir: str = "./network_dir", + summary_dir: Optional[str] = None, + max_iter: int = 1000, + save_freq: int = 100, + print_freq: int = 10, + loading: bool = True, + init_network_dirs: Optional[List[str]] = None, + opt_config: Dict = None, + schedule_config: Dict = None, + result_dir="train_domain/results", + **kwargs, + ): self.network_dir: str = network_dir self.domain_losses = {domain.name: domain.loss_fn for domain in sample_domains} @@ -96,8 +99,16 @@ class Solver(Notifier, Optimizable): self.save_freq = save_freq self.print_freq = print_freq try: - self.parse_configure(**{**({"opt_config": opt_config} if opt_config is not None else {}), - **({"schedule_config": schedule_config} if schedule_config is not None else {})}) + self.parse_configure( + **{ + **({"opt_config": opt_config} if opt_config is not None else {}), + **( + {"schedule_config": schedule_config} + if schedule_config is not None + else {} + ), + } + ) except Exception: logger.error("Optimizer configuration failed") raise @@ -109,7 +120,10 @@ class Solver(Notifier, Optimizable): pass self.sample_domains: Tuple[DataNode, ...] = sample_domains self.summary_dir = self.network_dir if summary_dir is None else summary_dir - self.receivers: List[Receiver] = [SummaryReceiver(self.summary_dir), HandleResultReceiver(result_dir)] + self.receivers: List[Receiver] = [ + SummaryReceiver(self.summary_dir), + HandleResultReceiver(result_dir), + ] @property def network_dir(self): @@ -136,12 +150,23 @@ class Solver(Notifier, Optimizable): :return: A list of trainable parameters. :rtype: List[torch.nn.parameter.Parameter] """ - parameter_list = list(map(lambda _net_node: {'params': _net_node.net.parameters()}, - filter(lambda _net_node: not _net_node.is_reference and (not _net_node.fixed), - self.netnodes))) + parameter_list = list( + map( + lambda _net_node: {"params": _net_node.net.parameters()}, + filter( + lambda _net_node: not _net_node.is_reference + and (not _net_node.fixed), + self.netnodes, + ), + ) + ) if len(parameter_list) == 0: - '''To make sure successful initialization of optimizers.''' - parameter_list = [torch.nn.parameter.Parameter(data=torch.Tensor([0.]), requires_grad=True)] + """To make sure successful initialization of optimizers.""" + parameter_list = [ + torch.nn.parameter.Parameter( + data=torch.Tensor([0.0]), requires_grad=True + ) + ] logger.warning("No trainable parameters found!") return parameter_list @@ -158,15 +183,15 @@ class Solver(Notifier, Optimizable): """return sovler information, it will return components recursively""" str_list = [] str_list.append("nets: \n") - str_list.append(''.join([str(net) for net in self.netnodes])) + str_list.append("".join([str(net) for net in self.netnodes])) str_list.append("domains: \n") - str_list.append(''.join([str(domain) for domain in self.sample_domains])) - str_list.append('\n') - str_list.append('optimizer config:\n') + str_list.append("".join([str(domain) for domain in self.sample_domains])) + str_list.append("\n") + str_list.append("optimizer config:\n") for i, _class in enumerate(type(self).mro()): if _class == Optimizable: str_list.append(super(type(self).mro()[i - 1], self).__str__()) - return ''.join(str_list) + return "".join(str_list) def set_param_ranges(self, param_ranges: Dict): for domain in self.sample_domains: @@ -184,7 +209,7 @@ class Solver(Notifier, Optimizable): for value in self.sample_domains: if value.name == name: return value - raise KeyError(f'domain {name} not exist!') + raise KeyError(f"domain {name} not exist!") def generate_computation_pipeline(self): """Generate computation pipeline for all domains. @@ -195,28 +220,40 @@ class Solver(Notifier, Optimizable): self.vertex_pipelines = {} for domain_name, var in in_var.items(): logger.info(f"Constructing computation graph for domain <{domain_name}>") - self.vertex_pipelines[domain_name] = VertexTaskPipeline(self.netnodes + self.pdes, var, - self.outvar_dict_index[domain_name]) + self.vertex_pipelines[domain_name] = VertexTaskPipeline( + self.netnodes + self.pdes, var, self.outvar_dict_index[domain_name] + ) self.vertex_pipelines[domain_name].display( - os.path.join(self.network_dir, f'{domain_name}_{self.global_step}.png')) + os.path.join(self.network_dir, f"{domain_name}_{self.global_step}.png") + ) - def forward_through_all_graph(self, invar_dict: DomainVariables, - req_outvar_dict_index: Dict[str, List[str]]) -> DomainVariables: + def forward_through_all_graph( + self, invar_dict: DomainVariables, req_outvar_dict_index: Dict[str, List[str]] + ) -> DomainVariables: outvar_dict = {} for (key, req_outvar_names) in req_outvar_dict_index.items(): - outvar_dict[key] = self.vertex_pipelines[key].forward_pipeline(invar_dict[key], req_outvar_names) + outvar_dict[key] = self.vertex_pipelines[key].forward_pipeline( + invar_dict[key], req_outvar_names + ) return outvar_dict def append_sample_domain(self, datanode): self.sample_domains = self.sample_domains + (datanode,) def _generate_dict_index(self) -> None: - self.invar_dict_index = {domain.name: domain.inputs for domain in self.sample_domains} - self.outvar_dict_index = {domain.name: domain.outputs for domain in self.sample_domains} - self.lambda_dict_index = {domain.name: domain.lambda_outputs for domain in self.sample_domains} + self.invar_dict_index = { + domain.name: domain.inputs for domain in self.sample_domains + } + self.outvar_dict_index = { + domain.name: domain.outputs for domain in self.sample_domains + } + self.lambda_dict_index = { + domain.name: domain.lambda_outputs for domain in self.sample_domains + } - def generate_in_out_dict(self, samples: DomainVariables) -> \ - Tuple[DomainVariables, DomainVariables, DomainVariables]: + def generate_in_out_dict( + self, samples: DomainVariables + ) -> Tuple[DomainVariables, DomainVariables, DomainVariables]: invar_dict = {} for domain, variable in samples.items(): inner = {} @@ -226,20 +263,40 @@ class Solver(Notifier, Optimizable): invar_dict[domain] = inner invar_dict = { - domain: Variables({key: val for key, val in variable.items() if key in self.invar_dict_index[domain]}) for - domain, variable in samples.items()} + domain: Variables( + { + key: val + for key, val in variable.items() + if key in self.invar_dict_index[domain] + } + ) + for domain, variable in samples.items() + } outvar_dict = { - domain: Variables({key: val for key, val in variable.items() if key in self.outvar_dict_index[domain]}) for - domain, variable in samples.items()} + domain: Variables( + { + key: val + for key, val in variable.items() + if key in self.outvar_dict_index[domain] + } + ) + for domain, variable in samples.items() + } lambda_dict = { - domain: Variables({key: val for key, val in variable.items() if key in self.lambda_dict_index[domain]}) for - domain, variable in samples.items()} + domain: Variables( + { + key: val + for key, val in variable.items() + if key in self.lambda_dict_index[domain] + } + ) + for domain, variable in samples.items() + } return invar_dict, outvar_dict, lambda_dict def solve(self): - """After the solver instance is initialized, the method could be called to solve the entire problem. - """ - self.notify(self, message={Signal.SOLVE_START: 'default'}) + """After the solver instance is initialized, the method could be called to solve the entire problem.""" + self.notify(self, message={Signal.SOLVE_START: "default"}) while self.global_step < self.max_iter: loss = self.train_pipe() if self.global_step % self.print_freq == 0: @@ -247,13 +304,13 @@ class Solver(Notifier, Optimizable): if self.global_step % self.save_freq == 0: self.save() logger.info("Training Stage Ends") - self.notify(self, message={Signal.SOLVE_END: 'default'}) + self.notify(self, message={Signal.SOLVE_END: "default"}) def train_pipe(self): """Sample once; calculate the loss once; backward propagation once :return: None """ - self.notify(self, message={Signal.TRAIN_PIPE_START: 'defaults'}) + self.notify(self, message={Signal.TRAIN_PIPE_START: "defaults"}) for opt in self.optimizers: opt.zero_grad() samples = self.sample_variables_from_domains() @@ -263,7 +320,7 @@ class Solver(Notifier, Optimizable): loss = self.compute_loss(in_var, pred_out_sample, true_out, lambda_out) except RuntimeError: raise - self.notify(self, message={Signal.BEFORE_BACKWARD: 'defaults'}) + self.notify(self, message={Signal.BEFORE_BACKWARD: "defaults"}) loss.backward() for opt in self.optimizers: opt.step() @@ -271,40 +328,64 @@ class Solver(Notifier, Optimizable): for scheduler in self.schedulers: scheduler.step(self.global_step) - self.notify(self, message={Signal.TRAIN_PIPE_END: 'defaults'}) + self.notify(self, message={Signal.TRAIN_PIPE_END: "defaults"}) return loss - def compute_loss(self, in_var: DomainVariables, pred_out_sample: DomainVariables, - true_out: DomainVariables, - lambda_out: DomainVariables) -> torch.Tensor: - """Compute the total loss in one epoch. - - """ + def compute_loss( + self, + in_var: DomainVariables, + pred_out_sample: DomainVariables, + true_out: DomainVariables, + lambda_out: DomainVariables, + ) -> torch.Tensor: + """Compute the total loss in one epoch.""" diff = dict() for domain_name, domain_val in true_out.items(): if len(domain_val) == 0: continue - diff[domain_name] = pred_out_sample[domain_name] - domain_val.to_torch_tensor_() + diff[domain_name] = ( + pred_out_sample[domain_name] - domain_val.to_torch_tensor_() + ) diff[domain_name].update(lambda_out[domain_name]) - diff[domain_name].update(area=in_var[domain_name]['area']) + diff[domain_name].update(area=in_var[domain_name]["area"]) for domain, var in diff.items(): lambda_diff = dict() for constraint, _ in var.items(): - if 'lambda_' + constraint in in_var[domain].keys(): - lambda_diff['lambda_' + constraint] = in_var[domain]['lambda_' + constraint] + if "lambda_" + constraint in in_var[domain].keys(): + lambda_diff["lambda_" + constraint] = in_var[domain][ + "lambda_" + constraint + ] var.update(lambda_diff) self.loss_component = Variables( ChainMap( - *[diff[domain_name].weighted_loss(f"{domain_name}_loss", - loss_function=self.domain_losses[domain_name]) for - domain_name, domain_val in - diff.items()])) + *[ + diff[domain_name].weighted_loss( + f"{domain_name}_loss", + loss_function=self.domain_losses[domain_name], + ) + for domain_name, domain_val in diff.items() + ] + ) + ) self.notify(self, message={Signal.BEFORE_COMPUTE_LOSS: {**self.loss_component}}) - loss = sum({domain_name: self.get_sample_domain(domain_name).sigma * self.loss_component[f"{domain_name}_loss"] for - domain_name in diff}.values()) - self.notify(self, message={Signal.AFTER_COMPUTE_LOSS: {**self.loss_component, **{'total_loss': loss}}}) + loss = sum( + { + domain_name: self.get_sample_domain(domain_name).sigma + * self.loss_component[f"{domain_name}_loss"] + for domain_name in diff + }.values() + ) + self.notify( + self, + message={ + Signal.AFTER_COMPUTE_LOSS: { + **self.loss_component, + **{"total_loss": loss}, + } + }, + ) return loss def infer_step(self, domain_attr: Dict[str, List[str]]) -> DomainVariables: @@ -323,40 +404,46 @@ class Solver(Notifier, Optimizable): return {data_node.name: data_node.sample() for data_node in self.sample_domains} def save(self): - """Save parameters of netnodes and the global step to `model.ckpt`. - """ - save_path = os.path.join(self.network_dir, 'model.ckpt') + """Save parameters of netnodes and the global step to `model.ckpt`.""" + save_path = os.path.join(self.network_dir, "model.ckpt") logger.info("save to path: {}".format(os.path.abspath(save_path))) - save_dict = {f"{net_node.name}_dict": net_node.state_dict() for net_node in - filter(lambda _net: not _net.is_reference, self.netnodes)} + save_dict = { + f"{net_node.name}_dict": net_node.state_dict() + for net_node in filter(lambda _net: not _net.is_reference, self.netnodes) + } for i, opt in enumerate(self.optimizers): - save_dict['optimizer_{}_dict'.format(i)] = opt.state_dict() - save_dict['global_step'] = self.global_step + save_dict["optimizer_{}_dict".format(i)] = opt.state_dict() + save_dict["global_step"] = self.global_step torch.save(save_dict, save_path) def init_load(self): for network_dir in self.init_network_dirs: - save_path = os.path.join(network_dir, 'model.ckpt') + save_path = os.path.join(network_dir, "model.ckpt") save_dict = torch.load(save_path) for net_node in self.netnodes: - if f"{net_node.name}_dict" in save_dict.keys() and not net_node.is_reference: + if ( + f"{net_node.name}_dict" in save_dict.keys() + and not net_node.is_reference + ): net_node.load_state_dict(save_dict[f"{net_node.name}_dict"]) logger.info(f"Successfully loading initialization {net_node.name}.") def load(self): - """Load parameters of netnodes and the global step from `model.ckpt`. - """ - save_path = os.path.join(self.network_dir, 'model.ckpt') + """Load parameters of netnodes and the global step from `model.ckpt`.""" + save_path = os.path.join(self.network_dir, "model.ckpt") if not idrlnet.GPU_ENABLED: - save_dict = torch.load(save_path, map_location=torch.device('cpu')) + save_dict = torch.load(save_path, map_location=torch.device("cpu")) else: save_dict = torch.load(save_path) # todo: save on CPU, load on GPU for i, opt in enumerate(self.optimizers): - opt.load_state_dict(save_dict['optimizer_{}_dict'.format(i)]) - self.global_step = save_dict['global_step'] + opt.load_state_dict(save_dict["optimizer_{}_dict".format(i)]) + self.global_step = save_dict["global_step"] for net_node in self.netnodes: - if f"{net_node.name}_dict" in save_dict.keys() and not net_node.is_reference: + if ( + f"{net_node.name}_dict" in save_dict.keys() + and not net_node.is_reference + ): net_node.load_state_dict(save_dict[f"{net_node.name}_dict"]) logger.info(f"Successfully loading {net_node.name}.") @@ -364,27 +451,34 @@ class Solver(Notifier, Optimizable): """ Call interfaces of ``Optimizable`` """ - opt = self.optimizer_config['optimizer'] + opt = self.optimizer_config["optimizer"] if isinstance(opt, str) and opt in Optimizable.OPTIMIZER_MAP: - opt = Optimizable.OPTIMIZER_MAP[opt](self.trainable_parameters, - **{k: v for k, v in self.optimizer_config.items() if k != 'optimizer'}) + opt = Optimizable.OPTIMIZER_MAP[opt]( + self.trainable_parameters, + **{k: v for k, v in self.optimizer_config.items() if k != "optimizer"}, + ) elif isinstance(opt, Callable): opt = opt else: raise NotImplementedError( - 'The optimizer is not implemented. You may use one of the following optimizer:\n' + '\n'.join( - Optimizable.OPTIMIZER_MAP.keys()) + '\n Example: opt_config=dict(optimizer="Adam", lr=1e-3)') + "The optimizer is not implemented. You may use one of the following optimizer:\n" + + "\n".join(Optimizable.OPTIMIZER_MAP.keys()) + + '\n Example: opt_config=dict(optimizer="Adam", lr=1e-3)' + ) - lr_scheduler = self.schedule_config['scheduler'] + lr_scheduler = self.schedule_config["scheduler"] if isinstance(lr_scheduler, str) and lr_scheduler in Optimizable.SCHEDULE_MAP: - lr_scheduler = Optimizable.SCHEDULE_MAP[lr_scheduler](opt, - **{k: v for k, v in self.schedule_config.items() if - k != 'scheduler'}) + lr_scheduler = Optimizable.SCHEDULE_MAP[lr_scheduler]( + opt, + **{k: v for k, v in self.schedule_config.items() if k != "scheduler"}, + ) elif isinstance(lr_scheduler, Callable): lr_scheduler = lr_scheduler else: raise NotImplementedError( - 'The scheduler is not implemented. You may use one of the following scheduler:\n' + '\n'.join( - Optimizable.SCHEDULE_MAP.keys()) + '\n Example: schedule_config=dict(scheduler="ExponentialLR", gamma=0.999') + "The scheduler is not implemented. You may use one of the following scheduler:\n" + + "\n".join(Optimizable.SCHEDULE_MAP.keys()) + + '\n Example: schedule_config=dict(scheduler="ExponentialLR", gamma=0.999' + ) self.optimizers = [opt] self.schedulers = [lr_scheduler] diff --git a/idrlnet/torch_util.py b/idrlnet/torch_util.py index 0be76b5..d0d6198 100644 --- a/idrlnet/torch_util.py +++ b/idrlnet/torch_util.py @@ -10,7 +10,7 @@ import torch from idrlnet.header import DIFF_SYMBOL from functools import reduce -__all__ = ['integral', 'torch_lambdify'] +__all__ = ["integral", "torch_lambdify"] def integral_fun(x): @@ -19,7 +19,7 @@ def integral_fun(x): return x -integral = implemented_function('integral', lambda x: integral_fun(x)) +integral = implemented_function("integral", lambda x: integral_fun(x)) def torch_lambdify(r, f, *args, **kwargs): @@ -41,27 +41,27 @@ def torch_lambdify(r, f, *args, **kwargs): # todo: more functions TORCH_SYMPY_PRINTER = { - 'sin': torch.sin, - 'cos': torch.cos, - 'tan': torch.tan, - 'exp': torch.exp, - 'sqrt': torch.sqrt, - 'Abs': torch.abs, - 'tanh': torch.tanh, - 'DiracDelta': torch.zeros_like, - 'Heaviside': lambda x: torch.heaviside(x, torch.tensor([0.])), - 'amin': lambda x: reduce(lambda y, z: torch.minimum(y, z), x), - 'amax': lambda x: reduce(lambda y, z: torch.maximum(y, z), x), - 'Min': lambda *x: reduce(lambda y, z: torch.minimum(y, z), x), - 'Max': lambda *x: reduce(lambda y, z: torch.maximum(y, z), x), - 'equal': lambda x, y: torch.isclose(x, y), - 'Xor': torch.logical_xor, - 'log': torch.log, - 'sinh': torch.sinh, - 'cosh': torch.cosh, - 'asin': torch.arcsin, - 'acos': torch.arccos, - 'atan': torch.arctan, + "sin": torch.sin, + "cos": torch.cos, + "tan": torch.tan, + "exp": torch.exp, + "sqrt": torch.sqrt, + "Abs": torch.abs, + "tanh": torch.tanh, + "DiracDelta": torch.zeros_like, + "Heaviside": lambda x: torch.heaviside(x, torch.tensor([0.0])), + "amin": lambda x: reduce(lambda y, z: torch.minimum(y, z), x), + "amax": lambda x: reduce(lambda y, z: torch.maximum(y, z), x), + "Min": lambda *x: reduce(lambda y, z: torch.minimum(y, z), x), + "Max": lambda *x: reduce(lambda y, z: torch.maximum(y, z), x), + "equal": lambda x, y: torch.isclose(x, y), + "Xor": torch.logical_xor, + "log": torch.log, + "sinh": torch.sinh, + "cosh": torch.cosh, + "asin": torch.arcsin, + "acos": torch.arccos, + "atan": torch.arctan, } @@ -75,9 +75,12 @@ def _replace_derivatives(expr): expr = expr.subs(deriv, Function(str(deriv))(*deriv.free_symbols)) while True: try: - custom_fun = {_fun for _fun in expr.atoms(Function) if - (_fun.class_key()[1] == 0) and (not _fun.class_key()[2] == 'integral') - }.pop() + custom_fun = { + _fun + for _fun in expr.atoms(Function) + if (_fun.class_key()[1] == 0) + and (not _fun.class_key()[2] == "integral") + }.pop() new_symbol_name = str(custom_fun) expr = expr.subs(custom_fun, Symbol(new_symbol_name)) except KeyError: @@ -90,7 +93,10 @@ class UnderlineDerivativePrinter(StrPrinter): return expr.func.__name__ def _print_Derivative(self, expr): - return "".join([str(expr.args[0].func)] + [order * (DIFF_SYMBOL + str(key)) for key, order in expr.args[1:]]) + return "".join( + [str(expr.args[0].func)] + + [order * (DIFF_SYMBOL + str(key)) for key, order in expr.args[1:]] + ) def sstr(expr, **settings): diff --git a/idrlnet/variable.py b/idrlnet/variable.py index 82514f5..afc1755 100644 --- a/idrlnet/variable.py +++ b/idrlnet/variable.py @@ -13,14 +13,14 @@ from collections import defaultdict import pandas as pd from idrlnet.header import DIFF_SYMBOL -__all__ = ['Loss', 'Variables', 'DomainVariables', 'export_var'] +__all__ = ["Loss", "Variables", "DomainVariables", "export_var"] class Loss(enum.Enum): """Enumerate loss functions""" - L1 = 'L1' - square = 'square' + L1 = "L1" + square = "square" class LossFunction: @@ -35,56 +35,67 @@ class LossFunction: raise NotImplementedError(f"loss function {loss_function} is not defined!") @staticmethod - def weighted_L1_loss(variables: 'Variables', name: str) -> 'Variables': - loss = 0. + def weighted_L1_loss(variables: "Variables", name: str) -> "Variables": + loss = 0.0 for key, val in variables.items(): - if key.startswith("lambda_") or key == 'area': + if key.startswith("lambda_") or key == "area": continue elif "lambda_" + key in variables.keys(): - loss += torch.sum((torch.abs(val)) * variables["lambda_" + key] * variables["area"]) + loss += torch.sum( + (torch.abs(val)) * variables["lambda_" + key] * variables["area"] + ) else: loss += torch.sum((torch.abs(val)) * variables["area"]) return Variables({name: loss}) @staticmethod - def weighted_square_loss(variables: 'Variables', name: str) -> 'Variables': - loss = 0. + def weighted_square_loss(variables: "Variables", name: str) -> "Variables": + loss = 0.0 for key, val in variables.items(): - if key.startswith("lambda_") or key == 'area': + if key.startswith("lambda_") or key == "area": continue elif "lambda_" + key in variables.keys(): - loss += torch.sum((val ** 2) * variables["lambda_" + key] * variables["area"]) + loss += torch.sum( + (val ** 2) * variables["lambda_" + key] * variables["area"] + ) else: loss += torch.sum((val ** 2) * variables["area"]) return Variables({name: loss}) class Variables(dict): - def __sub__(self, other: 'Variables') -> 'Variables': + def __sub__(self, other: "Variables") -> "Variables": return Variables( - {key: (self[key] if key in self else 0) - (other[key] if key in other else 0) for key in {**self, **other}}) + { + key: (self[key] if key in self else 0) + - (other[key] if key in other else 0) + for key in {**self, **other} + } + ) - def weighted_loss(self, name: str, loss_function: Union[Loss, str]) -> 'Variables': + def weighted_loss(self, name: str, loss_function: Union[Loss, str]) -> "Variables": """Regard the variable as residuals and reduce to a weighted_loss.""" - return LossFunction.weighted_loss(variables=self, loss_function=loss_function, name=name) + return LossFunction.weighted_loss( + variables=self, loss_function=loss_function, name=name + ) - def subset(self, subset_keys: List[str]) -> 'Variables': + def subset(self, subset_keys: List[str]) -> "Variables": """Construct a new variable with subset references""" return Variables({name: self[name] for name in subset_keys if name in self}) - def to_torch_tensor_(self) -> 'Variables[str, torch.Tensor]': + def to_torch_tensor_(self) -> "Variables[str, torch.Tensor]": """Convert the variables to torch.Tensor""" for key, val in self.items(): if not isinstance(val, torch.Tensor): self[key] = torch.Tensor(val) - if (not key.startswith('lambda_')) and (not key == 'area'): + if (not key.startswith("lambda_")) and (not key == "area"): self[key].requires_grad_() return self - def to_ndarray_(self) -> 'Variables[str, np.ndarray]': + def to_ndarray_(self) -> "Variables[str, np.ndarray]": """convert to a numpy based variables""" for key, val in self.items(): @@ -92,7 +103,7 @@ class Variables(dict): self[key] = val.detach().cpu().numpy() return self - def to_ndarray(self) -> 'Variables[str, np.ndarray]': + def to_ndarray(self) -> "Variables[str, np.ndarray]": """Return a new numpy based variables""" new_var = Variables() @@ -130,26 +141,36 @@ class Variables(dict): variables[name] = var_t return variables - def differentiate_one_step_(self: 'Variables', independent_var: 'Variables', required_derivatives: List[str]): + def differentiate_one_step_( + self: "Variables", independent_var: "Variables", required_derivatives: List[str] + ): """One order of derivatives will be computed towards the required_derivatives.""" required_derivatives = [d for d in required_derivatives if d not in self] required_derivatives_set = set( - tuple(required_derivative.split(DIFF_SYMBOL)) for required_derivative in required_derivatives) + tuple(required_derivative.split(DIFF_SYMBOL)) + for required_derivative in required_derivatives + ) dependent_var_set = set(tuple(dv.split(DIFF_SYMBOL)) for dv in self.keys()) computable_derivative_dict = defaultdict(set) for dv, rd in itertools.product(dependent_var_set, required_derivatives_set): - if len(rd) > len(dv) and rd[:len(dv)] == dv and rd[:len(dv) + 1] not in dependent_var_set: + if ( + len(rd) > len(dv) + and rd[: len(dv)] == dv + and rd[: len(dv) + 1] not in dependent_var_set + ): computable_derivative_dict[rd[len(dv)]].add(DIFF_SYMBOL.join(dv)) derivative_variables = Variables() for key, value in computable_derivative_dict.items(): for v in value: - f__x = torch.autograd.grad(self[v], - independent_var[key], - grad_outputs=torch.ones_like(self[v]), - retain_graph=True, - create_graph=True, - allow_unused=True)[0] + f__x = torch.autograd.grad( + self[v], + independent_var[key], + grad_outputs=torch.ones_like(self[v]), + retain_graph=True, + create_graph=True, + allow_unused=True, + )[0] if f__x is not None: f__x.requires_grad_() else: @@ -157,7 +178,9 @@ class Variables(dict): derivative_variables[DIFF_SYMBOL.join([v, key])] = f__x self.update(derivative_variables) - def differentiate_(self: 'Variables', independent_var: 'Variables', required_derivatives: List[str]): + def differentiate_( + self: "Variables", independent_var: "Variables", required_derivatives: List[str] + ): """Derivatives will be computed towards the required_derivatives""" n_keys = 0 @@ -168,8 +191,11 @@ class Variables(dict): new_keys = len(self.keys()) @staticmethod - def var_differentiate_one_step(dependent_var: 'Variables', independent_var: 'Variables', - required_derivatives: List[str]): + def var_differentiate_one_step( + dependent_var: "Variables", + independent_var: "Variables", + required_derivatives: List[str], + ): """Perform one step of differentiate towards the required_derivatives""" dependent_var.differentiate_one_step_(independent_var, required_derivatives) @@ -177,15 +203,15 @@ class Variables(dict): def to_csv(self, filename: str) -> None: """Export variable to csv""" - if not filename.endswith('.csv'): - filename += '.csv' + if not filename.endswith(".csv"): + filename += ".csv" df = self.to_dataframe() df.to_csv(filename, index=False) def to_vtu(self, filename: str, coordinates=None) -> None: """Export variable to vtu""" - coordinates = ['x', 'y', 'z'] if coordinates is None else coordinates + coordinates = ["x", "y", "z"] if coordinates is None else coordinates shape = 0 for axis in coordinates: if axis not in self.keys(): @@ -196,27 +222,29 @@ class Variables(dict): if value.shape == (1, 1): self[key] = np.ones(shape) * value self[key] = np.asarray(self[key], dtype=np.float64) - pointsToVTK(filename, - self[coordinates[0]][:, 0].copy(), - self[coordinates[1]][:, 0].copy(), - self[coordinates[2]][:, 0].copy(), - data={key: value[:, 0].copy() for key, value in self.items()}) + pointsToVTK( + filename, + self[coordinates[0]][:, 0].copy(), + self[coordinates[1]][:, 0].copy(), + self[coordinates[2]][:, 0].copy(), + data={key: value[:, 0].copy() for key, value in self.items()}, + ) def save(self, path, formats=None): """Export variable to various formats""" if formats is None: - formats = ['np', 'csv', 'vtu'] + formats = ["np", "csv", "vtu"] np_var = self.to_ndarray() - if 'np' in formats: + if "np" in formats: np.savez(path, **np_var) - if 'csv' in formats: + if "csv" in formats: np_var.to_csv(path) - if 'vtu' in formats: + if "vtu" in formats: np_var.to_vtu(filename=path) @staticmethod - def cat(*var_list) -> 'Variables': + def cat(*var_list) -> "Variables": """todo: catenate in var list""" return Variables() @@ -224,12 +252,14 @@ class Variables(dict): DomainVariables = Dict[str, Variables] -def export_var(domain_var: DomainVariables, path='./inference_domain/results', formats=None): +def export_var( + domain_var: DomainVariables, path="./inference_domain/results", formats=None +): """Export a dict of variables to ``csv``, ``vtu`` or ``npz``.""" if formats is None: - formats = ['csv', 'vtu', 'np'] + formats = ["csv", "vtu", "np"] path = pathlib.Path(path) path.mkdir(exist_ok=True, parents=True) for key in domain_var.keys(): - domain_var[key].save(os.path.join(path, f'{key}'), formats) + domain_var[key].save(os.path.join(path, f"{key}"), formats)