2022.10.25 Commit
This commit is contained in:
parent
cd487502ba
commit
a011a13412
|
@ -0,0 +1,82 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
|
||||
@ Author : Zeyu Zhang
|
||||
@ Email : zhangzeyu_work@outlook.com
|
||||
@ Date : 2021-11-17 10:52:26
|
||||
@ LastEditTime : 2022-10-25 09:31:20
|
||||
@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TONR_Linear_Stiffness.py
|
||||
@
|
||||
@ Description : Jax+Flax+optax TONR Linear Stiffness Problem
|
||||
@ Reference :
|
||||
'''
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import pandas as pd
|
||||
import xarray
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn
|
||||
import time
|
||||
import TO_Problem
|
||||
import TO_Define
|
||||
import TO_Model
|
||||
import TO_Train
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
||||
here = Path(__file__).resolve().parent
|
||||
|
||||
design_condition = 1
|
||||
if design_condition == 1:
|
||||
problem = TO_Problem.cantilever_single()
|
||||
max_iterations = 200
|
||||
args = TO_Define.Toparams(problem)
|
||||
TopOpt_env = TO_Define.Topology_Optimization(args)
|
||||
model_kwargs = TO_Problem.cantilever_single_NN(TopOpt_env)
|
||||
elif design_condition == 2:
|
||||
problem = TO_Problem.cantilever_single_big()
|
||||
max_iterations = 200
|
||||
args = TO_Define.Toparams(problem)
|
||||
TopOpt_env = TO_Define.Topology_Optimization(args)
|
||||
model_kwargs = TO_Problem.cantilever_single_big_NN(TopOpt_env)
|
||||
|
||||
TONR_model = TO_Model.TO_CNN(TopOpt=TopOpt_env, **model_kwargs)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
start_time = time.perf_counter()
|
||||
seed = 1
|
||||
ds_TopOpt = TO_Train.train_TO_Optax(
|
||||
TONR_model, max_iterations, save_intermediate_designs=True, seed=1)
|
||||
ds_TopOpt.to_netcdf('ds_TopOpt.nc') # save
|
||||
# ds_TopOpt = xarray.open_dataset('ds_TopOpt.nc') # load
|
||||
end_time = time.perf_counter()
|
||||
whole_time = end_time - start_time
|
||||
|
||||
obj_value = ds_TopOpt.loss.data
|
||||
x_designed = ds_TopOpt.design.data
|
||||
|
||||
obj_min = jnp.min(obj_value)
|
||||
obj_max = jnp.max(obj_value)
|
||||
obj_min_index = jnp.where(obj_value == obj_min)[0][0]
|
||||
|
||||
dims = pd.Index(['TONR'], name='model')
|
||||
ds = xarray.concat([ds_TopOpt], dim=dims)
|
||||
ds_TopOpt.loss.transpose().to_pandas().cummin().plot(linewidth=2)
|
||||
plt.ylim(obj_min * 0.85, obj_value[0] * 1.15)
|
||||
plt.ylabel('Compliance (loss)')
|
||||
plt.xlabel('Optimization step')
|
||||
seaborn.despine()
|
||||
|
||||
final_design = x_designed[obj_min_index, :, :]
|
||||
final_obj = obj_min
|
||||
|
||||
plt.show()
|
||||
ds_TopOpt.design.sel(step=obj_min_index).plot.imshow(x='x', y='y', size=2,
|
||||
aspect=2.5, col_wrap=2, yincrease=False, add_colorbar=False, cmap='Greys')
|
||||
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
|
||||
@ Author : Zeyu Zhang
|
||||
@ Email : zhangzeyu_work@outlook.com
|
||||
@ Date : 2021-11-17 10:52:26
|
||||
@ LastEditTime : 2022-10-25 09:33:11
|
||||
@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Cal.py
|
||||
@
|
||||
@ Description : 存储TO计算过程所需的函数
|
||||
@ Reference :
|
||||
'''
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy.signal as jss
|
||||
import jax.lax as jl
|
||||
from jaxopt import Bisection
|
||||
from functools import partial
|
||||
import TO_Private_Autodiff as TPA
|
||||
import TO_FEA
|
||||
import TO_Obj
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(5, 6, 7))
|
||||
def design_variable(x, design_area_indices, nondesign_area_indices, filter_kernal, filter_weight, beta_x, volfrac, projection=False):
|
||||
x = filter(x, filter_kernal, filter_weight)
|
||||
x_1D = x.flatten(order='F')
|
||||
x_designed = design_and_volume_constraints(
|
||||
x_1D[design_area_indices], volfrac, beta_x)
|
||||
xPhys_1D = matrix_set_0(
|
||||
x_designed, design_area_indices, nondesign_area_indices)
|
||||
return xPhys_1D
|
||||
|
||||
|
||||
def projection(x, beta_TONR):
|
||||
output = jnp.tanh(beta_TONR * 0.5 * x)*.5 + 0.5
|
||||
return output
|
||||
|
||||
|
||||
def logit(obj_V):
|
||||
p = jnp.clip(
|
||||
obj_V, 0, 1)
|
||||
return jnp.log(p) - jnp.log1p(-p) # p从-inf到inf p取0.5为0 以0.5为界限 大于小于符号相反 数值相同
|
||||
|
||||
|
||||
# @partial(jax.jit, static_argnums=(1, 2))
|
||||
def design_and_volume_constraints(x, obj_V, beta_TONR):
|
||||
def find_yita(y, x, beta):
|
||||
projection_out = projection(x + y, beta)
|
||||
volume_diff = jnp.mean(projection_out) - obj_V
|
||||
return volume_diff
|
||||
lower_bound = logit(obj_V) - jl.stop_gradient(np.max(x))
|
||||
upper_bound = logit(obj_V) - jl.stop_gradient(np.min(x))
|
||||
bisec = Bisection(optimality_fun=find_yita, lower=lower_bound, upper=upper_bound, tol=1e-12, maxiter=64, check_bracket=False)
|
||||
yita = bisec.run(x=x, beta=beta_TONR).params
|
||||
x_design = projection(x + yita, beta_TONR)
|
||||
x_design = x_design + 0.00001
|
||||
return x_design
|
||||
|
||||
|
||||
def filter_define(design_map, filter_width):
|
||||
dy, dx = jnp.meshgrid(jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width)),
|
||||
jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width)))
|
||||
filter_kernal = jnp.maximum(0, filter_width - jnp.sqrt(dx ** 2 + dy ** 2))
|
||||
filter_weight = jss.convolve2d(design_map, filter_kernal, 'same')
|
||||
return filter_kernal, filter_weight
|
||||
|
||||
|
||||
# @jax.jit
|
||||
def filter(matrixA, filter_kernal, filter_weight):
|
||||
new_matrix = jss.convolve2d(matrixA, filter_kernal, 'same') / filter_weight
|
||||
return new_matrix
|
||||
|
||||
|
||||
def inverse_permutation(indices):
|
||||
inverse_perm = jnp.zeros(len(indices), dtype=jnp.int64)
|
||||
inverse_perm = inverse_perm.at[indices].set(
|
||||
jnp.arange(len(indices), dtype=jnp.int64))
|
||||
return inverse_perm
|
||||
|
||||
|
||||
# @jax.jit
|
||||
def matrix_set_0(nonzero_values, nonzero_indices, zero_indices):
|
||||
index_map = inverse_permutation(
|
||||
jnp.concatenate([nonzero_indices, zero_indices]))
|
||||
u_values = jnp.concatenate([nonzero_values, jnp.zeros(len(zero_indices))])
|
||||
new_matrix = u_values[index_map]
|
||||
return new_matrix
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3))
|
||||
def matrix_set_1(nonzero_values, nonzero_indices, origin_matrix_size, zero_size):
|
||||
all_indices = jnp.arange(origin_matrix_size, dtype=jnp.int64)
|
||||
zero_indices = jnp.setdiff1d(
|
||||
all_indices, nonzero_indices, size=zero_size, assume_unique=True)
|
||||
index_map = inverse_permutation(
|
||||
jnp.concatenate([nonzero_indices, zero_indices]))
|
||||
u_values = jnp.concatenate([nonzero_values, jnp.ones(len(zero_indices))])
|
||||
new_matrix = u_values[index_map]
|
||||
return new_matrix
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(1, 2, 10))
|
||||
def objective(xPhys_1D, young, young_min, freedofs, fext, s_K_predefined, dispTD_predefined, idx, edofMat, ke, penal):
|
||||
# xPhys_1D = xPhys.T.flatten()
|
||||
disp = TO_FEA.FEA(young, young_min, ke, xPhys_1D, freedofs,
|
||||
penal, idx, fext, s_K_predefined, dispTD_predefined)
|
||||
obj = TO_Obj.compliance(young, young_min, ke,
|
||||
xPhys_1D, edofMat, penal, disp)
|
||||
# print('obj is running')
|
||||
return obj
|
|
@ -0,0 +1,199 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
|
||||
@ Author : Zeyu Zhang
|
||||
@ Email : zhangzeyu_work@outlook.com
|
||||
@ Date : 2022-10-25 09:30:27
|
||||
@ LastEditTime : 2022-10-25 09:37:58
|
||||
@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Define.py
|
||||
@
|
||||
@ Description :
|
||||
@ Reference :
|
||||
'''
|
||||
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import jax.ops as jops
|
||||
import TO_Cal
|
||||
import TO_FEA
|
||||
|
||||
|
||||
def Toparams(problem):
|
||||
young = 1.0
|
||||
rou = 1.0
|
||||
nelx, nely, ele_length, ele_width = problem.nelx, problem.nely, problem.ele_length, problem.ele_width
|
||||
Total_ele, Total_nod, Total_dof = nelx * \
|
||||
nely, (nelx+1)*(nely+1), 2*(nelx+1)*(nely+1)
|
||||
gobalcoords0xy, M_elenod, M_edofMat, M_iK, M_jK = coordinate_and_preFEA(
|
||||
nelx, nely, ele_length, ele_width, Total_ele, Total_nod)
|
||||
elenod, edofMat, iK, jK = M_elenod-1, M_edofMat-1, M_iK-1, M_jK - \
|
||||
1
|
||||
edofMat_3D = TO_edofMat_3DTensor(nelx, nely)
|
||||
design_map_bool, num_design_ele, num_nondesign_ele, design_area_indices, nondesign_area_indices = handle_design_region(
|
||||
nelx, nely, Total_ele, problem.design_map)
|
||||
params = {
|
||||
'young': young,
|
||||
'young_min': 1e-9*young,
|
||||
'rou': rou,
|
||||
'poisson': 0.3,
|
||||
'g': 0,
|
||||
'ele_length': ele_length,
|
||||
'ele_width': ele_width,
|
||||
'ele_thickness': problem.ele_thickness,
|
||||
'nelx': nelx,
|
||||
'nely': nely,
|
||||
'Total_ele': Total_ele,
|
||||
'Total_nod': Total_nod,
|
||||
'Total_dof': Total_dof,
|
||||
'volfrac': problem.volfrac,
|
||||
'xmin': 0.001,
|
||||
'xmax': 1.0,
|
||||
'design_map': problem.design_map,
|
||||
# 'design_map_bool': design_map_bool,
|
||||
# 'num_design_ele': num_design_ele,
|
||||
# 'num_nondesign_ele': num_nondesign_ele,
|
||||
'design_area_indices': design_area_indices,
|
||||
'nondesign_area_indices': nondesign_area_indices,
|
||||
'freedofs': problem.freedofs,
|
||||
'fixeddofs': problem.fixeddofs,
|
||||
'fext': problem.fext,
|
||||
'penal': 3.0,
|
||||
'filter_width': problem.filter_width,
|
||||
'gobalcoords0xy': gobalcoords0xy,
|
||||
'M_elenod': M_elenod,
|
||||
'M_edofMat': M_edofMat,
|
||||
'M_iK': M_iK,
|
||||
'M_jK': M_jK,
|
||||
'elenod': elenod,
|
||||
'edofMat': edofMat,
|
||||
'edofMat_3D': edofMat_3D,
|
||||
'iK': iK,
|
||||
'jK': jK,
|
||||
'design_condition': problem.design_condition,
|
||||
}
|
||||
return params
|
||||
|
||||
|
||||
class Topology_Optimization:
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.young, self.young_min, self.poisson = args['young'], args['young_min'], args['poisson']
|
||||
self.Total_dof, self.Total_ele, self.nelx, self.nely, self.ele_thickness = args[
|
||||
'Total_dof'], args['Total_ele'], args['nelx'], args['nely'], args['ele_thickness']
|
||||
self.freedofs, self.fext, self.iK, self.jK, self.edofMat = args[
|
||||
'freedofs'], args['fext'], args['iK'], args['jK'], args['edofMat']
|
||||
self.filter_width, self.design_condition, self.volfrac = args[
|
||||
'filter_width'], args['design_condition'], args['volfrac']
|
||||
self.design_map, self.design_area_indices, self.nondesign_area_indices = args[
|
||||
'design_map'], args['design_area_indices'], args['nondesign_area_indices']
|
||||
self.ke = TO_FEA.linear_stiffness_matrix(
|
||||
self.young, self.poisson, self.ele_thickness)
|
||||
self.filter_kernal, self.filter_weight = TO_Cal.filter_define(
|
||||
self.design_map, self.filter_width)
|
||||
self.s_K_predefined, self.dispTD_predefined, self.idx = jnp.zeros(
|
||||
[self.Total_dof, self.Total_dof]), jnp.zeros([self.Total_dof]), jops.index[self.iK, self.jK]
|
||||
self.loop_1, self.iCont = 0, 0
|
||||
self.penal_use, self.beta_x_use = 3.0, 1.0
|
||||
self.beta_x_use_max, self.beta_x_deltath1, self.beta_x_deltath2 = 16, 50, 30,
|
||||
self.xPhys = self.volfrac * jnp.ones([self.nely, self.nelx])
|
||||
|
||||
def xPhys_transform(self, params, projection=False):
|
||||
x = params.reshape(self.nely, self.nelx)
|
||||
beta_x = self.penal_and_betax()
|
||||
xPhys_1D = TO_Cal.design_variable(x, self.design_area_indices, self.nondesign_area_indices,
|
||||
self.filter_kernal, self.filter_weight, beta_x, self.volfrac, projection=projection)
|
||||
return xPhys_1D
|
||||
|
||||
def objective(self, params, projection=False):
|
||||
self.loop_1 = self.loop_1 + 1
|
||||
xPhys_1D = self.xPhys_transform(params, projection=projection)
|
||||
Obj = TO_Cal.objective(xPhys_1D, self.young, self.young_min, self.freedofs, self.fext,
|
||||
self.s_K_predefined, self.dispTD_predefined, self.idx, self.edofMat, self.ke, self.penal_use)
|
||||
self.xPhys = xPhys_1D.reshape(self.nely, self.nelx, order='F')
|
||||
return Obj
|
||||
|
||||
def penal_and_betax(self, *args):
|
||||
self.iCont = self.iCont + 1
|
||||
if (self.iCont > self.beta_x_deltath1 and self.beta_x_use < 2):
|
||||
self.beta_x_use = self.beta_x_use * 2
|
||||
self.iCont = 1
|
||||
print(' beta_x increased to :%7.3f\n' % (self.beta_x_use))
|
||||
elif (self.iCont > self.beta_x_deltath2 and self.beta_x_use < self.beta_x_use_max and self.beta_x_use >= 2):
|
||||
self.beta_x_use = self.beta_x_use * 2
|
||||
self.iCont = 1
|
||||
print(' beta_x increased to :%7.3f\n' % (self.beta_x_use))
|
||||
return self.beta_x_use
|
||||
|
||||
|
||||
def coordinate_and_preFEA(nelx, nely, ele_length, ele_width, Total_ele, Total_nod):
|
||||
i0, j0 = jnp.meshgrid(jnp.arange(nelx+1), jnp.arange(nely+1))
|
||||
gobalcoords0x = i0*ele_length
|
||||
gobalcoords0y = ele_width*nely-j0*ele_width
|
||||
gobalcoords0xy = jnp.hstack((matrix_array(gobalcoords0x), matrix_array(
|
||||
gobalcoords0y)))
|
||||
gobalcoords0xyT = gobalcoords0xy.T
|
||||
gobalcoords0 = matrix_array(gobalcoords0xyT)
|
||||
nodegrd = matrix_order_arrange(1, Total_nod, nely+1, nelx+1)
|
||||
nodelast = Matlab_reshape(Matlab_matrix_extract(
|
||||
nodegrd, 1, -1, 1, -1), Total_ele, 1)
|
||||
edofVec = 2*matrix_array(nodelast)+1
|
||||
elenod = jnp.tile(nodelast, (1, 4)) + \
|
||||
jnp.tile(jnp.array([1, nely+2, nely+1, 0]), (Total_ele, 1))
|
||||
edofMat = jnp.tile(edofVec, (1, 8))+jnp.tile(jnp.array(
|
||||
[0, 1, 2*nely+2, 2*nely+3, 2*nely+0, 2*nely+1, -2, -1]), (Total_ele, 1))
|
||||
iK = jnp.kron(edofMat, jnp.ones((8, 1))).flatten()
|
||||
jK = jnp.kron(edofMat, jnp.ones((1, 8))).flatten()
|
||||
iK_int = iK.astype(jnp.int32)
|
||||
jK_int = jK.astype(jnp.int32)
|
||||
return gobalcoords0xy, elenod, edofMat, iK_int, jK_int
|
||||
|
||||
|
||||
def handle_design_region(nelx, nely, Total_ele, design_map):
|
||||
shape = (nely, nelx)
|
||||
design_map_bool = np.broadcast_to(design_map, shape) > 0
|
||||
num_design_ele = np.sum(design_map)
|
||||
num_nondesign_ele = Total_ele - num_design_ele
|
||||
|
||||
design_map_bool_1D = design_map_bool.flatten(order='F')
|
||||
all_indices = np.arange(Total_ele, dtype=jnp.int64)
|
||||
design_area_indices = jnp.flatnonzero(
|
||||
design_map_bool_1D, size=num_design_ele)
|
||||
nondesign_area_indices = jnp.setdiff1d(
|
||||
all_indices, design_area_indices, size=num_nondesign_ele, assume_unique=True) # setdiff1d通常不支持jit 需要指定size
|
||||
return design_map_bool, num_design_ele, num_nondesign_ele, design_area_indices, nondesign_area_indices
|
||||
|
||||
|
||||
def Matlab_matrix_array1D(matrixA):
|
||||
return matrixA.T.flatten()
|
||||
|
||||
|
||||
def matrix_array(matrixA):
|
||||
return matrixA.T.reshape(-1, 1)
|
||||
|
||||
|
||||
def matrix_order_arrange(numbegin, numend, ydim, xdim):
|
||||
matrixA = jnp.array([[i for i in range(numbegin, numend+1)]])
|
||||
matrixA = matrixA.reshape(xdim, ydim)
|
||||
matrixA = matrixA.T
|
||||
return matrixA
|
||||
|
||||
|
||||
def Matlab_reshape(matrixA, ydim, xdim):
|
||||
return matrixA.reshape((ydim, xdim), order='F')
|
||||
|
||||
|
||||
def Matlab_matrix_extract(matrixA, xbegin, xend, ybegin, yend):
|
||||
matrixA = matrixA[ybegin-1:yend, xbegin-1:xend]
|
||||
return matrixA
|
||||
|
||||
|
||||
def TO_edofMat_3DTensor(nelx, nely):
|
||||
ely, elx = jnp.meshgrid(jnp.arange(nely), jnp.arange(nelx)) # x, y coords
|
||||
n1 = (nely+1)*(elx+0) + (ely+0)
|
||||
n2 = (nely+1)*(elx+1) + (ely+0)
|
||||
n3 = (nely+1)*(elx+1) + (ely+1)
|
||||
n4 = (nely+1)*(elx+0) + (ely+1)
|
||||
edofMat_3D = jnp.array(
|
||||
[2*n4, 2*n4+1, 2*n3, 2*n3+1, 2*n2, 2*n2+1, 2*n1, 2*n1+1])
|
||||
return edofMat_3D
|
|
@ -0,0 +1,88 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
|
||||
@ Author : Zeyu Zhang
|
||||
@ Email : zhangzeyu_work@outlook.com
|
||||
@ Date : 2022-10-25 09:30:27
|
||||
@ LastEditTime : 2022-10-25 09:38:08
|
||||
@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_FEA.py
|
||||
@
|
||||
@ Description :
|
||||
@ Reference :
|
||||
'''
|
||||
|
||||
import numpy as np
|
||||
from numpy import float64
|
||||
import jax.numpy as jnp
|
||||
import jax.ops as jops
|
||||
import jax.scipy.linalg as jsl
|
||||
from jax import jit
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
def linear_stiffness_matrix(young, poisson, ele_thickness):
|
||||
young, poisson, h = young, poisson, ele_thickness
|
||||
k = np.array([1/2-poisson/6, 1/8+poisson/8, -1/4-poisson/12, -1/8+3*poisson/8,
|
||||
-1/4+poisson/12, -1/8-poisson/8, poisson/6, 1/8-3*poisson/8])
|
||||
stiffness_ele = h/(1-poisson**2)*np.array([[k[0], k[1], k[2], k[3], k[4], k[5], k[6], k[7]],
|
||||
[k[1], k[0], k[7], k[6],
|
||||
k[5], k[4], k[3], k[2]],
|
||||
[k[2], k[7], k[0], k[5],
|
||||
k[6], k[3], k[4], k[1]],
|
||||
[k[3], k[6], k[5], k[0],
|
||||
k[7], k[2], k[1], k[4]],
|
||||
[k[4], k[5], k[6], k[7],
|
||||
k[0], k[1], k[2], k[3]],
|
||||
[k[5], k[4], k[3], k[2],
|
||||
k[1], k[0], k[7], k[6]],
|
||||
[k[6], k[3], k[4], k[1],
|
||||
k[2], k[7], k[0], k[5]],
|
||||
[k[7], k[2], k[1], k[4],
|
||||
k[3], k[6], k[5], k[0]]
|
||||
], dtype=float64)
|
||||
return stiffness_ele
|
||||
|
||||
|
||||
def density_matrix(rou, ele_length, ele_width, ele_thickness):
|
||||
ele_volume = ele_length*ele_width*ele_thickness
|
||||
density_ele = ele_volume / 36 * np.array([[4, 0, 2, 0, 1, 0, 2, 0],
|
||||
[0, 4, 0, 2,
|
||||
0, 1, 0, 2],
|
||||
[2, 0, 4, 0,
|
||||
2, 0, 1, 0],
|
||||
[0, 2, 0, 4,
|
||||
0, 2, 0, 1],
|
||||
[1, 0, 2, 0,
|
||||
4, 0, 2, 0],
|
||||
[0, 1, 0, 2,
|
||||
0, 4, 0, 2],
|
||||
[2, 0, 1, 0,
|
||||
2, 0, 4, 0],
|
||||
[0, 2, 0, 1,
|
||||
0, 2, 0, 4],
|
||||
], dtype=float64)
|
||||
return density_ele
|
||||
|
||||
|
||||
# @jit
|
||||
def FEA(young, young_min, ke, xPhys_1D, freedofs, penal, idx, fext, s_K_predefined, dispTD_predefined):
|
||||
|
||||
def kuf_solve(fext_1D_free, s_K_free):
|
||||
u_free = jsl.solve(s_K_free, fext_1D_free,
|
||||
sym_pos=True, check_finite=False)
|
||||
return u_free
|
||||
|
||||
def young_modulus(xPhys_1D, young, young_min, ke, penal):
|
||||
|
||||
sK = ((ke.flatten()[jnp.newaxis]).T * (young_min + (xPhys_1D) ** penal * (young-young_min))
|
||||
).flatten(order='F')
|
||||
return sK
|
||||
|
||||
sK = young_modulus(xPhys_1D, young, young_min, ke, penal)
|
||||
s_K = jops.index_add(s_K_predefined, idx, sK)
|
||||
s_K_free = s_K[freedofs, :][:, freedofs]
|
||||
fext_free = fext[freedofs]
|
||||
u_free = kuf_solve(fext_free, s_K_free)
|
||||
dispTD = jops.index_add(dispTD_predefined, freedofs, u_free)
|
||||
return dispTD
|
|
@ -0,0 +1,177 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
|
||||
@ Author : Zeyu Zhang
|
||||
@ Email : zhangzeyu_work@outlook.com
|
||||
@ Date : 2022-10-25 09:30:27
|
||||
@ LastEditTime : 2022-10-25 09:36:54
|
||||
@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Model.py
|
||||
@
|
||||
@ Description :
|
||||
@ Reference :
|
||||
'''
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.image as ji
|
||||
import jax.tree_util as jtree
|
||||
from jax import random
|
||||
import flax.linen as nn
|
||||
import flax.linen.initializers as fli
|
||||
import flax.traverse_util as flu
|
||||
import flax.core.frozen_dict as fcfd
|
||||
from functools import partial
|
||||
from typing import Any, Callable
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
if seed is not None:
|
||||
rand_key = random.PRNGKey(seed)
|
||||
return rand_key
|
||||
|
||||
|
||||
def constant_array(rng_key, shape, dtype=jnp.float64, value=0.0):
|
||||
out = jnp.full(shape, value, dtype)
|
||||
# print("i am running")
|
||||
return out
|
||||
|
||||
|
||||
def extract_model_params(params_use):
|
||||
params_unfreeze = params_use.unfreeze()
|
||||
extracted_params = {
|
||||
'/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()}
|
||||
params_shape = jtree.tree_map(jnp.shape, extracted_params)
|
||||
params_total_num = sum(p.size for p in jtree.tree_leaves(extracted_params))
|
||||
return extracted_params, params_shape, params_total_num
|
||||
|
||||
|
||||
def replace_model_params(pre_trained_params_dict, params_use):
|
||||
params_unfreeze = params_use.unfreeze()
|
||||
#
|
||||
extracted_params = {
|
||||
'/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()}
|
||||
for key, val in pre_trained_params_dict.items():
|
||||
extracted_params[key] = val
|
||||
new_params_use = flu.unflatten_dict(
|
||||
{tuple(k.split('/')): v for k, v in extracted_params.items()})
|
||||
#
|
||||
params_leaves, params_struct = jtree.tree_flatten(params_unfreeze)
|
||||
i = 0
|
||||
for key, val in pre_trained_params_dict.items():
|
||||
params_leaves[i] = val
|
||||
i = i + 1
|
||||
new_params_use = jtree.tree_unflatten(params_struct, params_leaves)
|
||||
#
|
||||
new_params_use = fcfd.freeze(new_params_use)
|
||||
return new_params_use
|
||||
|
||||
|
||||
def batched_loss(x_inputs, TopOpt_envs):
|
||||
loss_list = [TopOpt.objective(x_inputs[i], projection=True)
|
||||
for i, TopOpt in enumerate(TopOpt_envs)]
|
||||
losses_array = jnp.stack(loss_list)[0]
|
||||
return losses_array
|
||||
|
||||
|
||||
# @jax.jit
|
||||
class Normalize_TO(nn.Module):
|
||||
|
||||
epsilon: float = 1e-6
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input):
|
||||
input_mean = jnp.mean(input, axis=(1, 2, 3), keepdims=True)
|
||||
input_var = jnp.var(input, axis=(1, 2, 3), keepdims=True)
|
||||
output = input - input_mean
|
||||
output = output * jax.lax.rsqrt(input_var + self.epsilon)
|
||||
return output
|
||||
|
||||
|
||||
class Add_offset(nn.Module):
|
||||
scale: int = 10
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input):
|
||||
add_bias = self.param('add_bias', lambda rng, shape: constant_array(
|
||||
rng, shape, value=0.0), input.shape)
|
||||
return input + self.scale * add_bias
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
resize_scale_one: int
|
||||
conv_output_one: int
|
||||
norm_use: nn.Module = Normalize_TO
|
||||
resize_method: str = "bilinear"
|
||||
kernel_size: int = 5
|
||||
conv_kernel_init: partial = fli.variance_scaling(
|
||||
scale=1.0, mode="fan_in", distribution="truncated_normal")
|
||||
offset: nn.Module = Add_offset
|
||||
offset_scale: int = 10
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input):
|
||||
output = jnp.tanh(input)
|
||||
output_shape = output.shape
|
||||
output = ji.resize(output, (output_shape[0], self.resize_scale_one * output_shape[1],
|
||||
self.resize_scale_one * output_shape[2], output_shape[-1]), method=self.resize_method, antialias=True)
|
||||
output = self.norm_use()(output)
|
||||
output = nn.Conv(features=self.conv_output_one, kernel_size=(
|
||||
self.kernel_size, self.kernel_size), kernel_init=self.conv_kernel_init)(output)
|
||||
output = self.offset(scale=self.offset_scale)(output)
|
||||
return output
|
||||
|
||||
|
||||
class My_TO_Model(nn.Module):
|
||||
TopOpt: Any
|
||||
seed: int = 1
|
||||
replace_params: Callable = replace_model_params
|
||||
extract_params: Callable = extract_model_params
|
||||
|
||||
def loss(self, input_TO):
|
||||
obj_TO = batched_loss(input_TO, [self.TopOpt])
|
||||
return obj_TO
|
||||
|
||||
|
||||
class TO_CNN(My_TO_Model):
|
||||
h: int = 5
|
||||
w: int = 10
|
||||
dense_scale_init: Any = 1.0
|
||||
dense_output_channel: int = 1000
|
||||
dtype: Any = jnp.float32
|
||||
input_init_std: float = 1.0
|
||||
dense_input_channel: int = 128
|
||||
conv_input_0: int = 32
|
||||
conv_output: tuple = (128, 64, 32, 16, 1)
|
||||
up_sampling_scale: tuple = (1, 2, 2, 2, 1)
|
||||
offset_scale: int = 10
|
||||
block: nn.Module = BasicBlock
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x=None): # [N, H, W, C]
|
||||
nn_input = self.param('z', lambda rng, shape: constant_array(
|
||||
rng, shape, value=self.TopOpt.volfrac), self.dense_input_channel)
|
||||
|
||||
output = nn.Dense(features=self.dense_output_channel, kernel_init=fli.orthogonal(
|
||||
scale=self.dense_scale_init))(nn_input)
|
||||
output = output.reshape([1, self.h, self.w, self.conv_input_0])
|
||||
|
||||
output = self.block(
|
||||
self.up_sampling_scale[0], self.conv_output[0])(output)
|
||||
output = self.block(
|
||||
self.up_sampling_scale[1], self.conv_output[1])(output)
|
||||
output = self.block(
|
||||
self.up_sampling_scale[2], self.conv_output[2])(output)
|
||||
output = self.block(
|
||||
self.up_sampling_scale[3], self.conv_output[3])(output)
|
||||
output = self.block(
|
||||
self.up_sampling_scale[4], self.conv_output[4])(output)
|
||||
|
||||
nn_output = jnp.squeeze(output, axis=-1)
|
||||
self.sow('intermediates', 'model_out',
|
||||
nn_output)
|
||||
return nn_output
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
|
||||
@ Author : Zeyu Zhang
|
||||
@ Email : zhangzeyu_work@outlook.com
|
||||
@ Date : 2022-10-25 09:30:27
|
||||
@ LastEditTime : 2022-10-25 09:34:00
|
||||
@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Obj.py
|
||||
@
|
||||
@ Description :
|
||||
@ Reference :
|
||||
'''
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
|
||||
|
||||
# @jit
|
||||
def compliance(young, young_min, ke, xPhys_1D, edofMat, penal, disp):
|
||||
ce = jnp.sum(jnp.matmul(disp[edofMat], ke) * disp[edofMat], 1)
|
||||
ce = (young_min + xPhys_1D ** penal * (young - young_min)) * ce
|
||||
obj = jnp.sum(ce)
|
||||
return obj
|
|
@ -0,0 +1,87 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
|
||||
@ Author : Zeyu Zhang
|
||||
@ Email : zhangzeyu_work@outlook.com
|
||||
@ Date : 2022-10-25 09:30:27
|
||||
@ LastEditTime : 2022-10-25 09:33:20
|
||||
@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Problem.py
|
||||
@
|
||||
@ Description :
|
||||
@ Reference :
|
||||
'''
|
||||
|
||||
import jax.ops as jops
|
||||
import numpy as np
|
||||
import dataclasses
|
||||
from typing import Optional, Union
|
||||
import jax.numpy as jnp
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ToProblem:
|
||||
L: int
|
||||
W: int
|
||||
nelx: int
|
||||
nely: int
|
||||
design_condition: int
|
||||
ele_length: np.float64
|
||||
ele_width: np.float64
|
||||
ele_thickness: np.float64
|
||||
volfrac: np.float64
|
||||
fixeddofs: np.ndarray
|
||||
freedofs: np.ndarray
|
||||
fext: np.ndarray
|
||||
filter_width: float
|
||||
design_map: np.ndarray
|
||||
|
||||
|
||||
def cantilever_single():
|
||||
L = 0.8
|
||||
W = 0.4
|
||||
nelx = 80
|
||||
nely = 40
|
||||
design_condition = 1
|
||||
Total_dof = 2*(nelx+1)*(nely+1)
|
||||
ele_length = L/nelx
|
||||
ele_width = W/nely
|
||||
ele_thickness = 1.0
|
||||
volfrac = 0.5
|
||||
dofs = np.arange(Total_dof)
|
||||
loaddof = Total_dof-1
|
||||
# fext = jnp.zeros(Total_dof)
|
||||
# fext = jops.index_update(fext, jops.index[loaddof], -1.)
|
||||
fext = np.zeros(Total_dof)
|
||||
fext[loaddof] = -1
|
||||
fixeddofs = np.array(dofs[0:2*(nely+1):1])
|
||||
freedofs = np.setdiff1d(dofs, fixeddofs)
|
||||
filter_width = 3.0
|
||||
design_map = np.ones([nely, nelx], dtype=np.int64)
|
||||
return ToProblem(L, W, nelx, nely, design_condition, ele_length, ele_width, ele_thickness, volfrac, fixeddofs, freedofs, fext, filter_width, design_map)
|
||||
|
||||
|
||||
def cantilever_single_NN(TopOpt):
|
||||
dense_input_channel = 128
|
||||
dense_init_scale = 1.0
|
||||
conv_input_0 = 32
|
||||
up_sampling_scale = (1, 2, 2, 2, 1)
|
||||
total_resize = int(np.prod(up_sampling_scale))
|
||||
h = TopOpt.nely // total_resize
|
||||
w = TopOpt.nelx // total_resize
|
||||
dense_output_channel = h * w * conv_input_0
|
||||
dense_scale_init = dense_init_scale * \
|
||||
jnp.sqrt(jnp.maximum(dense_output_channel / dense_input_channel, 1))
|
||||
model_kwargs = {
|
||||
'h': h,
|
||||
'w': w,
|
||||
'dense_scale_init': dense_scale_init,
|
||||
'dense_output_channel': dense_output_channel,
|
||||
'dense_input_channel': dense_input_channel,
|
||||
'conv_input_0': conv_input_0,
|
||||
'up_sampling_scale': up_sampling_scale,
|
||||
}
|
||||
return model_kwargs
|
||||
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
|
||||
@ Author : Zeyu Zhang
|
||||
@ Email : zhangzeyu_work@outlook.com
|
||||
@ Date : 2022-10-25 09:30:27
|
||||
@ LastEditTime : 2022-10-25 09:37:03
|
||||
@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Train.py
|
||||
@
|
||||
@ Description :
|
||||
@ Reference :
|
||||
'''
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from absl import logging
|
||||
import xarray
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
from jaxopt import OptaxSolver
|
||||
import optax
|
||||
import time
|
||||
from functools import partial
|
||||
import matplotlib.pyplot as plt
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
sys.path.append(
|
||||
'/home/zzy/ZZY_CODE/Env_JAX/Topology_Optimization/Linear_TO')
|
||||
here = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
if seed is not None:
|
||||
rand_key = random.PRNGKey(seed)
|
||||
return rand_key
|
||||
|
||||
|
||||
def optimizer_result_dataset(losses, frames, save_intermediate_designs=False):
|
||||
best_design = jnp.nanargmin(losses)
|
||||
logging.info(f'Final loss: {losses[best_design]}')
|
||||
if save_intermediate_designs:
|
||||
ds = xarray.Dataset({
|
||||
'loss': (('step',), losses),
|
||||
'design': (('step', 'y', 'x'), frames),
|
||||
}, coords={'step': jnp.arange(len(losses))})
|
||||
else:
|
||||
ds = xarray.Dataset({
|
||||
'loss': (('step',), losses),
|
||||
'design': (('y', 'x'), frames[best_design]),
|
||||
}, coords={'step': jnp.arange(len(losses))})
|
||||
return ds
|
||||
|
||||
|
||||
def train_TO_Optax(model, max_iterations, save_intermediate_designs=True, seed=1, **kwargs):
|
||||
losses = []
|
||||
frames = []
|
||||
# Initialize parameters
|
||||
init_params = model.init(set_random_seed(seed), None)
|
||||
batch_state, params = init_params.pop("params")
|
||||
del init_params
|
||||
# Instantiate Optimizer
|
||||
design_condition = model.TopOpt.design_condition
|
||||
if design_condition == 1:
|
||||
learning_rate = 0.001
|
||||
optimizer = optax.adam(learning_rate)
|
||||
elif design_condition == 2:
|
||||
learning_rate = 0.001
|
||||
optimizer = optax.adam(learning_rate)
|
||||
# loss
|
||||
def loss_cal(params):
|
||||
all_params = {"params": params}
|
||||
model_out, net_state = model.apply(all_params, None, mutable="intermediates")
|
||||
loss_out = model.loss(model_out)
|
||||
return loss_out, net_state["intermediates"]
|
||||
loss_grad_fn = jax.value_and_grad(loss_cal, has_aux=True)
|
||||
# Initialize Optimizer State
|
||||
state = optimizer.init(params)
|
||||
# Updated
|
||||
for iter in range(max_iterations):
|
||||
start_time_epoch = time.perf_counter()
|
||||
(loss_val, batch_stats), grads = loss_grad_fn(params)
|
||||
losses.append(jax.device_get(loss_val))
|
||||
updates_params, state = optimizer.update(grads, state)
|
||||
params = optax.apply_updates(params, updates_params)
|
||||
design_TO = jax.device_get(model.TopOpt.xPhys.aval.val)
|
||||
frames.append(design_TO)
|
||||
plt.figure()
|
||||
plt.imshow(design_TO, cmap='Greys')
|
||||
plt.show()
|
||||
whole_time_epoch = time.perf_counter() - start_time_epoch
|
||||
print(' It.:%5i Obj.:%11.4f Vol.:%7.3f Time.:%7.3f\n' %
|
||||
(iter+1, losses[-1], jnp.mean(design_TO), whole_time_epoch))
|
||||
|
||||
return optimizer_result_dataset(jnp.array(losses), jnp.array(frames), save_intermediate_designs)
|
Loading…
Reference in New Issue