88 lines
2.4 KiB
Python
88 lines
2.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
'''
|
|
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
|
|
@ Author : Zeyu Zhang
|
|
@ Email : zhangzeyu_work@outlook.com
|
|
@ Date : 2022-10-25 09:30:27
|
|
@ LastEditTime : 2022-10-25 09:33:20
|
|
@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Problem.py
|
|
@
|
|
@ Description :
|
|
@ Reference :
|
|
'''
|
|
|
|
import jax.ops as jops
|
|
import numpy as np
|
|
import dataclasses
|
|
from typing import Optional, Union
|
|
import jax.numpy as jnp
|
|
from jax.config import config
|
|
config.update("jax_enable_x64", True)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ToProblem:
|
|
L: int
|
|
W: int
|
|
nelx: int
|
|
nely: int
|
|
design_condition: int
|
|
ele_length: np.float64
|
|
ele_width: np.float64
|
|
ele_thickness: np.float64
|
|
volfrac: np.float64
|
|
fixeddofs: np.ndarray
|
|
freedofs: np.ndarray
|
|
fext: np.ndarray
|
|
filter_width: float
|
|
design_map: np.ndarray
|
|
|
|
|
|
def cantilever_single():
|
|
L = 0.8
|
|
W = 0.4
|
|
nelx = 80
|
|
nely = 40
|
|
design_condition = 1
|
|
Total_dof = 2*(nelx+1)*(nely+1)
|
|
ele_length = L/nelx
|
|
ele_width = W/nely
|
|
ele_thickness = 1.0
|
|
volfrac = 0.5
|
|
dofs = np.arange(Total_dof)
|
|
loaddof = Total_dof-1
|
|
# fext = jnp.zeros(Total_dof)
|
|
# fext = jops.index_update(fext, jops.index[loaddof], -1.)
|
|
fext = np.zeros(Total_dof)
|
|
fext[loaddof] = -1
|
|
fixeddofs = np.array(dofs[0:2*(nely+1):1])
|
|
freedofs = np.setdiff1d(dofs, fixeddofs)
|
|
filter_width = 3.0
|
|
design_map = np.ones([nely, nelx], dtype=np.int64)
|
|
return ToProblem(L, W, nelx, nely, design_condition, ele_length, ele_width, ele_thickness, volfrac, fixeddofs, freedofs, fext, filter_width, design_map)
|
|
|
|
|
|
def cantilever_single_NN(TopOpt):
|
|
dense_input_channel = 128
|
|
dense_init_scale = 1.0
|
|
conv_input_0 = 32
|
|
up_sampling_scale = (1, 2, 2, 2, 1)
|
|
total_resize = int(np.prod(up_sampling_scale))
|
|
h = TopOpt.nely // total_resize
|
|
w = TopOpt.nelx // total_resize
|
|
dense_output_channel = h * w * conv_input_0
|
|
dense_scale_init = dense_init_scale * \
|
|
jnp.sqrt(jnp.maximum(dense_output_channel / dense_input_channel, 1))
|
|
model_kwargs = {
|
|
'h': h,
|
|
'w': w,
|
|
'dense_scale_init': dense_scale_init,
|
|
'dense_output_channel': dense_output_channel,
|
|
'dense_input_channel': dense_input_channel,
|
|
'conv_input_0': conv_input_0,
|
|
'up_sampling_scale': up_sampling_scale,
|
|
}
|
|
return model_kwargs
|
|
|
|
|