idrlnet/examples/inverse_wave_equation/inverse_wave_equation.py

131 lines
4.0 KiB
Python
Raw Permalink Normal View History

2021-07-05 11:18:12 +08:00
import idrlnet.shortcut as sc
from math import pi
from sympy import Symbol
import torch
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
L = float(pi)
geo = sc.Line1D(0, L)
2021-07-13 10:39:09 +08:00
t_symbol = Symbol("t")
x = Symbol("x")
2021-07-05 11:18:12 +08:00
time_range = {t_symbol: (0, 2 * L)}
c = 1.54
2021-07-13 10:39:09 +08:00
external_filename = "external_sample.csv"
2021-07-05 11:18:12 +08:00
def generate_observed_data():
if os.path.exists(external_filename):
return
2021-07-13 10:39:09 +08:00
points = geo.sample_interior(
density=20, bounds={x: (0, L)}, param_ranges=time_range, low_discrepancy=True
)
points["u"] = np.sin(points["x"]) * (
np.sin(c * points["t"]) + np.cos(c * points["t"])
)
points["u"][np.random.choice(len(points["u"]), 10, replace=False)] = 3.0
2021-07-05 11:18:12 +08:00
points = {k: v.ravel() for k, v in points.items()}
points = pd.DataFrame.from_dict(points)
2021-07-13 10:39:09 +08:00
points.to_csv("external_sample.csv", index=False)
2021-07-05 11:18:12 +08:00
generate_observed_data()
# @sc.datanode(name='wave_domain')
2021-07-13 10:39:09 +08:00
@sc.datanode(name="wave_domain", loss_fn="L1")
2021-07-05 11:18:12 +08:00
class WaveExternal(sc.SampleDomain):
def __init__(self):
2021-07-13 10:39:09 +08:00
points = pd.read_csv("external_sample.csv")
self.points = {
col: points[col].to_numpy().reshape(-1, 1) for col in points.columns
}
self.constraints = {"u": self.points.pop("u")}
2021-07-05 11:18:12 +08:00
def sampling(self, *args, **kwargs):
return self.points, self.constraints
2021-07-13 10:39:09 +08:00
@sc.datanode(name="wave_external")
2021-07-05 11:18:12 +08:00
class WaveEq(sc.SampleDomain):
def sampling(self, *args, **kwargs):
2021-07-13 10:39:09 +08:00
points = geo.sample_interior(
density=1000, bounds={x: (0, L)}, param_ranges=time_range
)
constraints = {"wave_equation": 0.0}
2021-07-05 11:18:12 +08:00
return points, constraints
2021-07-13 10:39:09 +08:00
@sc.datanode(name="center_infer")
2021-07-05 11:18:12 +08:00
class CenterInfer(sc.SampleDomain):
def __init__(self):
self.points = sc.Variables()
2021-07-13 10:39:09 +08:00
self.points["t"] = np.linspace(0, 2 * L, 200).reshape(-1, 1)
self.points["x"] = np.ones_like(self.points["t"]) * L / 2
self.points["area"] = np.ones_like(self.points["t"])
2021-07-05 11:18:12 +08:00
def sampling(self, *args, **kwargs):
return self.points, {}
2021-07-13 10:39:09 +08:00
net = sc.get_net_node(
inputs=(
"x",
"t",
),
outputs=("u",),
name="net1",
arch=sc.Arch.mlp,
)
var_c = sc.get_net_node(inputs=("x",), outputs=("c",), arch=sc.Arch.single_var)
pde = sc.WaveNode(c="c", dim=1, time=True, u="u")
s = sc.Solver(
sample_domains=(WaveExternal(), WaveEq()),
netnodes=[net, var_c],
pdes=[pde],
# network_dir='square_network_dir',
network_dir="network_dir",
max_iter=5000,
)
2021-07-05 11:18:12 +08:00
s.solve()
_, ax = plt.subplots(1, 1, figsize=(8, 4))
2021-07-13 10:39:09 +08:00
coord = s.infer_step(domain_attr={"wave_domain": ["x", "t", "u"]})
num_t = coord["wave_domain"]["t"].cpu().detach().numpy().ravel()
num_u = coord["wave_domain"]["u"].cpu().detach().numpy().ravel()
ax.scatter(num_t, num_u, c="r", marker="o", label="predicted points")
2021-07-05 11:18:12 +08:00
print("true paratmeter c: {:.4f}".format(c))
predict_c = var_c.evaluate(torch.Tensor([[1.0]])).item()
print("predicted parameter c: {:.4f}".format(predict_c))
2021-07-13 10:39:09 +08:00
num_t = WaveExternal().sample_fn.points["t"].ravel()
num_u = WaveExternal().sample_fn.constraints["u"].ravel()
ax.scatter(num_t, num_u, c="b", marker="x", label="observed points")
2021-07-05 11:18:12 +08:00
s.sample_domains = (CenterInfer(),)
2021-07-13 10:39:09 +08:00
points = s.infer_step({"center_infer": ["t", "x", "u"]})
num_t = points["center_infer"]["t"].cpu().detach().numpy().ravel()
num_u = points["center_infer"]["u"].cpu().detach().numpy().ravel()
num_x = points["center_infer"]["x"].cpu().detach().numpy().ravel()
ax.plot(
num_t, np.sin(num_x) * (np.sin(c * num_t) + np.cos(c * num_t)), c="k", label="exact"
)
ax.plot(num_t, num_u, "--", c="g", linewidth=4, label="predict")
2021-07-05 11:18:12 +08:00
ax.legend()
2021-07-13 10:39:09 +08:00
ax.set_xlabel("t")
ax.set_ylabel("u")
2021-07-05 11:18:12 +08:00
# ax.set_title(f'Square loss ($x=0.5L$, c={predict_c:.4f}))')
2021-07-13 10:39:09 +08:00
ax.set_title(f"L1 loss ($x=0.5L$, c={predict_c:.4f})")
2021-07-05 11:18:12 +08:00
ax.grid(True)
ax.set_xlim([-0.5, 6.5])
ax.set_ylim([-3.5, 4.5])
# plt.savefig('square.png', dpi=1000, bbox_inches='tight', pad_inches=0.02)
2021-07-13 10:39:09 +08:00
plt.savefig("L1.png", dpi=1000, bbox_inches="tight", pad_inches=0.02)
2021-07-05 11:18:12 +08:00
plt.show()
plt.close()