forked from idrl/idrlnet
114 lines
4.0 KiB
Python
114 lines
4.0 KiB
Python
|
import idrlnet.shortcut as sc
|
||
|
from math import pi
|
||
|
from sympy import Symbol
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
import os
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
L = float(pi)
|
||
|
|
||
|
geo = sc.Line1D(0, L)
|
||
|
t_symbol = Symbol('t')
|
||
|
x = Symbol('x')
|
||
|
time_range = {t_symbol: (0, 2 * L)}
|
||
|
c = 1.54
|
||
|
external_filename = 'external_sample.csv'
|
||
|
|
||
|
|
||
|
def generate_observed_data():
|
||
|
if os.path.exists(external_filename):
|
||
|
return
|
||
|
points = geo.sample_interior(density=20,
|
||
|
bounds={x: (0, L)},
|
||
|
param_ranges=time_range,
|
||
|
low_discrepancy=True)
|
||
|
points['u'] = np.sin(points['x']) * (np.sin(c * points['t']) + np.cos(c * points['t']))
|
||
|
points['u'][np.random.choice(len(points['u']), 10, replace=False)] = 3.
|
||
|
points = {k: v.ravel() for k, v in points.items()}
|
||
|
points = pd.DataFrame.from_dict(points)
|
||
|
points.to_csv('external_sample.csv', index=False)
|
||
|
|
||
|
|
||
|
generate_observed_data()
|
||
|
|
||
|
|
||
|
# @sc.datanode(name='wave_domain')
|
||
|
@sc.datanode(name='wave_domain', loss_fn='L1')
|
||
|
class WaveExternal(sc.SampleDomain):
|
||
|
def __init__(self):
|
||
|
points = pd.read_csv('external_sample.csv')
|
||
|
self.points = {col: points[col].to_numpy().reshape(-1, 1) for col in points.columns}
|
||
|
self.constraints = {'u': self.points.pop('u')}
|
||
|
|
||
|
def sampling(self, *args, **kwargs):
|
||
|
return self.points, self.constraints
|
||
|
|
||
|
|
||
|
@sc.datanode(name='wave_external')
|
||
|
class WaveEq(sc.SampleDomain):
|
||
|
def sampling(self, *args, **kwargs):
|
||
|
points = geo.sample_interior(density=1000, bounds={x: (0, L)}, param_ranges=time_range)
|
||
|
constraints = {'wave_equation': 0.}
|
||
|
return points, constraints
|
||
|
|
||
|
|
||
|
@sc.datanode(name='center_infer')
|
||
|
class CenterInfer(sc.SampleDomain):
|
||
|
def __init__(self):
|
||
|
self.points = sc.Variables()
|
||
|
self.points['t'] = np.linspace(0, 2 * L, 200).reshape(-1, 1)
|
||
|
self.points['x'] = np.ones_like(self.points['t']) * L / 2
|
||
|
self.points['area'] = np.ones_like(self.points['t'])
|
||
|
|
||
|
def sampling(self, *args, **kwargs):
|
||
|
return self.points, {}
|
||
|
|
||
|
|
||
|
net = sc.get_net_node(inputs=('x', 't',), outputs=('u',), name='net1', arch=sc.Arch.mlp)
|
||
|
var_c = sc.get_net_node(inputs=('x',), outputs=('c',), arch=sc.Arch.single_var)
|
||
|
pde = sc.WaveNode(c='c', dim=1, time=True, u='u')
|
||
|
s = sc.Solver(sample_domains=(WaveExternal(), WaveEq()),
|
||
|
netnodes=[net, var_c],
|
||
|
pdes=[pde],
|
||
|
# network_dir='square_network_dir',
|
||
|
network_dir='network_dir',
|
||
|
max_iter=5000)
|
||
|
s.solve()
|
||
|
|
||
|
_, ax = plt.subplots(1, 1, figsize=(8, 4))
|
||
|
|
||
|
coord = s.infer_step(domain_attr={'wave_domain': ['x', 't', 'u']})
|
||
|
num_t = coord['wave_domain']['t'].cpu().detach().numpy().ravel()
|
||
|
num_u = coord['wave_domain']['u'].cpu().detach().numpy().ravel()
|
||
|
ax.scatter(num_t, num_u, c='r', marker='o', label='predicted points')
|
||
|
|
||
|
print("true paratmeter c: {:.4f}".format(c))
|
||
|
predict_c = var_c.evaluate(torch.Tensor([[1.0]])).item()
|
||
|
print("predicted parameter c: {:.4f}".format(predict_c))
|
||
|
|
||
|
num_t = WaveExternal().sample_fn.points['t'].ravel()
|
||
|
num_u = WaveExternal().sample_fn.constraints['u'].ravel()
|
||
|
ax.scatter(num_t, num_u, c='b', marker='x', label='observed points')
|
||
|
|
||
|
s.sample_domains = (CenterInfer(),)
|
||
|
points = s.infer_step({'center_infer': ['t', 'x', 'u']})
|
||
|
num_t = points['center_infer']['t'].cpu().detach().numpy().ravel()
|
||
|
num_u = points['center_infer']['u'].cpu().detach().numpy().ravel()
|
||
|
num_x = points['center_infer']['x'].cpu().detach().numpy().ravel()
|
||
|
ax.plot(num_t, np.sin(num_x) * (np.sin(c * num_t) + np.cos(c * num_t)), c='k', label='exact')
|
||
|
ax.plot(num_t, num_u, '--', c='g', linewidth=4, label='predict')
|
||
|
ax.legend()
|
||
|
ax.set_xlabel('t')
|
||
|
ax.set_ylabel('u')
|
||
|
# ax.set_title(f'Square loss ($x=0.5L$, c={predict_c:.4f}))')
|
||
|
ax.set_title(f'L1 loss ($x=0.5L$, c={predict_c:.4f})')
|
||
|
ax.grid(True)
|
||
|
ax.set_xlim([-0.5, 6.5])
|
||
|
ax.set_ylim([-3.5, 4.5])
|
||
|
# plt.savefig('square.png', dpi=1000, bbox_inches='tight', pad_inches=0.02)
|
||
|
plt.savefig('L1.png', dpi=1000, bbox_inches='tight', pad_inches=0.02)
|
||
|
plt.show()
|
||
|
plt.close()
|