Compare commits

...

2 Commits

Author SHA1 Message Date
weipengOO98 403238c9ee no message 2023-06-30 14:27:11 +08:00
weipengOO98 6267eda37f no message 2023-06-29 10:46:16 +08:00
14 changed files with 943 additions and 18 deletions

View File

@ -53,7 +53,7 @@ pip install idrlnet
### From Source ### From Source
``` ```
git clone https://osredm.com/idrl/idrlnet.git git clone https://github.com/idrl-lab/idrlnet
cd idrlnet cd idrlnet
pip install -e . pip install -e .
``` ```
@ -100,7 +100,7 @@ It is also easy to customize IDRLnet to meet new demands.
First off, thanks for taking the time to contribute! First off, thanks for taking the time to contribute!
- **Reporting bugs.** To report a bug, simply open an issue(疑修) in the osredm "Issues" section. - **Reporting bugs.** To report a bug, simply open an issue in the GitHub "Issues" section.
- **Suggesting enhancements.** To submit an enhancement suggestion for IDRLnet, including completely new features and minor improvements to existing functionality, let us know by opening an issue. - **Suggesting enhancements.** To submit an enhancement suggestion for IDRLnet, including completely new features and minor improvements to existing functionality, let us know by opening an issue.

View File

@ -18,11 +18,11 @@ sys.path.insert(0, os.path.abspath(".."))
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = "idrlnet" project = "idrlnet"
copyright = "2021, IDRL" copyright = "2023, IDRL"
author = "IDRL" author = "IDRL"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
release = "0.1.0" release = "2.0.0-rc3"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@ -0,0 +1,123 @@
# Deepritz
This section repeats the Deepritz method presented by [Weinan E and Bing Yu](https://link.springer.com/article/10.1007/s40304-018-0127-z).
Consider the 2d Poisson's equation such as the following:
$$
\begin{equation}
\begin{aligned}
-\Delta u=f, & \text { in } \Omega \\
u=0, & \text { on } \partial \Omega
\end{aligned}
\end{equation}
$$
Based on the scattering theorem, its weak form is that both sides are multiplied by$ v \in H_0^1$(which can be interpreted as some function bounded by 0),to get
$$
\int f v=-\int v \Delta u=(\nabla u, \nabla v)
$$
The above equation holds for any $v \in H_0^1$. The bilinear part of the right-hand side of the equation with respect to $u,v$ is symmetric and yields the bilinear term:
$$
a(u, v)=\int \nabla u \cdot \nabla v
$$
By the Poincaré inequality, the $a(\cdot, \cdot)$ is a positive definite operator. By positive definite, we mean that there exists $\alpha >0$, such that
$$
a(u, u) \geq \alpha\|u\|^2, \quad \forall u \in H_0^1
$$
The remaining term is a linear generalization of $v$, which is $l(v)$, which yields the equation:
$$
a(u, v) = l(v)
$$
For this equation, by discretizing $u,v$ in the same finite dimensional subspace, we can obtain a symmetric positive definite system of equations, which is the family of Galerkin methods, or we can transform it into a polarization problem to solve it.
To find $u$ satisfies
$$
a(u, v) = l(v), \quad \forall v \in H_0^1
$$
For a symmetric positive definite $a$ , which is equivalent to solving the variational minimization problem, that is, finding $u$, such that holds, where
$$
J(u) = \frac{1}{2} a(u, u) - l(u)
$$
Specifically
$$
\min _{u \in H_0^1} J(u)=\frac{1}{2} \int\|\nabla u\|_2^2-\int f v
$$
The DeepRitz method is similar to the PINN approach, replacing the neural network with u, and after sampling the region, just solve it with a solver like Adam. Written as
$$
\begin{equation}
\min _{\left.\hat{u}\right|_{\partial \Omega}=0} \hat{J}(\hat{u})=\frac{1}{2} \frac{S_{\Omega}}{N_{\Omega}} \sum\left\|\nabla \hat{u}\left(x_i, y_i\right)\right\|_2^2-\frac{S_{\Omega}}{N_{\partial \Omega}} \sum f\left(x_i, y_i\right) \hat{u}\left(x_i, y_i\right)
\end{equation}
$$
Note that the original $u \in H_0^1$, which is zero on the boundary, is transformed into an unconstrained problem by adding the penalty function term:
$$
\begin{equation}
\begin{gathered}
\min \hat{J}(\hat{u})=\frac{1}{2} \frac{S_{\Omega}}{N_{\Omega}} \sum\left\|\nabla \hat{u}\left(x_i, y_i\right)\right\|_2^2-\frac{S_{\Omega}}{N_{\Omega}} \sum f\left(x_i, y_i\right) \hat{u}\left(x_i, y_i\right)+\beta \frac{S_{\partial \Omega}}{N_{\partial \Omega}} \\
\sum \hat{u}^2\left(x_i, y_i\right)
\end{gathered}
\end{equation}
$$
Consider the 2d Poisson's equation defined on $\Omega=[-1,1]\times[-1,1]$, which satisfies $f=2 \pi^2 \sin (\pi x) \sin (\pi y)$.
### Define Sampling Methods and Constraints
For the problem, boundary condition and PDE constraint are presented and use the Identity loss.
```python
@sc.datanode(sigma=1000.0)
class Boundary(sc.SampleDomain):
def __init__(self):
self.points = geo.sample_boundary(100,)
self.constraints = {"u": 0.}
def sampling(self, *args, **kwargs):
return self.points, self.constraints
@sc.datanode(loss_fn="Identity")
class Interior(sc.SampleDomain):
def __init__(self):
self.points = geo.sample_interior(1000)
self.constraints = {"integral_dxdy": 0,}
def sampling(self, *args, **kwargs):
return self.points, self.constraints
```
### Define Neural Networks and PDEs
In the PDE definition section, based on the DeepRitz method we add two types of PDE nodes:
```python
def f(x, y):
return 2 * sp.pi ** 2 * sp.sin(sp.pi * x) * sp.sin(sp.pi * y)
dx_exp = sc.ExpressionNode(
expression=0.5*(u.diff(x) ** 2 + u.diff(y) ** 2) - u * f(x, y), name="dxdy"
)
net = sc.get_net_node(inputs=("x", "y"), outputs=("u",), name="net", arch=sc.Arch.mlp)
integral = sc.ICNode("dxdy", dim=2, time=False)
```
The result is shown as follows:
![deepritz](https://github.com/xiangzixuebit/picture/raw/3d73005f3642f10400975659479e856fb99f6518/deepritz.png)

View File

@ -0,0 +1,227 @@
# Navier-Stokes equations
This section repeats the Robust PINN method presented by [Peng et.al](https://deepai.org/publication/robust-regression-with-highly-corrupted-data-via-physics-informed-neural-networks).
## Steady 2D NS equations
The prototype problem of incompressible flow past a circular cylinder is considered.
![image](https://github.com/xiangzixuebit/picture/raw/3d73005f3642f10400975659479e856fb99f6518/NS1.png)
The velocity vector is set to zero at all walls and the pressure is set to p = 0 at the outlet. The fluid density is taken as $\rho = 1kg/m^3$ and the dynamic viscosity is taken as $\mu = 2 · 10^{2}kg/m^3$ . The velocity profile on the inlet is set as $u(0, y)=4 \frac{U_M}{H^2}(H-y) y$ with $U_M = 1m/s$ and $H = 0.41m$.
The two-dimensional steady-state Navier-Stokes equation is equivalently transformed into the following equations:
$$
\begin{equation}
\begin{aligned}
\sigma^{11} &=-p+2 \mu u_x \\
\sigma^{22} &=-p+2 \mu v_y \\
\sigma^{12} &=\mu\left(u_y+v_x\right) \\
p &=-\frac{1}{2}\left(\sigma^{11}+\sigma^{22}\right) \\
\left(u u_x+v u_y\right) &=\mu\left(\sigma_x^{11}+\sigma_y^{12}\right) \\
\left(u v_x+v v_y\right) &=\mu\left(\sigma_x^{12}+\sigma_y^{22}\right)
\end{aligned}
\end{equation}
$$
We construct a neural network with six outputs to satisfy the PDE constraints above:
$$
\begin{equation}
u, v, p, \sigma^{11}, \sigma^{12}, \sigma^{22}=net(x, y)
\end{equation}
$$
### Define Symbols and Geometric Objects
For the 2d problem, we define two coordinate symbols`x`and`y`, six variables$ u, v, p, \sigma^{11}, \sigma^{12}, \sigma^{22}$ are defined.
The geometry object is a simple rectangle and circle with the operator `-`.
```python
x = Symbol('x')
y = Symbol('y')
rec = sc.Rectangle((0., 0.), (1.1, 0.41))
cir = sc.Circle((0.2, 0.2), 0.05)
geo = rec - cir
u = sp.Function('u')(x, y)
v = sp.Function('v')(x, y)
p = sp.Function('p')(x, y)
s11 = sp.Function('s11')(x, y)
s22 = sp.Function('s22')(x, y)
s12 = sp.Function('s12')(x, y)
```
### Define Sampling Methods and Constraints
For the problem, three boundary conditions , PDE constraint and external data are presented. We use the robust-PINN model inspired by the traditional LAD (Least Absolute Derivation) approach, where the L1 loss replaces the squared L2 data loss.
```python
@sc.datanode
class Inlet(sc.SampleDomain):
def sampling(self, *args, **kwargs):
points = rec.sample_boundary(1000, sieve=(sp.Eq(x, 0.)))
constraints = sc.Variables({'u': 4 * (0.41 - y) * y / (0.41 * 0.41)})
return points, constraints
@sc.datanode
class Outlet(sc.SampleDomain):
def sampling(self, *args, **kwargs):
points = geo.sample_boundary(1000, sieve=(sp.Eq(x, 1.1)))
constraints = sc.Variables({'p': 0.})
return points, constraints
@sc.datanode
class Wall(sc.SampleDomain):
def sampling(self, *args, **kwargs):
points = geo.sample_boundary(1000, sieve=((x > 0.) & (x < 1.1)))
#print("points3", points)
constraints = sc.Variables({'u': 0., 'v': 0.})
return points, constraints
@sc.datanode(name='NS_external')
class Interior_domain(sc.SampleDomain):
def __init__(self):
self.density = 2000
def sampling(self, *args, **kwargs):
points = geo.sample_interior(2000)
constraints = {'f_s11': 0., 'f_s22': 0., 'f_s12': 0., 'f_u': 0., 'f_v': 0., 'f_p': 0.}
return points, constraints
@sc.datanode(name='NS_domain', loss_fn='L1')
class NSExternal(sc.SampleDomain):
def __init__(self):
points = pd.read_csv('NSexternel_sample.csv')
self.points = {col: points[col].to_numpy().reshape(-1, 1) for col in points.columns}
self.constraints = {'u': self.points.pop('u'), 'v': self.points.pop('v'), 'p': self.points.pop('p')}
def sampling(self, *args, **kwargs):
return self.points, self.constraints
```
### Define Neural Networks and PDEs
In the PDE definition part, we add these PDE nodes:
```python
net = sc.MLP([2, 40, 40, 40, 40, 40, 40, 40, 40, 6], activation=sc.Activation.tanh)
net = sc.get_net_node(inputs=('x', 'y'), outputs=('u', 'v', 'p', 's11', 's22', 's12'), name='net', arch=sc.Arch.mlp)
pde1 = sc.ExpressionNode(name='f_s11', expression=-p + 2 * nu * u.diff(x) - s11)
pde2 = sc.ExpressionNode(name='f_s22', expression=-p + 2 * nu * v.diff(y) - s22)
pde3 = sc.ExpressionNode(name='f_s12', expression=nu * (u.diff(y) + v.diff(x)) - s12)
pde4 = sc.ExpressionNode(name='f_u', expression=u * u.diff(x) + v * u.diff(y) - nu * (s11.diff(x) + s12.diff(y)))
pde5 = sc.ExpressionNode(name='f_v', expression=u * v.diff(x) + v * v.diff(y) - nu * (s12.diff(x) + s22.diff(y)))
pde6 = sc.ExpressionNode(name='f_p', expression=p + (s11 + s22) / 2)
```
### Define A Solver
Direct use of Adam optimization is less effective, so the LBFGS optimization method or a combination of both (Adam+LBFGS) is used for training:
```python
s = sc.Solver(sample_domains=(Inlet(), Outlet(), Wall(), Interior_domain(), NSExternal()),
netnodes=[net],
init_network_dirs=['network_dir_adam'],
pdes=[pde1, pde2, pde3, pde4, pde5, pde6],
max_iter=300,
opt_config = dict(optimizer='LBFGS', lr=1)
)
```
The result is shown as follows:
![image](https://github.com/xiangzixuebit/picture/raw/3d73005f3642f10400975659479e856fb99f6518/NS11.png)
## Unsteady 2D N-S equations with unknown parameters
A two-dimensional incompressible flow and dynamic vortex shedding past a circular cylinder in a steady-state are numerically simulated. Respectively, the Reynolds number of the incompressible flow is $Re = 100$. The kinematic viscosity of the fluid is $\nu = 0.01$. The cylinder diameter D is 1. The simulation domain size is
$[-15,25] × [-8,8]$. The computational domain is much smaller: $[1,8] × [-2,2]× [0,20]$.
![image](https://github.com/xiangzixuebit/picture/raw/3d73005f3642f10400975659479e856fb99f6518/NS2.png)
$$
\begin{equation}
\begin{aligned}
&u_t+\lambda_1\left(u u_x+v u_y\right)=-p_x+\lambda_2\left(u_{x x}+u_{y y}\right) \\
&v_t+\lambda_1\left(u v_x+v v_y\right)=-p_y+\lambda_2\left(v_{x x}+v_{y y}\right)
\end{aligned}
\end{equation}
$$
where $\lambda_1$ and $\lambda_2$ are two unknown parameters to be recovered. We make the assumption that $u=\psi_y, \quad v=-\psi_x$
for some stream function $\psi(x, y)$. Under this assumption, the continuity equation will be automatically satisfied. The following architecture is used in this example,
$$
\begin{equation}
\psi, p=net\left(t, x, y, \lambda_1, \lambda_2\right)
\end{equation}
$$
### Define Symbols and Geometric Objects
We define three coordinate symbols `x`, `y` and `t`, three variables $u,v,p$ are defined.
```python
x = Symbol('x')
y = Symbol('y')
t = Symbol('t')
geo = sc.Rectangle((1., -2.), (8., 2.))
u = sp.Function('u')(x, y, t)
v = sp.Function('v')(x, y, t)
p = sp.Function('p')(x, y, t)
time_range = {t: (0, 20)}
```
### Define Sampling Methods and Constraints
This example has only two equation constraints, while the former has six equation constraints. We also use the LAD-PINN model. Then the PDE constrained optimization model is formulated as:
$$
\min _{\theta, \lambda} \frac{1}{\# \mathbf{D}_u} \sum_{\left(t_i, x_i, u_i\right) \in \mathbf{D}_u}\left|u_i-u_\theta\left(t_i, x_i ; \lambda\right)\right|+\omega \cdot L_{p d e} .
$$
```python
@sc.datanode(name='NS_domain', loss_fn='L1')
class NSExternal(sc.SampleDomain):
def __init__(self):
points = pd.read_csv('NSexternel_sample.csv')
self.points = {col: points[col].to_numpy().reshape(-1, 1) for col in points.columns}
self.constraints = {'u': self.points.pop('u'), 'v': self.points.pop('v'), 'p': self.points.pop('p')}
def sampling(self, *args, **kwargs):
return self.points, self.constraints
@sc.datanode(name='NS_external')
class NSEq(sc.SampleDomain):
def sampling(self, *args, **kwargs):
points = geo.sample_interior(density=2000, param_ranges=time_range)
constraints = {'continuity': 0, 'momentum_x': 0, 'momentum_y': 0}
return points, constraints
```
### Define Neural Networks and PDEs
IDRLnet defines a network node to represent the unknown Parameters.
```python
net = sc.MLP([3, 20, 20, 20, 20, 20, 20, 20, 20, 3], activation=sc.Activation.tanh)
net = sc.get_net_node(inputs=('x', 'y', 't'), outputs=('u', 'v', 'p'), name='net', arch=sc.Arch.mlp)
var_nr = sc.get_net_node(inputs=('x', 'y'), outputs=('nu', 'rho'), arch=sc.Arch.single_var)
pde = sc.NavierStokesNode(nu='nu', rho='rho', dim=2, time=True, u='u', v='v', p='p')
```
### Define A Solver
Two nodes trained together
```python
s = sc.Solver(sample_domains=(NSExternal(), NSEq()),
netnodes=[net, var_nr],
pdes=[pde],
network_dir='network_dir',
max_iter=10000)
```
Finally, the real velocity field and pressure field at t=10s are compared with the predicted results:
![image](https://github.com/xiangzixuebit/picture/raw/3d73005f3642f10400975659479e856fb99f6518/NS22.png)

View File

@ -14,6 +14,8 @@ To make full use of IDRLnet. We strongly suggest following the following example
6. :ref:`Parameterized poisson equation <Parameterized Poisson>`. The example introduces how to train a surrogate with parameters. 6. :ref:`Parameterized poisson equation <Parameterized Poisson>`. The example introduces how to train a surrogate with parameters.
7. :ref:`Variational Minimization <Variational Minimization>`. The example introduces how to solve variational minimization problems. 7. :ref:`Variational Minimization <Variational Minimization>`. The example introduces how to solve variational minimization problems.
8. :ref:`Volterra integral differential equation <Volterra Integral Differential Equation>`. The example introduces the way to solve IDEs. 8. :ref:`Volterra integral differential equation <Volterra Integral Differential Equation>`. The example introduces the way to solve IDEs.
9. :ref:`Navier-Stokes equation <Navier-Stokes equations>`. The example introduces how to use the LBFGS optimizer.
10. :ref:`Deepritz method <Deepritz>`. The example introduces the way to solve PDEs with the Deepritz method.
@ -28,3 +30,5 @@ To make full use of IDRLnet. We strongly suggest following the following example
6_parameterized_poisson 6_parameterized_poisson
7_minimal_surface 7_minimal_surface
8_volterra_ide 8_volterra_ide
9_navier_stokes_equation
10_deepritz

View File

@ -0,0 +1,75 @@
import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
import matplotlib.tri as tri
import idrlnet.shortcut as sc
x, y = sp.symbols("x y")
u = sp.Function("u")(x, y)
geo = sc.Rectangle((-1, -1), (1., 1.))
@sc.datanode(sigma=1000.0)
class Boundary(sc.SampleDomain):
def __init__(self):
self.points = geo.sample_boundary(100,)
self.constraints = {"u": 0.}
def sampling(self, *args, **kwargs):
return self.points, self.constraints
@sc.datanode(loss_fn="Identity")
class Interior(sc.SampleDomain):
def __init__(self):
self.points = geo.sample_interior(1000)
self.constraints = {"integral_dxdy": 0,}
def sampling(self, *args, **kwargs):
return self.points, self.constraints
def f(x, y):
return 2 * sp.pi ** 2 * sp.sin(sp.pi * x) * sp.sin(sp.pi * y)
dx_exp = sc.ExpressionNode(
expression=0.5*(u.diff(x) ** 2 + u.diff(y) ** 2) - u * f(x, y), name="dxdy"
)
net = sc.get_net_node(inputs=("x", "y"), outputs=("u",), name="net", arch=sc.Arch.mlp)
integral = sc.ICNode("dxdy", dim=2, time=False)
s = sc.Solver(
sample_domains=(Boundary(), Interior()),
netnodes=[net],
pdes=[
dx_exp,
integral,
],
max_iter=10000,
)
s.solve()
coord = s.infer_step({"Interior": ["x", "y", "u"]})
num_x = coord["Interior"]["x"].cpu().detach().numpy().ravel()
num_y = coord["Interior"]["y"].cpu().detach().numpy().ravel()
num_Up = coord["Interior"]["u"].cpu().detach().numpy().ravel()
# Ground truth
num_U = np.sin(np.pi*num_x)*np.sin(np.pi*num_y)
fig, ax = plt.subplots(1, 3, figsize=(10, 3))
triang_total = tri.Triangulation(num_x, num_y)
ax[0].tricontourf(triang_total, num_Up, 100, cmap="bwr", vmin=-1, vmax=1)
ax[0].axis("off")
ax[0].set_title("prediction")
ax[1].tricontourf(triang_total, num_U, 100, cmap="bwr", vmin=-1, vmax=1)
ax[1].axis("off")
ax[1].set_title("ground truth")
ax[2].tricontourf(
triang_total, np.abs(num_U - num_Up), 100, cmap="bwr", vmin=0, vmax=0.5
)
ax[2].axis("off")
ax[2].set_title("absolute error")
plt.savefig("deepritz.png", dpi=300, bbox_inches="tight")

View File

@ -63,8 +63,7 @@ solver = sc.Solver(
), ),
netnodes=[net], netnodes=[net],
pdes=[pde1, pde2, pde3, pde4], pdes=[pde1, pde2, pde3, pde4],
max_iter=2000, max_iter=2000)
)
solver.solve() solver.solve()

View File

@ -0,0 +1,79 @@
import matplotlib.pyplot as plt
import sympy as sp
import numpy as np
import idrlnet.shortcut as sc
x = sp.symbols('x')
Line = sc.Line1D(0, 1)
y = sp.Function('y')(x)
@sc.datanode(name='interior')
class Interior(sc.SampleDomain):
def sampling(self, *args, **kwargs):
return Line.sample_interior(1000), {'dddd_y': 0}
@sc.datanode(name='left_boundary1')
class LeftBoundary1(sc.SampleDomain):
def sampling(self, *args, **kwargs):
return Line.sample_boundary(100, sieve=(sp.Eq(x, 0))), {'y': 0}
@sc.datanode(name='left_boundary2')
class LeftBoundary2(sc.SampleDomain):
def sampling(self, *args, **kwargs):
return Line.sample_boundary(100, sieve=(sp.Eq(x, 0))), {'d_y': 0}
@sc.datanode(name='right_boundary1')
class RightBoundary1(sc.SampleDomain):
def sampling(self, *args, **kwargs):
return Line.sample_boundary(100, sieve=(sp.Eq(x, 1))), {'dd_y': 0}
@sc.datanode(name='right_boundary2')
class RightBoundary2(sc.SampleDomain):
def sampling(self, *args, **kwargs):
return Line.sample_boundary(100, sieve=(sp.Eq(x, 1))), {'ddd_y': 0}
@sc.datanode(name='infer')
class Infer(sc.SampleDomain):
def sampling(self, *args, **kwargs):
return {'x': np.linspace(0, 1, 1000).reshape(-1, 1)}, {}
net = sc.get_net_node(inputs=('x',), outputs=('y',), name='net', arch=sc.Arch.mlp)
pde1 = sc.ExpressionNode(name='dddd_y', expression=y.diff(x).diff(x).diff(x).diff(x) + 1)
pde2 = sc.ExpressionNode(name='d_y', expression=y.diff(x))
pde3 = sc.ExpressionNode(name='dd_y', expression=y.diff(x).diff(x))
pde4 = sc.ExpressionNode(name='ddd_y', expression=y.diff(x).diff(x).diff(x))
solver = sc.Solver(
sample_domains=(Interior(), LeftBoundary1(), LeftBoundary2(), RightBoundary1(), RightBoundary2()),
netnodes=[net],
pdes=[pde1, pde2, pde3, pde4],
max_iter=200,
opt_config=dict(optimizer='LBFGS', lr=1))
solver.solve()
# inference
def exact(x):
return -(x ** 4) / 24 + x ** 3 / 6 - x ** 2 / 4
solver.sample_domains = (Infer(),)
points = solver.infer_step({'infer': ['x', 'y']})
xs = points['infer']['x'].detach().cpu().numpy().ravel()
y_pred = points['infer']['y'].detach().cpu().numpy().ravel()
plt.plot(xs, y_pred, label='Pred')
y_exact = exact(xs)
plt.plot(xs, y_exact, label='Exact', linestyle='--')
plt.legend()
plt.xlabel('x')
plt.ylabel('w')
plt.savefig('Euler_beam_LBFGS.png', dpi=300, bbox_inches='tight')
plt.show()

View File

@ -0,0 +1,194 @@
import matplotlib.pyplot as plt
import sympy as sp
import numpy as np
import idrlnet.shortcut as sc
from sympy import Symbol, sin
import pandas as pd
import torch
import matplotlib.tri as tri
x = Symbol('x')
y = Symbol('y')
rec = sc.Rectangle((0., 0.), (1.1, 0.41))
cir = sc.Circle((0.2, 0.2), 0.05)
geo = rec - cir
u = sp.Function('u')(x, y)
v = sp.Function('v')(x, y)
p = sp.Function('p')(x, y)
s11 = sp.Function('s11')(x, y)
s22 = sp.Function('s22')(x, y)
s12 = sp.Function('s12')(x, y)
nu=0.02
rho=1
@sc.datanode
class Inlet(sc.SampleDomain):
def sampling(self, *args, **kwargs):
points = rec.sample_boundary(1000, sieve=(sp.Eq(x, 0.)))
constraints = sc.Variables({'u': 4 * (0.41 - y) * y / (0.41 * 0.41)})
return points, constraints
@sc.datanode
class Outlet(sc.SampleDomain):
def sampling(self, *args, **kwargs):
points = geo.sample_boundary(1000, sieve=(sp.Eq(x, 1.1)))
constraints = sc.Variables({'p': 0.})
return points, constraints
@sc.datanode
class Wall(sc.SampleDomain):
def sampling(self, *args, **kwargs):
points = geo.sample_boundary(1000, sieve=((x > 0.) & (x < 1.1)))
#print("points3", points)
constraints = sc.Variables({'u': 0., 'v': 0.})
return points, constraints
@sc.datanode(name='NS_external')
class Interior_domain(sc.SampleDomain):
def __init__(self):
self.density = 2000
def sampling(self, *args, **kwargs):
points = geo.sample_interior(2000)
constraints = {'f_s11': 0., 'f_s22': 0., 'f_s12': 0., 'f_u': 0., 'f_v': 0., 'f_p': 0.}
return points, constraints
@sc.datanode(name='NS_domain', loss_fn='L1')
class NSExternal(sc.SampleDomain):
def __init__(self):
points = pd.read_csv('NSexternel_sample.csv')
self.points = {col: points[col].to_numpy().reshape(-1, 1) for col in points.columns}
self.constraints = {'u': self.points.pop('u'), 'v': self.points.pop('v'), 'p': self.points.pop('p')}
def sampling(self, *args, **kwargs):
return self.points, self.constraints
net = sc.MLP([2, 40, 40, 40, 40, 40, 40, 40, 40, 6], activation=sc.Activation.tanh)
net = sc.get_net_node(inputs=('x', 'y'), outputs=('u', 'v', 'p', 's11', 's22', 's12'), name='net', arch=sc.Arch.mlp)
pde1 = sc.ExpressionNode(name='f_s11', expression=-p + 2 * nu * u.diff(x) - s11)
pde2 = sc.ExpressionNode(name='f_s22', expression=-p + 2 * nu * v.diff(y) - s22)
pde3 = sc.ExpressionNode(name='f_s12', expression=nu * (u.diff(y) + v.diff(x)) - s12)
pde4 = sc.ExpressionNode(name='f_u', expression=u * u.diff(x) + v * u.diff(y) - nu * (s11.diff(x) + s12.diff(y)))
pde5 = sc.ExpressionNode(name='f_v', expression=u * v.diff(x) + v * v.diff(y) - nu * (s12.diff(x) + s22.diff(y)))
pde6 = sc.ExpressionNode(name='f_p', expression=p + (s11 + s22) / 2)
s = sc.Solver(sample_domains=(Inlet(), Outlet(), Wall(), Interior_domain(), NSExternal()),
netnodes=[net],
init_network_dirs=['network_dir_adam'],
pdes=[pde1, pde2, pde3, pde4, pde5, pde6],
max_iter=300,
opt_config = dict(optimizer='LBFGS', lr=1)
)
#opt_config = dict(optimizer='LBFGS', lr=1)
#init_network_dirs=['network_dir_lbfgs'],
s.solve()
points1 = pd.read_csv('NSexternel_test.csv')
points1 = {col: points1[col].to_numpy().reshape(-1, 1) for col in points1.columns}
x_test = torch.tensor(points1['x_test'].astype(np.float32))
y_test = torch.tensor(points1['y_test'].astype(np.float32))
u_test = torch.tensor(points1['u_test'].astype(np.float32))
v_test = torch.tensor(points1['v_test'].astype(np.float32))
p_test = torch.tensor(points1['p_test'].astype(np.float32))
U = s.netnodes[0].net(torch.cat([x_test, y_test], dim=1))
num_x = x_test.cpu().detach().numpy().ravel()
num_y = y_test.cpu().detach().numpy().ravel()
num_u = u_test.cpu().detach().numpy().ravel()
num_v = v_test.cpu().detach().numpy().ravel()
num_p = p_test.cpu().detach().numpy().ravel()
num_up = U[:, 0:1].cpu().detach().numpy().ravel()
num_vp = U[:, 1:2].cpu().detach().numpy().ravel()
num_pp = U[:, 2:3].cpu().detach().numpy().ravel()
triang_total = tri.Triangulation(num_x, num_y)
font2 = {'family': 'Times New Roman',
'weight': 'normal',
'size': 15,
}
fig = plt.figure(figsize=(20, 4))
ax1 = fig.add_subplot(131)
tcf = ax1.tricontourf(triang_total, num_u, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax1.set_xlabel('$x$', font2)
ax1.set_ylabel('$y$', font2)
ax1.set_title('Exact $u$', fontsize=18)
ax2 = fig.add_subplot(132)
tcf = ax2.tricontourf(triang_total, num_up, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax2.set_xlabel('$x$', font2)
ax2.set_ylabel('$y$', font2)
ax2.set_title('Predicted $u$', fontsize=18)
ax3 = fig.add_subplot(133)
tcf = ax3.tricontourf(triang_total, num_u - num_up, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax3.set_xlabel('$x$', font2)
ax3.set_ylabel('$y$', font2)
ax3.set_title('Point-wise Error', fontsize=18)
plt.savefig('test_NS_u_Adam.png', dpi=300, bbox_inches='tight')
plt.show()
fig = plt.figure(figsize=(20, 4))
ax1 = fig.add_subplot(131)
tcf = ax1.tricontourf(triang_total, num_v, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax1.set_xlabel('$x$', font2)
ax1.set_ylabel('$y$', font2)
ax1.set_title('Exact $v$', fontsize=18)
ax2 = fig.add_subplot(132)
tcf = ax2.tricontourf(triang_total, num_vp, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax2.set_xlabel('$x$', font2)
ax2.set_ylabel('$y$', font2)
ax2.set_title('Predicted $v$', fontsize=18)
ax3 = fig.add_subplot(133)
tcf = ax3.tricontourf(triang_total, num_v - num_vp, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax3.set_xlabel('$x$', font2)
ax3.set_ylabel('$y$', font2)
ax3.set_title('Point-wise Error', fontsize=18)
plt.savefig('test_NS_v_Adam.png', dpi=300, bbox_inches='tight')
plt.show()
fig = plt.figure(figsize=(20, 4))
ax1 = fig.add_subplot(131)
tcf = ax1.tricontourf(triang_total, num_p, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax1.set_xlabel('$x$', font2)
ax1.set_ylabel('$y$', font2)
ax1.set_title('Exact $p$', fontsize=18)
ax2 = fig.add_subplot(132)
tcf = ax2.tricontourf(triang_total, num_pp, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax2.set_xlabel('$x$', font2)
ax2.set_ylabel('$y$', font2)
ax2.set_title('Predicted $p$', fontsize=18)
ax3 = fig.add_subplot(133)
tcf = ax3.tricontourf(triang_total, num_p - num_pp, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax3.set_xlabel('$x$', font2)
ax3.set_ylabel('$y$', font2)
ax3.set_title('Point-wise Error', fontsize=18)
plt.savefig('test_NS_p_Adam.png', dpi=300, bbox_inches='tight')
plt.show()

View File

@ -0,0 +1,188 @@
import matplotlib.pyplot as plt
import sympy as sp
import numpy as np
import idrlnet.shortcut as sc
from sympy import Symbol, sin
import pandas as pd
import torch
import matplotlib.tri as tri
x = Symbol('x')
y = Symbol('y')
t = Symbol('t')
geo = sc.Rectangle((1., -2.), (8., 2.))
u = sp.Function('u')(x, y, t)
v = sp.Function('v')(x, y, t)
p = sp.Function('p')(x, y, t)
time_range = {t: (0, 20)}
nu=0.01
rho=1
@sc.datanode(name='NS_domain', loss_fn='L1')
class NSExternal(sc.SampleDomain):
def __init__(self):
points = pd.read_csv('NSexternel_sample.csv')
self.points = {col: points[col].to_numpy().reshape(-1, 1) for col in points.columns}
self.constraints = {'u': self.points.pop('u'), 'v': self.points.pop('v'), 'p': self.points.pop('p')}
def sampling(self, *args, **kwargs):
return self.points, self.constraints
@sc.datanode(name='NS_external')
class NSEq(sc.SampleDomain):
def sampling(self, *args, **kwargs):
points = geo.sample_interior(density=2000, param_ranges=time_range)
constraints = {'continuity': 0, 'momentum_x': 0, 'momentum_y': 0}
return points, constraints
net = sc.MLP([3, 20, 20, 20, 20, 20, 20, 20, 20, 3], activation=sc.Activation.tanh)
net = sc.get_net_node(inputs=('x', 'y', 't'), outputs=('u', 'v', 'p'), name='net', arch=sc.Arch.mlp)
#var_nr = sc.get_net_node(inputs=('x', 'y'), outputs=('nu', 'rho'), arch=sc.Arch.single_var)
#pde = sc.NavierStokesNode(nu='nu', rho='rho', dim=2, time=True, u='u', v='v', p='p')
pde = sc.NavierStokesNode(nu=0.01, rho=1.0, dim=2, time=True)
s = sc.Solver(sample_domains=(NSExternal(), NSEq()),
netnodes=[net],
init_network_dirs=['network_dir_adam'],
pdes=[pde],
max_iter=100,
opt_config=dict(optimizer='LBFGS', lr=1)
)
#opt_config=dict(optimizer='LBFGS', lr=1)
# s = sc.Solver(sample_domains=(NSExternal(), NSEq()),
# netnodes=[net, var_nr],
# pdes=[pde],
# network_dir='network_dir',
# max_iter=10)
s.solve()
coord = s.infer_step(domain_attr={'NS_domain': ['x', 'y', 'u', 'v', 'p']})
num_xd = coord['NS_domain']['x'].cpu().detach().numpy().ravel()
num_yd = coord['NS_domain']['y'].cpu().detach().numpy().ravel()
num_ud = coord['NS_domain']['u'].cpu().detach().numpy().ravel()
num_vd = coord['NS_domain']['v'].cpu().detach().numpy().ravel()
num_pd = coord['NS_domain']['p'].cpu().detach().numpy().ravel()
# print("true paratmeter rho: {:.4f}".format(rho))
# predict_rho = var_nr.evaluate(torch.Tensor([[1.0]])).item()
# print("predicted parameter rho: {:.4f}".format(predict_rho))
points1 = pd.read_csv('NSexternel_test.csv')
points1 = {col: points1[col].to_numpy().reshape(-1, 1) for col in points1.columns}
x_test = torch.tensor(points1['x_test'].astype(np.float32))
y_test = torch.tensor(points1['y_test'].astype(np.float32))
t_test = torch.tensor(points1['t_test'].astype(np.float32))
u_test = torch.tensor(points1['u_test'].astype(np.float32))
v_test = torch.tensor(points1['v_test'].astype(np.float32))
p_test = torch.tensor(points1['p_test'].astype(np.float32))
U = s.netnodes[0].net(torch.cat([x_test, y_test, t_test], dim=1))
num_x = x_test.cpu().detach().numpy().ravel()
num_y = y_test.cpu().detach().numpy().ravel()
num_u = u_test.cpu().detach().numpy().ravel()
num_v = v_test.cpu().detach().numpy().ravel()
num_p = p_test.cpu().detach().numpy().ravel()
num_up = U[:, 0:1].cpu().detach().numpy().ravel()
num_vp = U[:, 1:2].cpu().detach().numpy().ravel()
num_pp = U[:, 2:3].cpu().detach().numpy().ravel()
triang_total = tri.Triangulation(num_x, num_y)
font2 = {'family': 'Times New Roman',
'weight': 'normal',
'size': 15,
}
# fig = plt.figure(figsize=(6, 6))
# ax = fig.add_subplot(111)
# ax.scatter(num_xi, num_yi, c='b', s=1, label='Domain')
# ax.set_xlabel('$x$', font2)
# ax.set_ylabel('$y$', font2)
# ax.set_title('collocation points', fontsize=18)
# plt.savefig('points.png', dpi=300, bbox_inches='tight')
# plt.show()
fig = plt.figure(figsize=(20, 4))
ax1 = fig.add_subplot(131)
tcf = ax1.tricontourf(triang_total, num_u, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax1.set_xlabel('$x$', font2)
ax1.set_ylabel('$y$', font2)
ax1.set_title('Exact $u$', fontsize=18)
ax2 = fig.add_subplot(132)
tcf = ax2.tricontourf(triang_total, num_up, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax2.set_xlabel('$x$', font2)
ax2.set_ylabel('$y$', font2)
ax2.set_title('Predicted $u$', fontsize=18)
ax3 = fig.add_subplot(133)
tcf = ax3.tricontourf(triang_total, num_u - num_up, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax3.set_xlabel('$x$', font2)
ax3.set_ylabel('$y$', font2)
ax3.set_title('Point-wise Error', fontsize=18)
plt.savefig('test_NS_u_c.png', dpi=300, bbox_inches='tight')
plt.show()
fig = plt.figure(figsize=(20, 4))
ax1 = fig.add_subplot(131)
tcf = ax1.tricontourf(triang_total, num_v, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax1.set_xlabel('$x$', font2)
ax1.set_ylabel('$y$', font2)
ax1.set_title('Exact $v$', fontsize=18)
ax2 = fig.add_subplot(132)
tcf = ax2.tricontourf(triang_total, num_v, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax2.set_xlabel('$x$', font2)
ax2.set_ylabel('$y$', font2)
ax2.set_title('Predicted $v$', fontsize=18)
ax3 = fig.add_subplot(133)
tcf = ax3.tricontourf(triang_total, num_v - num_vp, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax3.set_xlabel('$x$', font2)
ax3.set_ylabel('$y$', font2)
ax3.set_title('Point-wise Error', fontsize=18)
plt.savefig('test_NS_v_c.png', dpi=300, bbox_inches='tight')
plt.show()
fig = plt.figure(figsize=(20, 4))
ax1 = fig.add_subplot(131)
tcf = ax1.tricontourf(triang_total, num_p, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax1.set_xlabel('$x$', font2)
ax1.set_ylabel('$y$', font2)
ax1.set_title('Exact $p$', fontsize=18)
ax2 = fig.add_subplot(132)
tcf = ax2.tricontourf(triang_total, num_pp, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax2.set_xlabel('$x$', font2)
ax2.set_ylabel('$y$', font2)
ax2.set_title('Predicted $p$', fontsize=18)
ax3 = fig.add_subplot(133)
tcf = ax3.tricontourf(triang_total, num_p - num_pp, 100, cmap='jet')
tc_bar = plt.colorbar(tcf)
tc_bar.ax.tick_params(labelsize=12)
ax3.set_xlabel('$x$', font2)
ax3.set_ylabel('$y$', font2)
ax3.set_title('Point-wise Error', fontsize=18)
plt.savefig('test_NS_p_c.png', dpi=300, bbox_inches='tight')
plt.show()

View File

@ -312,18 +312,37 @@ class Solver(Notifier, Optimizable):
""" """
self.notify(self, message={Signal.TRAIN_PIPE_START: "defaults"}) self.notify(self, message={Signal.TRAIN_PIPE_START: "defaults"})
for opt in self.optimizers: for opt in self.optimizers:
opt.zero_grad() # print('Running optimization with %s'%(self.optimizer_config['optimizer']))
if self.optimizer_config['optimizer'] == 'LBFGS':
def closure():
opt.zero_grad()
samples = self.sample_variables_from_domains()
in_var, true_out, lambda_out = self.generate_in_out_dict(samples)
pred_out_sample = self.forward_through_all_graph(in_var, self.outvar_dict_index)
loss = self.compute_loss(in_var, pred_out_sample, true_out, lambda_out)
self.notify(self, message={Signal.BEFORE_BACKWARD: 'defaults'})
loss.backward()
return loss
opt.step(closure)
else:
opt.zero_grad()
samples = self.sample_variables_from_domains()
in_var, true_out, lambda_out = self.generate_in_out_dict(samples)
pred_out_sample = self.forward_through_all_graph(in_var, self.outvar_dict_index)
try:
loss = self.compute_loss(in_var, pred_out_sample, true_out, lambda_out)
except RuntimeError:
raise
self.notify(self, message={Signal.BEFORE_BACKWARD: 'defaults'})
loss.backward()
opt.step()
samples = self.sample_variables_from_domains() samples = self.sample_variables_from_domains()
in_var, true_out, lambda_out = self.generate_in_out_dict(samples) in_var, true_out, lambda_out = self.generate_in_out_dict(samples)
pred_out_sample = self.forward_through_all_graph(in_var, self.outvar_dict_index) pred_out_sample = self.forward_through_all_graph(in_var, self.outvar_dict_index)
try: loss = self.compute_loss(in_var, pred_out_sample, true_out, lambda_out)
loss = self.compute_loss(in_var, pred_out_sample, true_out, lambda_out)
except RuntimeError:
raise
self.notify(self, message={Signal.BEFORE_BACKWARD: "defaults"})
loss.backward()
for opt in self.optimizers:
opt.step()
self.global_step += 1 self.global_step += 1
for scheduler in self.schedulers: for scheduler in self.schedulers:

View File

@ -21,6 +21,7 @@ class Loss(enum.Enum):
L1 = "L1" L1 = "L1"
square = "square" square = "square"
Identity = "Identity"
class LossFunction: class LossFunction:
@ -32,6 +33,8 @@ class LossFunction:
return LossFunction.weighted_L1_loss(variables, name=name) return LossFunction.weighted_L1_loss(variables, name=name)
elif loss_function == Loss.square.name or loss_function == Loss.square: elif loss_function == Loss.square.name or loss_function == Loss.square:
return LossFunction.weighted_square_loss(variables, name=name) return LossFunction.weighted_square_loss(variables, name=name)
elif loss_function == Loss.Identity.name or loss_function == Loss.Identity:
return LossFunction.weighted_identity_loss(variables, name=name)
raise NotImplementedError(f"loss function {loss_function} is not defined!") raise NotImplementedError(f"loss function {loss_function} is not defined!")
@staticmethod @staticmethod
@ -62,6 +65,20 @@ class LossFunction:
loss += torch.sum((val ** 2) * variables["area"]) loss += torch.sum((val ** 2) * variables["area"])
return Variables({name: loss}) return Variables({name: loss})
@staticmethod
def weighted_identity_loss(variables: "Variables", name: str) -> "Variables":
loss = 0.0
for key, val in variables.items():
if key.startswith("lambda_") or key == "area":
continue
elif "lambda_" + key in variables.keys():
loss += torch.sum(
val * variables["lambda_" + key] * variables["area"]
)
else:
loss += torch.sum(val * variables["area"])
return Variables({name: loss})
class Variables(dict): class Variables(dict):
def __sub__(self, other: "Variables") -> "Variables": def __sub__(self, other: "Variables") -> "Variables":

View File

@ -1,6 +1,5 @@
transforms3d transforms3d
typing typing
numpy
keras keras
h5py h5py
pandas pandas
@ -11,12 +10,13 @@ sphinx
matplotlib matplotlib
myst_parser myst_parser
sphinx_markdown_parser sphinx_markdown_parser
numpy==1.21.0
sphinx_rtd_theme==0.5.2 sphinx_rtd_theme==0.5.2
tensorboard==2.4.1 tensorboard==2.4.1
sympy==1.5.1 sympy==1.5.1
pyevtk==1.1.1 pyevtk==1.1.1
flask==1.1.2 flask==1.1.2
requests==2.25.0 requests==2.25.0
torch>=1.7.1 torch==1.7.1
networkx==2.5.1 networkx==2.5.1
protobuf~=3.20 protobuf~=3.20

View File

@ -21,7 +21,7 @@ def load_requirements(path_dir=here, comment_char="#"):
setuptools.setup( setuptools.setup(
name="idrlnet", # Replace with your own username name="idrlnet", # Replace with your own username
version="0.1.0", version="2.0.0-rc3",
author="Intelligent Design & Robust Learning lab", author="Intelligent Design & Robust Learning lab",
author_email="weipeng@deepinfar.cn", author_email="weipeng@deepinfar.cn",
description="IDRLnet", description="IDRLnet",