idrlnet/examples/inverse_wave_equation/inverse_wave_equation.py

131 lines
4.0 KiB
Python

import idrlnet.shortcut as sc
from math import pi
from sympy import Symbol
import torch
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
L = float(pi)
geo = sc.Line1D(0, L)
t_symbol = Symbol("t")
x = Symbol("x")
time_range = {t_symbol: (0, 2 * L)}
c = 1.54
external_filename = "external_sample.csv"
def generate_observed_data():
if os.path.exists(external_filename):
return
points = geo.sample_interior(
density=20, bounds={x: (0, L)}, param_ranges=time_range, low_discrepancy=True
)
points["u"] = np.sin(points["x"]) * (
np.sin(c * points["t"]) + np.cos(c * points["t"])
)
points["u"][np.random.choice(len(points["u"]), 10, replace=False)] = 3.0
points = {k: v.ravel() for k, v in points.items()}
points = pd.DataFrame.from_dict(points)
points.to_csv("external_sample.csv", index=False)
generate_observed_data()
# @sc.datanode(name='wave_domain')
@sc.datanode(name="wave_domain", loss_fn="L1")
class WaveExternal(sc.SampleDomain):
def __init__(self):
points = pd.read_csv("external_sample.csv")
self.points = {
col: points[col].to_numpy().reshape(-1, 1) for col in points.columns
}
self.constraints = {"u": self.points.pop("u")}
def sampling(self, *args, **kwargs):
return self.points, self.constraints
@sc.datanode(name="wave_external")
class WaveEq(sc.SampleDomain):
def sampling(self, *args, **kwargs):
points = geo.sample_interior(
density=1000, bounds={x: (0, L)}, param_ranges=time_range
)
constraints = {"wave_equation": 0.0}
return points, constraints
@sc.datanode(name="center_infer")
class CenterInfer(sc.SampleDomain):
def __init__(self):
self.points = sc.Variables()
self.points["t"] = np.linspace(0, 2 * L, 200).reshape(-1, 1)
self.points["x"] = np.ones_like(self.points["t"]) * L / 2
self.points["area"] = np.ones_like(self.points["t"])
def sampling(self, *args, **kwargs):
return self.points, {}
net = sc.get_net_node(
inputs=(
"x",
"t",
),
outputs=("u",),
name="net1",
arch=sc.Arch.mlp,
)
var_c = sc.get_net_node(inputs=("x",), outputs=("c",), arch=sc.Arch.single_var)
pde = sc.WaveNode(c="c", dim=1, time=True, u="u")
s = sc.Solver(
sample_domains=(WaveExternal(), WaveEq()),
netnodes=[net, var_c],
pdes=[pde],
# network_dir='square_network_dir',
network_dir="network_dir",
max_iter=5000,
)
s.solve()
_, ax = plt.subplots(1, 1, figsize=(8, 4))
coord = s.infer_step(domain_attr={"wave_domain": ["x", "t", "u"]})
num_t = coord["wave_domain"]["t"].cpu().detach().numpy().ravel()
num_u = coord["wave_domain"]["u"].cpu().detach().numpy().ravel()
ax.scatter(num_t, num_u, c="r", marker="o", label="predicted points")
print("true paratmeter c: {:.4f}".format(c))
predict_c = var_c.evaluate(torch.Tensor([[1.0]])).item()
print("predicted parameter c: {:.4f}".format(predict_c))
num_t = WaveExternal().sample_fn.points["t"].ravel()
num_u = WaveExternal().sample_fn.constraints["u"].ravel()
ax.scatter(num_t, num_u, c="b", marker="x", label="observed points")
s.sample_domains = (CenterInfer(),)
points = s.infer_step({"center_infer": ["t", "x", "u"]})
num_t = points["center_infer"]["t"].cpu().detach().numpy().ravel()
num_u = points["center_infer"]["u"].cpu().detach().numpy().ravel()
num_x = points["center_infer"]["x"].cpu().detach().numpy().ravel()
ax.plot(
num_t, np.sin(num_x) * (np.sin(c * num_t) + np.cos(c * num_t)), c="k", label="exact"
)
ax.plot(num_t, num_u, "--", c="g", linewidth=4, label="predict")
ax.legend()
ax.set_xlabel("t")
ax.set_ylabel("u")
# ax.set_title(f'Square loss ($x=0.5L$, c={predict_c:.4f}))')
ax.set_title(f"L1 loss ($x=0.5L$, c={predict_c:.4f})")
ax.grid(True)
ax.set_xlim([-0.5, 6.5])
ax.set_ylim([-3.5, 4.5])
# plt.savefig('square.png', dpi=1000, bbox_inches='tight', pad_inches=0.02)
plt.savefig("L1.png", dpi=1000, bbox_inches="tight", pad_inches=0.02)
plt.show()
plt.close()