idrlnet/examples/allen_cahn/allen_cahn.py

215 lines
6.2 KiB
Python

from sympy import Symbol
import sympy as sp
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import idrlnet.shortcut as sc
import os
import torch
# parameter phase
L = 1.0
# define geometry
geo = sc.Line1D(-1.0, 1.0)
# define sympy varaibles to parametize domain curves
t_symbol = Symbol("t")
x = Symbol("x")
u = sp.Function("u")(x, t_symbol)
up = sp.Function("up")(x, t_symbol)
time_range = {t_symbol: (0, L)}
# constraint phase
@sc.datanode
class AllenInit(sc.SampleDomain):
def sampling(self, *args, **kwargs):
return geo.sample_interior(density=300, param_ranges={t_symbol: 0.0}), {
"u": x ** 2 * sp.cos(sp.pi * x),
"lambda_u": 100,
}
@sc.datanode
class AllenBc(sc.SampleDomain):
def sampling(self, *args, **kwargs):
return geo.sample_boundary(
density=200, sieve=sp.Eq(x, -1), param_ranges=time_range
), {
"difference_u_up": 0,
"difference_diff_u_diff_up": 0,
}
@sc.datanode(name="allen_domain")
class AllenEq(sc.SampleDomain):
def __init__(self):
self.points = geo.sample_interior(
density=2000, param_ranges=time_range, low_discrepancy=True
)
def sampling(self, *args, **kwargs):
constraints = {"AllenCahn_u": 0}
return self.points, constraints
@sc.datanode(name="data_evaluate")
class AllenPointsInference(sc.SampleDomain):
def __init__(self):
self.points = geo.sample_interior(
density=5000, param_ranges=time_range, low_discrepancy=True
)
self.points = sc.Variables(self.points).to_torch_tensor_()
self.constraints = {"AllenCahn_u": torch.zeros_like(self.points["x"])}
def sampling(self, *args, **kwargs):
return self.points, self.constraints
@sc.datanode(name="re_sampling_domain")
class SpaceAdaptiveSampling(sc.SampleDomain):
def __init__(self):
self.points = geo.sample_interior(
density=100, param_ranges=time_range, low_discrepancy=True
)
self.points = sc.Variables(self.points).to_torch_tensor_()
self.constraints = {"AllenCahn_u": torch.zeros_like(self.points["x"])}
def sampling(self, *args, **kwargs):
return self.points, self.constraints
@sc.datanode(name="allen_test")
def generate_plot_data():
x = np.linspace(-1.0, 1.0, 100)
t = np.linspace(0, 1.0, 100)
x, t = np.meshgrid(x, t)
points = sc.Variables(x=x.reshape(-1, 1), t=t.reshape(-1, 1))
return points, {}
# computational node phase
net_u = sc.MLP([2, 128, 128, 128, 128, 2], activation=sc.Activation.tanh)
net_u = sc.NetNode(
inputs=(
"x",
"t",
),
outputs=("u",),
name="net1",
net=net_u,
)
xp = sc.ExpressionNode(name="xp", expression=x + 2)
get_tilde_u = sc.get_shared_net_node(
net_u,
inputs=(
"xp",
"t",
),
outputs=("up",),
name="net2",
arch="mlp",
)
diff_u = sc.ExpressionNode(expression=u.diff(x), name="diff_u")
diff_up = sc.ExpressionNode(expression=up.diff(x), name="diff_up")
pde = sc.AllenCahnNode(u="u", gamma_1=0.0001, gamma_2=5)
boundary_up = sc.Difference(T="diff_u", S="diff_up")
boundary_u = sc.Difference(T="u", S="up")
# Receiver hook phase
class SpaceAdaptiveReceiver(sc.Receiver):
def receive_notify(self, solver, message):
if (
sc.Signal.TRAIN_PIPE_END in message.keys()
and solver.global_step % 1000 == 0
):
sc.logger.info("space adaptive sampling...")
results = solver.infer_step(
{"data_evaluate": ["x", "t", "sdf", "AllenCahn_u"]}
)
residual_data = (
results["data_evaluate"]["AllenCahn_u"].detach().cpu().numpy().ravel()
)
# sort the points by residual loss
index = np.argsort(-1.0 * np.abs(residual_data))[:200]
_points = {
key: values[index].detach().cpu().numpy()
for key, values in results["data_evaluate"].items()
}
_points.pop("AllenCahn_u")
_points["area"] = np.zeros_like(_points["sdf"]) + (1.0 / 200)
solver.set_domain_parameter("re_sampling_domain", {"points": _points})
class PostProcessReceiver(sc.Receiver):
def __init__(self):
if not os.path.exists("image"):
os.mkdir("image")
def receive_notify(self, solver, message):
if (
sc.Signal.TRAIN_PIPE_END in message.keys()
and solver.global_step % 1000 == 1
):
sc.logger.info("Post Processing...")
points = s.infer_step({"allen_test": ["x", "t", "u"]})
triang_total = tri.Triangulation(
points["allen_test"]["t"].detach().cpu().numpy().ravel(),
points["allen_test"]["x"].detach().cpu().numpy().ravel(),
)
plt.tricontourf(
triang_total,
points["allen_test"]["u"].detach().cpu().numpy().ravel(),
100,
vmin=-1,
vmax=1,
)
tc_bar = plt.colorbar()
tc_bar.ax.tick_params(labelsize=12)
_points = solver.get_domain_parameter("re_sampling_domain", "points")
if not isinstance(_points["t"], torch.Tensor):
plt.scatter(_points["t"].ravel(), _points["x"].ravel(), marker="x", s=8)
else:
plt.scatter(
_points["t"].detach().cpu().numpy().ravel(),
_points["x"].detach().cpu().numpy().ravel(),
marker="x",
s=8,
)
plt.xlabel("$t$")
plt.ylabel("$x$")
plt.title("$u(x,t)$")
plt.savefig(f"image/result_{solver.global_step}.png")
plt.show()
# Solver phase
s = sc.Solver(
sample_domains=(
AllenInit(),
AllenBc(),
AllenEq(),
AllenPointsInference(),
SpaceAdaptiveSampling(),
generate_plot_data(),
),
netnodes=[net_u, get_tilde_u],
pdes=[pde, xp, diff_up, diff_u, boundary_up, boundary_u],
max_iter=60000,
loading=True,
)
s.register_receiver(SpaceAdaptiveReceiver())
s.register_receiver(PostProcessReceiver())
s.solve()