idrlnet/examples/euler_beam/euler_beam.py

87 lines
2.4 KiB
Python
Raw Permalink Normal View History

2021-07-05 11:18:12 +08:00
import matplotlib.pyplot as plt
import sympy as sp
import numpy as np
import idrlnet.shortcut as sc
2021-07-13 10:39:09 +08:00
x = sp.symbols("x")
2021-07-05 11:18:12 +08:00
Line = sc.Line1D(0, 1)
2021-07-13 10:39:09 +08:00
y = sp.Function("y")(x)
2021-07-05 11:18:12 +08:00
2021-07-13 10:39:09 +08:00
@sc.datanode(name="interior")
2021-07-05 11:18:12 +08:00
class Interior(sc.SampleDomain):
def sampling(self, *args, **kwargs):
2021-07-13 10:39:09 +08:00
return Line.sample_interior(1000), {"dddd_y": 0}
2021-07-05 11:18:12 +08:00
2021-07-13 10:39:09 +08:00
@sc.datanode(name="left_boundary1")
2021-07-05 11:18:12 +08:00
class LeftBoundary1(sc.SampleDomain):
def sampling(self, *args, **kwargs):
2021-07-13 10:39:09 +08:00
return Line.sample_boundary(100, sieve=(sp.Eq(x, 0))), {"y": 0}
2021-07-05 11:18:12 +08:00
2021-07-13 10:39:09 +08:00
@sc.datanode(name="left_boundary2")
2021-07-05 11:18:12 +08:00
class LeftBoundary2(sc.SampleDomain):
def sampling(self, *args, **kwargs):
2021-07-13 10:39:09 +08:00
return Line.sample_boundary(100, sieve=(sp.Eq(x, 0))), {"d_y": 0}
2021-07-05 11:18:12 +08:00
2021-07-13 10:39:09 +08:00
@sc.datanode(name="right_boundary1")
2021-07-05 11:18:12 +08:00
class RightBoundary1(sc.SampleDomain):
def sampling(self, *args, **kwargs):
2021-07-13 10:39:09 +08:00
return Line.sample_boundary(100, sieve=(sp.Eq(x, 1))), {"dd_y": 0}
2021-07-05 11:18:12 +08:00
2021-07-13 10:39:09 +08:00
@sc.datanode(name="right_boundary2")
2021-07-05 11:18:12 +08:00
class RightBoundary2(sc.SampleDomain):
def sampling(self, *args, **kwargs):
2021-07-13 10:39:09 +08:00
return Line.sample_boundary(100, sieve=(sp.Eq(x, 1))), {"ddd_y": 0}
2021-07-05 11:18:12 +08:00
2021-07-13 10:39:09 +08:00
@sc.datanode(name="infer")
2021-07-05 11:18:12 +08:00
class Infer(sc.SampleDomain):
def sampling(self, *args, **kwargs):
2021-07-13 10:39:09 +08:00
return {"x": np.linspace(0, 1, 1000).reshape(-1, 1)}, {}
2021-07-05 11:18:12 +08:00
2021-07-13 10:39:09 +08:00
net = sc.get_net_node(inputs=("x",), outputs=("y",), name="net", arch=sc.Arch.mlp)
2021-07-05 11:18:12 +08:00
2021-07-13 10:39:09 +08:00
pde1 = sc.ExpressionNode(
name="dddd_y", expression=y.diff(x).diff(x).diff(x).diff(x) + 1
)
pde2 = sc.ExpressionNode(name="d_y", expression=y.diff(x))
pde3 = sc.ExpressionNode(name="dd_y", expression=y.diff(x).diff(x))
pde4 = sc.ExpressionNode(name="ddd_y", expression=y.diff(x).diff(x).diff(x))
2021-07-05 11:18:12 +08:00
solver = sc.Solver(
2021-07-13 10:39:09 +08:00
sample_domains=(
Interior(),
LeftBoundary1(),
LeftBoundary2(),
RightBoundary1(),
RightBoundary2(),
),
2021-07-05 11:18:12 +08:00
netnodes=[net],
pdes=[pde1, pde2, pde3, pde4],
2023-06-29 10:46:16 +08:00
max_iter=2000)
2021-07-05 11:18:12 +08:00
solver.solve()
# inference
def exact(x):
return -(x ** 4) / 24 + x ** 3 / 6 - x ** 2 / 4
solver.sample_domains = (Infer(),)
2021-07-13 10:39:09 +08:00
points = solver.infer_step({"infer": ["x", "y"]})
xs = points["infer"]["x"].detach().cpu().numpy().ravel()
y_pred = points["infer"]["y"].detach().cpu().numpy().ravel()
plt.plot(xs, y_pred, label="Pred")
2021-07-05 11:18:12 +08:00
y_exact = exact(xs)
2021-07-13 10:39:09 +08:00
plt.plot(xs, y_exact, label="Exact", linestyle="--")
2021-07-05 11:18:12 +08:00
plt.legend()
2021-07-13 10:39:09 +08:00
plt.xlabel("x")
plt.ylabel("w")
plt.savefig("Euler_beam.png", dpi=300, bbox_inches="tight")
2021-07-05 11:18:12 +08:00
plt.show()