idrlnet/examples/allen_cahn/allen_cahn.py

215 lines
6.2 KiB
Python
Raw Permalink Normal View History

2021-07-05 11:18:12 +08:00
from sympy import Symbol
import sympy as sp
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import idrlnet.shortcut as sc
import os
import torch
# parameter phase
2021-07-12 16:03:42 +08:00
L = 1.0
2021-07-05 11:18:12 +08:00
# define geometry
geo = sc.Line1D(-1.0, 1.0)
# define sympy varaibles to parametize domain curves
2021-07-12 16:03:42 +08:00
t_symbol = Symbol("t")
x = Symbol("x")
u = sp.Function("u")(x, t_symbol)
up = sp.Function("up")(x, t_symbol)
2021-07-05 11:18:12 +08:00
time_range = {t_symbol: (0, L)}
# constraint phase
@sc.datanode
class AllenInit(sc.SampleDomain):
def sampling(self, *args, **kwargs):
2021-07-12 16:03:42 +08:00
return geo.sample_interior(density=300, param_ranges={t_symbol: 0.0}), {
"u": x ** 2 * sp.cos(sp.pi * x),
"lambda_u": 100,
}
2021-07-05 11:18:12 +08:00
@sc.datanode
class AllenBc(sc.SampleDomain):
def sampling(self, *args, **kwargs):
2021-07-12 16:03:42 +08:00
return geo.sample_boundary(
density=200, sieve=sp.Eq(x, -1), param_ranges=time_range
), {
"difference_u_up": 0,
"difference_diff_u_diff_up": 0,
}
2021-07-05 11:18:12 +08:00
2021-07-12 16:03:42 +08:00
@sc.datanode(name="allen_domain")
2021-07-05 11:18:12 +08:00
class AllenEq(sc.SampleDomain):
def __init__(self):
2021-07-12 16:03:42 +08:00
self.points = geo.sample_interior(
density=2000, param_ranges=time_range, low_discrepancy=True
)
2021-07-05 11:18:12 +08:00
def sampling(self, *args, **kwargs):
2021-07-12 16:03:42 +08:00
constraints = {"AllenCahn_u": 0}
2021-07-05 11:18:12 +08:00
return self.points, constraints
2021-07-12 16:03:42 +08:00
@sc.datanode(name="data_evaluate")
2021-07-05 11:18:12 +08:00
class AllenPointsInference(sc.SampleDomain):
def __init__(self):
2021-07-12 16:03:42 +08:00
self.points = geo.sample_interior(
density=5000, param_ranges=time_range, low_discrepancy=True
)
2021-07-05 11:18:12 +08:00
self.points = sc.Variables(self.points).to_torch_tensor_()
2021-07-12 16:03:42 +08:00
self.constraints = {"AllenCahn_u": torch.zeros_like(self.points["x"])}
2021-07-05 11:18:12 +08:00
def sampling(self, *args, **kwargs):
return self.points, self.constraints
2021-07-12 16:03:42 +08:00
@sc.datanode(name="re_sampling_domain")
2021-07-05 11:18:12 +08:00
class SpaceAdaptiveSampling(sc.SampleDomain):
def __init__(self):
2021-07-12 16:03:42 +08:00
self.points = geo.sample_interior(
density=100, param_ranges=time_range, low_discrepancy=True
)
2021-07-05 11:18:12 +08:00
self.points = sc.Variables(self.points).to_torch_tensor_()
2021-07-12 16:03:42 +08:00
self.constraints = {"AllenCahn_u": torch.zeros_like(self.points["x"])}
2021-07-05 11:18:12 +08:00
def sampling(self, *args, **kwargs):
return self.points, self.constraints
2021-07-12 16:03:42 +08:00
@sc.datanode(name="allen_test")
2021-07-05 11:18:12 +08:00
def generate_plot_data():
x = np.linspace(-1.0, 1.0, 100)
t = np.linspace(0, 1.0, 100)
x, t = np.meshgrid(x, t)
points = sc.Variables(x=x.reshape(-1, 1), t=t.reshape(-1, 1))
return points, {}
# computational node phase
net_u = sc.MLP([2, 128, 128, 128, 128, 2], activation=sc.Activation.tanh)
2021-07-12 16:03:42 +08:00
net_u = sc.NetNode(
inputs=(
"x",
"t",
),
outputs=("u",),
name="net1",
net=net_u,
)
xp = sc.ExpressionNode(name="xp", expression=x + 2)
get_tilde_u = sc.get_shared_net_node(
net_u,
inputs=(
"xp",
"t",
),
outputs=("up",),
name="net2",
arch="mlp",
)
diff_u = sc.ExpressionNode(expression=u.diff(x), name="diff_u")
diff_up = sc.ExpressionNode(expression=up.diff(x), name="diff_up")
pde = sc.AllenCahnNode(u="u", gamma_1=0.0001, gamma_2=5)
boundary_up = sc.Difference(T="diff_u", S="diff_up")
boundary_u = sc.Difference(T="u", S="up")
2021-07-05 11:18:12 +08:00
# Receiver hook phase
2021-07-12 16:03:42 +08:00
2021-07-05 11:18:12 +08:00
class SpaceAdaptiveReceiver(sc.Receiver):
def receive_notify(self, solver, message):
2021-07-12 16:03:42 +08:00
if (
sc.Signal.TRAIN_PIPE_END in message.keys()
and solver.global_step % 1000 == 0
):
sc.logger.info("space adaptive sampling...")
results = solver.infer_step(
{"data_evaluate": ["x", "t", "sdf", "AllenCahn_u"]}
)
residual_data = (
results["data_evaluate"]["AllenCahn_u"].detach().cpu().numpy().ravel()
)
2021-07-05 11:18:12 +08:00
# sort the points by residual loss
2021-07-12 16:03:42 +08:00
index = np.argsort(-1.0 * np.abs(residual_data))[:200]
_points = {
key: values[index].detach().cpu().numpy()
for key, values in results["data_evaluate"].items()
}
_points.pop("AllenCahn_u")
_points["area"] = np.zeros_like(_points["sdf"]) + (1.0 / 200)
solver.set_domain_parameter("re_sampling_domain", {"points": _points})
2021-07-05 11:18:12 +08:00
class PostProcessReceiver(sc.Receiver):
def __init__(self):
2021-07-12 16:03:42 +08:00
if not os.path.exists("image"):
os.mkdir("image")
2021-07-05 11:18:12 +08:00
def receive_notify(self, solver, message):
2021-07-12 16:03:42 +08:00
if (
sc.Signal.TRAIN_PIPE_END in message.keys()
and solver.global_step % 1000 == 1
):
sc.logger.info("Post Processing...")
points = s.infer_step({"allen_test": ["x", "t", "u"]})
triang_total = tri.Triangulation(
points["allen_test"]["t"].detach().cpu().numpy().ravel(),
points["allen_test"]["x"].detach().cpu().numpy().ravel(),
)
plt.tricontourf(
triang_total,
points["allen_test"]["u"].detach().cpu().numpy().ravel(),
100,
vmin=-1,
vmax=1,
)
2021-07-05 11:18:12 +08:00
tc_bar = plt.colorbar()
tc_bar.ax.tick_params(labelsize=12)
2021-07-12 16:03:42 +08:00
_points = solver.get_domain_parameter("re_sampling_domain", "points")
if not isinstance(_points["t"], torch.Tensor):
plt.scatter(_points["t"].ravel(), _points["x"].ravel(), marker="x", s=8)
2021-07-05 11:18:12 +08:00
else:
2021-07-12 16:03:42 +08:00
plt.scatter(
_points["t"].detach().cpu().numpy().ravel(),
_points["x"].detach().cpu().numpy().ravel(),
marker="x",
s=8,
)
plt.xlabel("$t$")
plt.ylabel("$x$")
plt.title("$u(x,t)$")
plt.savefig(f"image/result_{solver.global_step}.png")
2021-07-05 11:18:12 +08:00
plt.show()
# Solver phase
2021-07-12 16:03:42 +08:00
s = sc.Solver(
sample_domains=(
AllenInit(),
AllenBc(),
AllenEq(),
AllenPointsInference(),
SpaceAdaptiveSampling(),
generate_plot_data(),
),
netnodes=[net_u, get_tilde_u],
pdes=[pde, xp, diff_up, diff_u, boundary_up, boundary_u],
max_iter=60000,
loading=True,
)
2021-07-05 11:18:12 +08:00
s.register_receiver(SpaceAdaptiveReceiver())
s.register_receiver(PostProcessReceiver())
s.solve()