2022.8.16 Commit

This commit is contained in:
zhangzeyu 2022-08-16 15:35:58 +08:00
commit da7820895e
16 changed files with 1287 additions and 0 deletions

View File

@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
'''
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
@ Author : Zeyu Zhang
@ Email : zhangzeyu_work@outlook.com
@ Date : 2021-11-17 10:52:26
@ LastEditTime : 2022-08-16 14:45:12
@ FilePath : /Topology_Optimization/Linear_TO/TONR_Linear_Stiffness.py
@
@ Description : Jax+Flax+optax TONR Linear Stiffness Problem
@ Reference :
'''
from pathlib import Path # 文件路径的相关操作
import sys
import pandas as pd
import xarray
import matplotlib.pyplot as plt
import seaborn
import time
import TO_Problem
import TO_Define
import TO_Model
import TO_Train
import jax.numpy as jnp
from jax import random
from jax.config import config
config.update("jax_enable_x64", True) # 对计算精度影响很大
sys.path.append(
'/home/zzy/ZZY_CODE/Env_JAX/Topology_Optimization/Linear_TO')
here = Path(__file__).resolve().parent
design_condition = 1
if design_condition == 1:
problem = TO_Problem.cantilever_single()
max_iterations = 200
args = TO_Define.Toparams(problem)
TopOpt_env = TO_Define.Topology_Optimization(args)
model_kwargs = TO_Problem.cantilever_single_NN(TopOpt_env)
elif design_condition == 2:
problem = TO_Problem.cantilever_single_big()
max_iterations = 200
args = TO_Define.Toparams(problem)
TopOpt_env = TO_Define.Topology_Optimization(args)
model_kwargs = TO_Problem.cantilever_single_big_NN(TopOpt_env)
TONR_model = TO_Model.TO_CNN(TopOpt=TopOpt_env, **model_kwargs)
if __name__ == '__main__':
start_time = time.perf_counter()
seed = 1
ds_TopOpt = TO_Train.train_TO_Optax(
TONR_model, max_iterations, save_intermediate_designs=True, seed=1)
ds_TopOpt.to_netcdf('ds_TopOpt.nc') # save
# ds_TopOpt = xarray.open_dataset('ds_TopOpt.nc') # load
end_time = time.perf_counter()
whole_time = end_time - start_time
obj_value = ds_TopOpt.loss.data
x_designed = ds_TopOpt.design.data
obj_min = jnp.min(obj_value)
obj_max = jnp.max(obj_value)
obj_min_index = jnp.where(obj_value == obj_min)[0][0]
dims = pd.Index(['TONR'], name='model')
ds = xarray.concat([ds_TopOpt], dim=dims)
ds_TopOpt.loss.transpose().to_pandas().cummin().plot(linewidth=2)
plt.ylim(obj_min * 0.85, obj_value[0] * 1.15)
plt.ylabel('Compliance (loss)')
plt.xlabel('Optimization step')
seaborn.despine()
final_design = x_designed[obj_min_index, :, :]
final_obj = obj_min
plt.show()
ds_TopOpt.design.sel(step=obj_min_index).plot.imshow(x='x', y='y', size=2,
aspect=2.5, col_wrap=2, yincrease=False, add_colorbar=False, cmap='Greys')
def save_gif_movie(images, path, duration=200, loop=0, **kwargs):
images[0].save(path, save_all=True, append_images=images[1:],
duration=duration, loop=loop, **kwargs)
# images = [
# TO_Support.image_from_design(design, problem)
# for design in ds.design.sel(model='DL_TO')[:150]
# ]
# save_gif_movie([im.resize((5*120, 5*20)) for im in images], 'movie.gif')

253
Linear_TO/TO_Cal.py Normal file
View File

@ -0,0 +1,253 @@
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 23 15:08:01 2020
@author: zzy
Object: 存储TO计算过程所需的函数
1. 所有函数均可jit
2. 在调用时 选择在最外层封装jit 内层函数则将jit注释
3. 对于部分确定的static variable 应采用numpy定义(仅限CPU)
"""
import numpy as np
import jax
import jax.numpy as jnp
import jax.scipy.signal as jss
import jax.lax as jl
from jaxopt import Bisection
from functools import partial
import TO_Private_Autodiff as TPA
import TO_FEA
import TO_Obj
from jax.config import config
config.update("jax_enable_x64", True)
@partial(jax.jit, static_argnums=(5, 6, 7))
def design_variable(x, design_area_indices, nondesign_area_indices, filter_kernal, filter_weight, beta_x, volfrac, projection=False):
"""
Parameters
----------
x : shape:[nely, nelx](2D) 经过reshape的NN输出
design_area_indices : 设计域的单元编号索引
nondesign_area_indices : 非设计域的单元编号索引
filter_kernal : TOP88中的H 卷积过滤中的卷积核 计算影响域内每一个单元的权重
filter_weight : TOP88中的HS 设计域内每一个单元的总权重 与网格划分同维度
beta_x : Projection参数
volfrac : 目标体积
projection : 是否采用Projection
Returns
-------
xPhys_1D : shape:[Total_ele](1D) 参与有限元计算的单元相对密度阵
"""
# density filtering 用在了转换之前 目前看也是有效果的
x = filter(x, filter_kernal, filter_weight)
x_1D = x.flatten(order='F')
x_designed = design_and_volume_constraints(
x_1D[design_area_indices], volfrac, beta_x)
xPhys_1D = matrix_set_0(
x_designed, design_area_indices, nondesign_area_indices)
# 直接等于xPhys_1D 但尚需要在真实存在非设计域时验证
# 原始方案
# x_in = x[design_map_bool]
# x_designed = design_and_volume_constraints(x_in, volfrac, beta_x) # x[design_map_bool]内含了一步matrix.flatten()
# x_flat = matrix_set_0(x_designed, jnp.flatnonzero(design_map_bool, size=num_design_ele), Total_ele, num_nondesign_ele)
# 因为x_designed内含flatten() 所以这里对应要复原 由于flatten时没有传入order='F' 这里也不能传入
# xPhys = x_flat.reshape(x.shape)
return xPhys_1D
def projection(x, beta_TONR):
"""
实现projection [-z, z]的数值映射至(0, 1)
目前
原理 tanh能够将元素映射至(-1, 1) 因此 tanh_out * 0.5 + 0.5可以映射至(0, 1)
注意 sigmoid(x) = 0.5 * tanh(0.5 * x) + 0.5
Parameters
----------
x : 待映射的矩阵
beta_TONR : projection函数斜率的控制参数
Returns
-------
output : sigmoid/projection变换后的矩阵
"""
output = jnp.tanh(beta_TONR * 0.5 * x)*.5 + 0.5
return output
def logit(obj_V):
"""
根据obj_V确定sigmoid分母部分e ** (x-b) b搜索的初始上下界
np.clip(a, a_min, a_max) 将a的取值截断在(a_min, a_max)
Parameters
----------
obj_V : 目标体积
Returns
-------
"""
p = jnp.clip(
obj_V, 0, 1)
return jnp.log(p) - jnp.log1p(-p) # p从-inf到inf p取0.5为0 以0.5为界限 大于小于符号相反 数值相同
# @partial(jax.jit, static_argnums=(1, 2))
def design_and_volume_constraints(x, obj_V, beta_TONR):
"""
用来将神经网络的输出转化为满足设计约束 体积约束和projection的用于有限元计算的单元密度
x_design = 1/(1 + e ** (x - b(x, obj_V)))
Parameters
----------
x : 经过密度过滤后的单元相对密度阵
obj_V : 目标体积.
beta_TONR : Projection参数
Returns
-------
x_design : 满足设计约束和体积约束的相对密度 向量形式
"""
def find_yita(y, x, beta):
projection_out = projection(x + y, beta)
volume_diff = jnp.mean(projection_out) - obj_V
return volume_diff
lower_bound = logit(obj_V) - jl.stop_gradient(np.max(x))
upper_bound = logit(obj_V) - jl.stop_gradient(np.min(x))
bisec = Bisection(optimality_fun=find_yita, lower=lower_bound, upper=upper_bound, tol=1e-12, maxiter=64, check_bracket=False)
yita = bisec.run(x=x, beta=beta_TONR).params
# yita = TPA.binary_search(find_yita, x, beta_TONR, lower_bound, upper_bound) # 自编二分法
# 用来搜索yita的取值 由二分法获得 本质上这里和保体积projection搜索yita的过程完全相同
# 实验和公式表明 在保体积projection和此处添加体积约束的sigmoid中 对yita的求解均是一个关于(x, obj_V)的函数
# 其敏度必然要考虑
x_design = projection(x + yita, beta_TONR)
x_design = x_design + 0.00001
return x_design
def filter_define(design_map, filter_width):
"""
定义TO中过滤器的卷积核 权重等参数
Parameters
----------
design_map : 设计域和非设计域定义矩阵
filter_width : 过滤半径
Returns
-------
filter_kernal : TOP88中的H 卷积过滤中的卷积核 计算影响域内每一个单元的权重
filter_weight : TOP88中的HS 设计域内每一个单元的总权重 与网格划分同维度
"""
dy, dx = jnp.meshgrid(jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width)),
jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width)))
filter_kernal = jnp.maximum(0, filter_width - jnp.sqrt(dx ** 2 + dy ** 2))
filter_weight = jss.convolve2d(design_map, filter_kernal, 'same')
return filter_kernal, filter_weight
# @jax.jit
def filter(matrixA, filter_kernal, filter_weight):
"""
Parameters
----------
matrixA : 待过滤的矩阵
filter_kernal : TOP88中的H 卷积过滤中的卷积核 计算影响域内每一个单元的权重
filter_weight : TOP88中的HS 设计域内每一个单元的总权重 与网格划分同维度
Returns
-------
new_matrix : 过滤后的矩阵
"""
new_matrix = jss.convolve2d(matrixA, filter_kernal, 'same') / filter_weight
return new_matrix
def inverse_permutation(indices):
inverse_perm = jnp.zeros(len(indices), dtype=jnp.int64)
inverse_perm = inverse_perm.at[indices].set(
jnp.arange(len(indices), dtype=jnp.int64))
return inverse_perm
# @jax.jit
def matrix_set_0(nonzero_values, nonzero_indices, zero_indices):
"""
补齐矩阵 对原始矩阵指定区域置0 可用于非设计域进行置0 以及节点位移添加指定节点的位移
Parameters
----------
nonzero_values : 已存在数值的矩阵 如设计域的相对密度阵 freedofs对应的节点位移
nonzero_indices : size=存在数值的个数 已存在数值矩阵对应的索引编号 如设计域对应的索引
注意 这里为根据design_map自动生成的 是基于行展开编号的 仅用于此处计算 并非真正TO中的单元编号
zero_indices : 矩阵中未存在数值的单元对应的索引编号
origin_matrix_size : 原始矩阵元素个数 即包含设计域和非设计域
zero_size : 矩阵中0元素的个数 代表非设计域的单元个数
Returns
-------
new_matrix : 补齐后的矩阵
"""
# all_indices = np.arange(origin_matrix_size, dtype=jnp.int64)
# zero_indices = jnp.setdiff1d(
# all_indices, nonzero_indices, size=zero_size, assume_unique=True) # setdiff1d通常不支持jit 需要指定size
index_map = inverse_permutation(
jnp.concatenate([nonzero_indices, zero_indices]))
u_values = jnp.concatenate([nonzero_values, jnp.zeros(len(zero_indices))])
new_matrix = u_values[index_map]
return new_matrix
@partial(jax.jit, static_argnums=(2, 3))
def matrix_set_1(nonzero_values, nonzero_indices, origin_matrix_size, zero_size):
"""
补齐矩阵 对原始矩阵指定区域置1
Parameters
----------
nonzero_values : 已存在数值的矩阵
nonzero_indices : 已存在数值矩阵对应的索引编号
origin_matrix_size : 原始矩阵元素个数 即包含设计域和非设计域
zero_size : 矩阵中0元素的个数 代表非设计域的单元个数
Returns
-------
new_matrix : 补齐后的矩阵
"""
all_indices = jnp.arange(origin_matrix_size, dtype=jnp.int64)
zero_indices = jnp.setdiff1d(
all_indices, nonzero_indices, size=zero_size, assume_unique=True)
index_map = inverse_permutation(
jnp.concatenate([nonzero_indices, zero_indices]))
u_values = jnp.concatenate([nonzero_values, jnp.ones(len(zero_indices))])
new_matrix = u_values[index_map]
return new_matrix
@partial(jax.jit, static_argnums=(1, 2, 10))
def objective(xPhys_1D, young, young_min, freedofs, fext, s_K_predefined, dispTD_predefined, idx, edofMat, ke, penal):
"""
Parameters
----------
xPhys_1D : shape:[Total_ele](1D) 参与有限元计算的相对密度
young : 杨氏模量
young_min : 弱单元的杨氏模量
freedofs : 有限元的自由节点
fext : Python中的1D矢量 外力
s_K_predefined : 预定义的刚度阵
dispTD_predefined : 预定义的位移阵
idx : 刚度阵组装编号
edofMat : shape:[Total_ele, 8](2D) 按照单元顺序存储每一个单元的自由度编号
ke : 单元刚度阵 不考虑杨氏模量
penal : 惩罚因子
Returns
-------
obj : 目标函数输出 此处为结构的柔度
"""
# xPhys_1D = xPhys.T.flatten()
disp = TO_FEA.FEA(young, young_min, ke, xPhys_1D, freedofs,
penal, idx, fext, s_K_predefined, dispTD_predefined)
obj = TO_Obj.compliance(young, young_min, ke,
xPhys_1D, edofMat, penal, disp)
# print('obj is running')
return obj

245
Linear_TO/TO_Define.py Normal file
View File

@ -0,0 +1,245 @@
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 22 17:40:00 2020
@author: zzy
Object: 对TO问题进行定义
1. 对于部分确定的static variable 应采用numpy定义
"""
import numpy as np
import jax.numpy as jnp
import jax.ops as jops
import TO_Cal
import TO_FEA
def Toparams(problem):
young = 1.0
rou = 1.0
nelx, nely, ele_length, ele_width = problem.nelx, problem.nely, problem.ele_length, problem.ele_width
Total_ele, Total_nod, Total_dof = nelx * \
nely, (nelx+1)*(nely+1), 2*(nelx+1)*(nely+1)
gobalcoords0xy, M_elenod, M_edofMat, M_iK, M_jK = coordinate_and_preFEA(
nelx, nely, ele_length, ele_width, Total_ele, Total_nod)
elenod, edofMat, iK, jK = M_elenod-1, M_edofMat-1, M_iK-1, M_jK - \
1 # gobalcoords0xy是绝对坐标无需更改 其他根据索引方式均-1 纯粹的数值处理 后面索引用的IE也要从0起
edofMat_3D = TO_edofMat_3DTensor(nelx, nely) # 这是一个Tensor版本edofMat
design_map_bool, num_design_ele, num_nondesign_ele, design_area_indices, nondesign_area_indices = handle_design_region(
nelx, nely, Total_ele, problem.design_map)
params = {
'young': young,
'young_min': 1e-9*young,
'rou': rou,
'poisson': 0.3,
'g': 0,
'ele_length': ele_length,
'ele_width': ele_width,
'ele_thickness': problem.ele_thickness,
'nelx': nelx,
'nely': nely,
'Total_ele': Total_ele,
'Total_nod': Total_nod,
'Total_dof': Total_dof,
'volfrac': problem.volfrac,
'xmin': 0.001,
'xmax': 1.0,
'design_map': problem.design_map,
# 'design_map_bool': design_map_bool,
# 'num_design_ele': num_design_ele,
# 'num_nondesign_ele': num_nondesign_ele,
'design_area_indices': design_area_indices,
'nondesign_area_indices': nondesign_area_indices,
'freedofs': problem.freedofs,
'fixeddofs': problem.fixeddofs,
'fext': problem.fext,
'penal': 3.0,
'filter_width': problem.filter_width,
'gobalcoords0xy': gobalcoords0xy,
'M_elenod': M_elenod,
'M_edofMat': M_edofMat,
'M_iK': M_iK,
'M_jK': M_jK,
'elenod': elenod,
'edofMat': edofMat,
'edofMat_3D': edofMat_3D,
'iK': iK,
'jK': jK,
'design_condition': problem.design_condition,
}
return params
class Topology_Optimization:
def __init__(self, args):
self.args = args
self.young, self.young_min, self.poisson = args['young'], args['young_min'], args['poisson']
self.Total_dof, self.Total_ele, self.nelx, self.nely, self.ele_thickness = args[
'Total_dof'], args['Total_ele'], args['nelx'], args['nely'], args['ele_thickness']
self.freedofs, self.fext, self.iK, self.jK, self.edofMat = args[
'freedofs'], args['fext'], args['iK'], args['jK'], args['edofMat']
self.filter_width, self.design_condition, self.volfrac = args[
'filter_width'], args['design_condition'], args['volfrac']
self.design_map, self.design_area_indices, self.nondesign_area_indices = args[
'design_map'], args['design_area_indices'], args['nondesign_area_indices']
self.ke = TO_FEA.linear_stiffness_matrix(
self.young, self.poisson, self.ele_thickness)
self.filter_kernal, self.filter_weight = TO_Cal.filter_define(
self.design_map, self.filter_width)
self.s_K_predefined, self.dispTD_predefined, self.idx = jnp.zeros(
[self.Total_dof, self.Total_dof]), jnp.zeros([self.Total_dof]), jops.index[self.iK, self.jK]
self.loop_1, self.iCont = 0, 0
self.penal_use, self.beta_x_use = 3.0, 1.0
self.beta_x_use_max, self.beta_x_deltath1, self.beta_x_deltath2 = 16, 50, 30,
self.xPhys = self.volfrac * jnp.ones([self.nely, self.nelx])
def xPhys_transform(self, params, projection=False):
x = params.reshape(self.nely, self.nelx)
beta_x = self.penal_and_betax()
xPhys_1D = TO_Cal.design_variable(x, self.design_area_indices, self.nondesign_area_indices,
self.filter_kernal, self.filter_weight, beta_x, self.volfrac, projection=projection)
return xPhys_1D
def objective(self, params, projection=False):
self.loop_1 = self.loop_1 + 1
xPhys_1D = self.xPhys_transform(params, projection=projection)
Obj = TO_Cal.objective(xPhys_1D, self.young, self.young_min, self.freedofs, self.fext,
self.s_K_predefined, self.dispTD_predefined, self.idx, self.edofMat, self.ke, self.penal_use)
self.xPhys = xPhys_1D.reshape(self.nely, self.nelx, order='F')
return Obj
def penal_and_betax(self, *args):
self.iCont = self.iCont + 1
if (self.iCont > self.beta_x_deltath1 and self.beta_x_use < 2):
self.beta_x_use = self.beta_x_use * 2
self.iCont = 1
print(' beta_x increased to :%7.3f\n' % (self.beta_x_use))
elif (self.iCont > self.beta_x_deltath2 and self.beta_x_use < self.beta_x_use_max and self.beta_x_use >= 2):
self.beta_x_use = self.beta_x_use * 2
self.iCont = 1
print(' beta_x increased to :%7.3f\n' % (self.beta_x_use))
return self.beta_x_use
def coordinate_and_preFEA(nelx, nely, ele_length, ele_width, Total_ele, Total_nod):
"""
有限元计算相关量 单元及节点编号顺序为从左至右 从上至下
采用Matlab计数法 对单一单元内物理量 按左下逆时针顺序
Parameters
----------
nelx : 横向网格划分
nely : 纵向网格划分
ele_length : 单元长度
ele_width : 单元宽度
Total_ele : 单元总数
Total_nod : 节点总数
Returns
-------
gobalcoords0xy : shape:[Total_nod, 2](2D) 按照节点顺序存储每一个节点的物理坐标(x, y)
elenod : shape:[Total_ele, 4](2D) 按照单元顺序存储每一个单元的节点编号
edofMat : shape:[Total_ele, 8](2D) 按照单元顺序存储每一个单元的自由度编号
iK_int : shape:[Total_ele * 64](2D) 用于刚度阵组装 按照特定顺序储存自由度编号
jK_int : shape:[Total_ele * 64](2D) 用于刚度阵组装 按照特定顺序储存自由度编号
"""
i0, j0 = jnp.meshgrid(jnp.arange(nelx+1), jnp.arange(nely+1))
gobalcoords0x = i0*ele_length
gobalcoords0y = ele_width*nely-j0*ele_width
gobalcoords0xy = jnp.hstack((matrix_array(gobalcoords0x), matrix_array(
gobalcoords0y)))
gobalcoords0xyT = gobalcoords0xy.T
gobalcoords0 = matrix_array(gobalcoords0xyT)
nodegrd = matrix_order_arrange(1, Total_nod, nely+1, nelx+1)
nodelast = Matlab_reshape(Matlab_matrix_extract(
nodegrd, 1, -1, 1, -1), Total_ele, 1)
edofVec = 2*matrix_array(nodelast)+1
elenod = jnp.tile(nodelast, (1, 4)) + \
jnp.tile(jnp.array([1, nely+2, nely+1, 0]), (Total_ele, 1))
edofMat = jnp.tile(edofVec, (1, 8))+jnp.tile(jnp.array(
[0, 1, 2*nely+2, 2*nely+3, 2*nely+0, 2*nely+1, -2, -1]), (Total_ele, 1))
iK = jnp.kron(edofMat, jnp.ones((8, 1))).flatten()
jK = jnp.kron(edofMat, jnp.ones((1, 8))).flatten()
iK_int = iK.astype(jnp.int32)
jK_int = jK.astype(jnp.int32)
return gobalcoords0xy, elenod, edofMat, iK_int, jK_int
def handle_design_region(nelx, nely, Total_ele, design_map):
"""
对设计域/非设计域相关参数进行预定义
Parameters
----------
nelx : 横向网格划分
nely : 纵向网格划分
Total_ele : 单元总数
design_map : 设计域和非设计域定义矩阵
Returns
-------
design_map_bool : bool形式的设计域和非设计域定义矩阵
num_design_ele : 可设计单元总数
num_nondesign_ele : 非设计单元总数
design_area_indices : 设计域的单元编号索引
nondesign_area_indices : 非设计域的单元编号索引
"""
shape = (nely, nelx)
design_map_bool = np.broadcast_to(design_map, shape) > 0
num_design_ele = np.sum(design_map)
num_nondesign_ele = Total_ele - num_design_ele
design_map_bool_1D = design_map_bool.flatten(order='F')
all_indices = np.arange(Total_ele, dtype=jnp.int64)
design_area_indices = jnp.flatnonzero(
design_map_bool_1D, size=num_design_ele)
nondesign_area_indices = jnp.setdiff1d(
all_indices, design_area_indices, size=num_nondesign_ele, assume_unique=True) # setdiff1d通常不支持jit 需要指定size
return design_map_bool, num_design_ele, num_nondesign_ele, design_area_indices, nondesign_area_indices
def Matlab_matrix_array1D(matrixA):
"""
shape:[*] 1D array 按照matlab的排列
"""
return matrixA.T.flatten()
def matrix_array(matrixA):
"""
shape:[*, 1] 2D array 按照matlab的排列 matrixA(:)
"""
return matrixA.T.reshape(-1, 1)
def matrix_order_arrange(numbegin, numend, ydim, xdim):
'''
shape:[ydim, xdim] 2D array 实现一个元素按列顺序排列的矩阵 注意索引编号和matlab的差1
'''
matrixA = jnp.array([[i for i in range(numbegin, numend+1)]])
matrixA = matrixA.reshape(xdim, ydim)
matrixA = matrixA.T
return matrixA
def Matlab_reshape(matrixA, ydim, xdim):
'''
shape:[ydim, xdim] 2D array matlab的reshape效果
'''
return matrixA.reshape((ydim, xdim), order='F')
def Matlab_matrix_extract(matrixA, xbegin, xend, ybegin, yend):
matrixA = matrixA[ybegin-1:yend, xbegin-1:xend]
return matrixA
def TO_edofMat_3DTensor(nelx, nely):
ely, elx = jnp.meshgrid(jnp.arange(nely), jnp.arange(nelx)) # x, y coords
n1 = (nely+1)*(elx+0) + (ely+0)
n2 = (nely+1)*(elx+1) + (ely+0)
n3 = (nely+1)*(elx+1) + (ely+1)
n4 = (nely+1)*(elx+0) + (ely+1)
edofMat_3D = jnp.array(
[2*n4, 2*n4+1, 2*n3, 2*n3+1, 2*n2, 2*n2+1, 2*n1, 2*n1+1])
return edofMat_3D

133
Linear_TO/TO_FEA.py Normal file
View File

@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
"""
Created on Sat Apr 25 01:57:01 2020
@author: zzy
Object: 存储有限元计算所需函数
1. 所有函数均可jit
2. 在调用时 选择在最外层封装jit 内层函数则将jit注释
3. 对于部分确定的static variable 应采用numpy定义(仅限CPU)
"""
import numpy as np
from numpy import float64
import jax.numpy as jnp
import jax.ops as jops
import jax.scipy.linalg as jsl
from jax import jit
import TO_Private_Autodiff as TPA
from jax.config import config
config.update("jax_enable_x64", True) # 对计算精度影响很大
def linear_stiffness_matrix(young, poisson, ele_thickness):
"""
Parameters
----------
young : 杨氏模量
poisson : 泊松比
ele_thickness : 单元厚度
Returns
-------
stiffness_ele : 单元刚度矩阵 注意 考虑到TO问题 此处没有考虑单元杨式模量的影响
"""
young, poisson, h = young, poisson, ele_thickness
k = np.array([1/2-poisson/6, 1/8+poisson/8, -1/4-poisson/12, -1/8+3*poisson/8,
-1/4+poisson/12, -1/8-poisson/8, poisson/6, 1/8-3*poisson/8])
stiffness_ele = h/(1-poisson**2)*np.array([[k[0], k[1], k[2], k[3], k[4], k[5], k[6], k[7]],
[k[1], k[0], k[7], k[6],
k[5], k[4], k[3], k[2]],
[k[2], k[7], k[0], k[5],
k[6], k[3], k[4], k[1]],
[k[3], k[6], k[5], k[0],
k[7], k[2], k[1], k[4]],
[k[4], k[5], k[6], k[7],
k[0], k[1], k[2], k[3]],
[k[5], k[4], k[3], k[2],
k[1], k[0], k[7], k[6]],
[k[6], k[3], k[4], k[1],
k[2], k[7], k[0], k[5]],
[k[7], k[2], k[1], k[4],
k[3], k[6], k[5], k[0]]
], dtype=float64)
return stiffness_ele
def density_matrix(rou, ele_length, ele_width, ele_thickness):
"""
Parameters
----------
rou : 密度
ele_length : 单元长度
ele_width : 单元宽度
ele_thickness : 单元厚度
Returns
-------
density_ele : 单元密度矩阵(一致质量矩阵) 注意 考虑到TO问题 此处没有考虑单元相对密度的影响 否则分子应当是质量
"""
ele_volume = ele_length*ele_width*ele_thickness
density_ele = ele_volume / 36 * np.array([[4, 0, 2, 0, 1, 0, 2, 0],
[0, 4, 0, 2,
0, 1, 0, 2],
[2, 0, 4, 0,
2, 0, 1, 0],
[0, 2, 0, 4,
0, 2, 0, 1],
[1, 0, 2, 0,
4, 0, 2, 0],
[0, 1, 0, 2,
0, 4, 0, 2],
[2, 0, 1, 0,
2, 0, 4, 0],
[0, 2, 0, 1,
0, 2, 0, 4],
], dtype=float64)
return density_ele
# @jit
def FEA(young, young_min, ke, xPhys_1D, freedofs, penal, idx, fext, s_K_predefined, dispTD_predefined):
"""
Parameters
----------
young : 杨氏模量
young_min : 弱单元的杨氏模量
ke : 单元刚度阵 不考虑杨氏模量
xPhys_1D : shape:[Total_ele](1D) 参与有限元计算的单元相对密度阵 (这种形式是必要的)
freedofs : 有限元的自由节点
penal : 惩罚因子
idx : 刚度阵组装编号
fext : shape:[Total_dof](1D) 外力
s_K_predefined : shape:[Total_dof, Total_dof](2D) 预定义的刚度阵
dispTD_predefined : shape:[Total_dof](1D) 预定义的位移阵
Returns
-------
dispTD shape:[Total_dof](1D) 节点位移
"""
def kuf_solve(fext_1D_free, s_K_free):
u_free = jsl.solve(s_K_free, fext_1D_free,
sym_pos=True, check_finite=False)
return u_free
def young_modulus(xPhys_1D, young, young_min, ke, penal):
"""
Returns
-------
sK : shape:[64, Total_ele](2D) 考虑相对密度后的每个单元的刚度阵(列形式存储)
"""
sK = ((ke.flatten()[jnp.newaxis]).T * (young_min + (xPhys_1D) ** penal * (young-young_min))
).flatten(order='F')
return sK
sK = young_modulus(xPhys_1D, young, young_min, ke, penal)
s_K = jops.index_add(s_K_predefined, idx, sK)
s_K_free = s_K[freedofs, :][:, freedofs]
fext_free = fext[freedofs]
u_free = kuf_solve(fext_free, s_K_free)
dispTD = jops.index_add(dispTD_predefined, freedofs, u_free)
return dispTD

249
Linear_TO/TO_Model.py Normal file
View File

@ -0,0 +1,249 @@
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 19 00:00:06 2021
@author: zzy
Object: AuTONR的model构建
1. 基于Jax+Flax+JAXopt/Optax
2. 关于dataclasses和flax.linen可以参考
learning_dataclasses.py
learning_flax_module.py
"""
import jax
import jax.numpy as jnp
import jax.image as ji
import jax.tree_util as jtree
from jax import random
import flax.linen as nn
import flax.linen.initializers as fli
import flax.traverse_util as flu
import flax.core.frozen_dict as fcfd
from functools import partial
from typing import Any, Callable
from jax.config import config
config.update("jax_enable_x64", True)
def set_random_seed(seed):
if seed is not None:
rand_key = random.PRNGKey(seed)
return rand_key
def constant_array(rng_key, shape, dtype=jnp.float64, value=0.0):
"""
自定义参数初始化函数 实现矩阵填充数值
Parameters
----------
rng_key :
shape : 目标矩阵的shape
dtype : 数据格式 The default is jnp.float64.
value : 填充数值 The default is 0.0.
Returns
-------
out : 初始化后的矩阵
"""
out = jnp.full(shape, value, dtype)
# print("i am running")
return out
def extract_model_params(params_use):
"""
提取网络参数
Parameters
----------
params_use : FrozenDict 需要提取的参数
Returns
-------
extracted_params : dict 提取出的参数 {'Conv_0/bias': ..., 'Conv_0/kernel': ..., ...}
params_shape : dict 提取的每一组参数的shape {'Conv_0/bias': (...), 'Conv_0/kernel': (...), ...}
params_total_num : 参数总数
"""
params_unfreeze = params_use.unfreeze()
extracted_params = {
'/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()}
params_shape = jtree.tree_map(jnp.shape, extracted_params)
params_total_num = sum(p.size for p in jtree.tree_leaves(extracted_params))
return extracted_params, params_shape, params_total_num
def replace_model_params(pre_trained_params_dict, params_use):
"""
替换网络参数
Parameters
----------
pre_trained_params_dict : dict 预训练的参数 {'Conv_0/bias': ..., 'Conv_0/kernel': ..., ...}
params_use : FrozenDict 需要被替换的参数
Returns
-------
new_params_use : FrozenDict 新的参数
"""
params_unfreeze = params_use.unfreeze()
# 有两种方式均可实现功能
# Method A 基于flax.traverse_util实现
extracted_params = {
'/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()}
for key, val in pre_trained_params_dict.items():
extracted_params[key] = val
new_params_use = flu.unflatten_dict(
{tuple(k.split('/')): v for k, v in extracted_params.items()})
# Method B 基于jax.tree实现
params_leaves, params_struct = jtree.tree_flatten(params_unfreeze)
i = 0
for key, val in pre_trained_params_dict.items():
params_leaves[i] = val
i = i + 1
new_params_use = jtree.tree_unflatten(params_struct, params_leaves)
# 封装新的参数
new_params_use = fcfd.freeze(new_params_use)
return new_params_use
def batched_loss(x_inputs, TopOpt_envs):
"""
处理带有batch_size和多个实例化的TopOpt类的情况
Parameters
----------
x_inputs : shape:[1, nely, nelx] (3D array) 第一个维度代表batch 第二和第三个维度代表TO_2D问题中的单元相对密度 通常batch_size = 1
TopOpt_envs : 多个实例化的TopOpt类 以列表形式存储 通常仅有一个实例化的类
Returns
-------
losses_array : loss_list以列表形式存储对应于每一个batch和TopOpt的obj输出 将其转化为array形式
"""
loss_list = [TopOpt.objective(x_inputs[i], projection=True)
for i, TopOpt in enumerate(TopOpt_envs)]
losses_array = jnp.stack(loss_list)[0]
return losses_array
# @jax.jit
class Normalize_TO(nn.Module):
"""
自定义Normalize类
"""
epsilon: float = 1e-6
@nn.compact
def __call__(self, input):
# 通过对axis设置可实现不同normalization的效果
input_mean = jnp.mean(input, axis=(1, 2, 3), keepdims=True)
input_var = jnp.var(input, axis=(1, 2, 3), keepdims=True)
output = input - input_mean
output = output * jax.lax.rsqrt(input_var + self.epsilon)
return output
class Add_offset(nn.Module):
scale: int = 10
@nn.compact
def __call__(self, input):
add_bias = self.param('add_bias', lambda rng, shape: constant_array(
rng, shape, value=0.0), input.shape)
return input + self.scale * add_bias
class BasicBlock(nn.Module):
resize_scale_one: int
conv_output_one: int
norm_use: nn.Module = Normalize_TO
resize_method: str = "bilinear"
kernel_size: int = 5
conv_kernel_init: partial = fli.variance_scaling(
scale=1.0, mode="fan_in", distribution="truncated_normal")
offset: nn.Module = Add_offset
offset_scale: int = 10
@nn.compact
def __call__(self, input):
output = jnp.tanh(input)
output_shape = output.shape
output = ji.resize(output, (output_shape[0], self.resize_scale_one * output_shape[1],
self.resize_scale_one * output_shape[2], output_shape[-1]), method=self.resize_method, antialias=True)
output = self.norm_use()(output)
output = nn.Conv(features=self.conv_output_one, kernel_size=(
self.kernel_size, self.kernel_size), kernel_init=self.conv_kernel_init)(output)
output = self.offset(scale=self.offset_scale)(output)
return output
class My_TO_Model(nn.Module):
TopOpt: Any
seed: int = 1
replace_params: Callable = replace_model_params
extract_params: Callable = extract_model_params
def loss(self, input_TO):
# 以list形式传入TopOpt 其实tuple形式也可以
obj_TO = batched_loss(input_TO, [self.TopOpt])
return obj_TO
class TO_CNN(My_TO_Model):
h: int = 5
w: int = 10
dense_scale_init: Any = 1.0
dense_output_channel: int = 1000
dtype: Any = jnp.float32
input_init_std: float = 1.0 # 初始化设计变量的标准差 正态分布
dense_input_channel: int = 128
conv_input_0: int = 32
conv_output: tuple = (128, 64, 32, 16, 1)
up_sampling_scale: tuple = (1, 2, 2, 2, 1)
offset_scale: int = 10
block: nn.Module = BasicBlock
@nn.compact
def __call__(self, x=None): # [N, H, W, C]
nn_input = self.param('z', lambda rng, shape: constant_array(
rng, shape, value=self.TopOpt.volfrac), self.dense_input_channel)
output = nn.Dense(features=self.dense_output_channel, kernel_init=fli.orthogonal(
scale=self.dense_scale_init))(nn_input)
output = output.reshape([1, self.h, self.w, self.conv_input_0])
output = self.block(
self.up_sampling_scale[0], self.conv_output[0])(output)
output = self.block(
self.up_sampling_scale[1], self.conv_output[1])(output)
output = self.block(
self.up_sampling_scale[2], self.conv_output[2])(output)
output = self.block(
self.up_sampling_scale[3], self.conv_output[3])(output)
output = self.block(
self.up_sampling_scale[4], self.conv_output[4])(output)
nn_output = jnp.squeeze(output, axis=-1)
self.sow('intermediates', 'model_out',
nn_output) # 在这里是否可以考虑去掉batch维度
return nn_output
if __name__ == '__main__':
def funA(params):
out = 1 * jnp.sum(params)
return out
def funB(params):
out = 2 * jnp.sum(params)
return out
def funC(params):
out = 3 * jnp.sum(params)
return out
# 测试 直观展示batched_loss
params_1 = jnp.ones([3, 4, 4])
fun_list = [funA, funB, funC]
losses = [fun_used(params_1[i]) for i, fun_used in enumerate(fun_list)]
loss_out = jnp.stack(losses)
print("losses:", losses)
print("loss_out:", loss_out)

22
Linear_TO/TO_Obj.py Normal file
View File

@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
"""
Created on Sun Apr 26 01:59:52 2020
@author: zzy
Object: 存储TO计算目标函数过程所需的函数
1. 所有函数均可jit
2. 在调用时 选择在最外层封装jit 内层函数则将jit注释
3. 对于部分确定的static variable 应采用numpy定义
"""
import jax.numpy as jnp
from jax import jit
# @jit
def compliance(young, young_min, ke, xPhys_1D, edofMat, penal, disp):
ce = jnp.sum(jnp.matmul(disp[edofMat], ke) * disp[edofMat], 1)
ce = (young_min + xPhys_1D ** penal * (young - young_min)) * ce
obj = jnp.sum(ce)
return obj

129
Linear_TO/TO_Problem.py Normal file
View File

@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 11 14:20:11 2020
@author: zzy
定义设计域
"""
import jax.ops as jops
import numpy as np
import dataclasses
from typing import Optional, Union
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
@dataclasses.dataclass
class ToProblem:
L: int
W: int
nelx: int
nely: int
design_condition: int
ele_length: np.float64
ele_width: np.float64
ele_thickness: np.float64
volfrac: np.float64
fixeddofs: np.ndarray
freedofs: np.ndarray
fext: np.ndarray
filter_width: float
design_map: np.ndarray
def cantilever_single():
L = 0.8
W = 0.4
nelx = 80
nely = 40
design_condition = 1
Total_dof = 2*(nelx+1)*(nely+1)
ele_length = L/nelx
ele_width = W/nely
ele_thickness = 1.0
volfrac = 0.5
dofs = np.arange(Total_dof)
loaddof = Total_dof-1
# fext = jnp.zeros(Total_dof)
# fext = jops.index_update(fext, jops.index[loaddof], -1.)
fext = np.zeros(Total_dof)
fext[loaddof] = -1
fixeddofs = np.array(dofs[0:2*(nely+1):1])
freedofs = np.setdiff1d(dofs, fixeddofs)
filter_width = 3.0
design_map = np.ones([nely, nelx], dtype=np.int64)
return ToProblem(L, W, nelx, nely, design_condition, ele_length, ele_width, ele_thickness, volfrac, fixeddofs, freedofs, fext, filter_width, design_map)
def cantilever_single_NN(TopOpt):
dense_input_channel = 128
dense_init_scale = 1.0
conv_input_0 = 32
up_sampling_scale = (1, 2, 2, 2, 1)
total_resize = int(np.prod(up_sampling_scale))
h = TopOpt.nely // total_resize
w = TopOpt.nelx // total_resize
dense_output_channel = h * w * conv_input_0
dense_scale_init = dense_init_scale * \
jnp.sqrt(jnp.maximum(dense_output_channel / dense_input_channel, 1))
model_kwargs = {
'h': h,
'w': w,
'dense_scale_init': dense_scale_init,
'dense_output_channel': dense_output_channel,
'dense_input_channel': dense_input_channel,
'conv_input_0': conv_input_0,
'up_sampling_scale': up_sampling_scale,
}
return model_kwargs
def cantilever_single_big():
L = 1.6
W = 0.8
nelx = 160
nely = 80
design_condition = 2
Total_dof = 2*(nelx+1)*(nely+1)
ele_length = L/nelx
ele_width = W/nely
ele_thickness = 1.0
volfrac = 0.5
dofs = np.arange(Total_dof)
loaddof = Total_dof-1
# fext = jnp.zeros(Total_dof)
# fext = jops.index_update(fext, jops.index[loaddof], -1.)
fext = np.zeros(Total_dof)
fext[loaddof] = -1
fixeddofs = np.array(dofs[0:2*(nely+1):1])
freedofs = np.setdiff1d(dofs, fixeddofs)
filter_width = 3.0
design_map = np.ones([nely, nelx], dtype=np.int64)
return ToProblem(L, W, nelx, nely, design_condition, ele_length, ele_width, ele_thickness, volfrac, fixeddofs, freedofs, fext, filter_width, design_map)
def cantilever_single_big_NN(TopOpt):
dense_input_channel = 128
dense_init_scale = 1.0
conv_input_0 = 32
up_sampling_scale = (1, 2, 2, 2, 2)
total_resize = int(np.prod(up_sampling_scale))
h = TopOpt.nely // total_resize
w = TopOpt.nelx // total_resize
dense_output_channel = h * w * conv_input_0
dense_scale_init = dense_init_scale * \
jnp.sqrt(jnp.maximum(dense_output_channel / dense_input_channel, 1))
model_kwargs = {
'h': h,
'w': w,
'dense_scale_init': dense_scale_init,
'dense_output_channel': dense_output_channel,
'dense_input_channel': dense_input_channel,
'conv_input_0': conv_input_0,
'up_sampling_scale': up_sampling_scale,
}
return model_kwargs

164
Linear_TO/TO_Train.py Normal file
View File

@ -0,0 +1,164 @@
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 19 09:53:14 2021
@author: zzy
Object: AuTONR的训练函数
1. 基于Flax+JAXopt/Optax实现
2. 关于train过程的实现可以参考:
learning_optax_training.py
learning_flax_managing_params.py
learning_flax_training.py
"""
import sys
from pathlib import Path
from absl import logging
import xarray
import jax
import jax.numpy as jnp
from jax import random
from jaxopt import OptaxSolver
import optax
import time
from functools import partial
import matplotlib.pyplot as plt
from jax.config import config
config.update("jax_enable_x64", True)
sys.path.append(
'/home/zzy/ZZY_CODE/Env_JAX/Topology_Optimization/Linear_TO')
here = Path(__file__).resolve().parent
def set_random_seed(seed):
if seed is not None:
rand_key = random.PRNGKey(seed)
return rand_key
def optimizer_result_dataset(losses, frames, save_intermediate_designs=False):
# The best design will often but not always be the final one.
best_design = jnp.nanargmin(losses)
logging.info(f'Final loss: {losses[best_design]}')
if save_intermediate_designs:
ds = xarray.Dataset({
'loss': (('step',), losses),
'design': (('step', 'y', 'x'), frames),
}, coords={'step': jnp.arange(len(losses))})
else:
ds = xarray.Dataset({
'loss': (('step',), losses),
'design': (('y', 'x'), frames[best_design]),
}, coords={'step': jnp.arange(len(losses))})
return ds
def train_TO_Optax(model, max_iterations, save_intermediate_designs=True, seed=1, **kwargs):
losses = []
frames = []
# Initialize parameters
init_params = model.init(set_random_seed(seed), None)
batch_state, params = init_params.pop("params")
del init_params
# Instantiate Optimizer
design_condition = model.TopOpt.design_condition
if design_condition == 1:
learning_rate = 0.001
optimizer = optax.adam(learning_rate)
elif design_condition == 2:
learning_rate = 0.001
optimizer = optax.adam(learning_rate)
# loss
def loss_cal(params):
all_params = {"params": params}
model_out, net_state = model.apply(all_params, None, mutable="intermediates")
loss_out = model.loss(model_out)
return loss_out, net_state["intermediates"]
loss_grad_fn = jax.value_and_grad(loss_cal, has_aux=True)
# Initialize Optimizer State
state = optimizer.init(params)
# Updated
for iter in range(max_iterations):
start_time_epoch = time.perf_counter()
(loss_val, batch_stats), grads = loss_grad_fn(params)
losses.append(jax.device_get(loss_val))
updates_params, state = optimizer.update(grads, state)
params = optax.apply_updates(params, updates_params)
design_TO = jax.device_get(model.TopOpt.xPhys.aval.val)
frames.append(design_TO)
plt.figure()
plt.imshow(design_TO, cmap='Greys')
plt.show()
whole_time_epoch = time.perf_counter() - start_time_epoch
print(' It.:%5i Obj.:%11.4f Vol.:%7.3f Time.:%7.3f\n' %
(iter+1, losses[-1], jnp.mean(design_TO), whole_time_epoch))
return optimizer_result_dataset(jnp.array(losses), jnp.array(frames), save_intermediate_designs)
def train_TO_JAXopt(model, max_iterations, save_intermediate_designs=True, seed=1, **kwargs):
losses = []
frames = []
# Initialize parameters
init_params = model.init(set_random_seed(seed), None)
# batch_stats, params = init_params.pop("params")
params = init_params["params"]
# batch_stats = init_params["batch_stats"]
del init_params
# Initialize solver
# @jax.jit
def loss_cal(params):
# all_params = {"params": params, "batch_stats": aux}
# model_out, net_state = model.apply(all_params, None)
all_params = {"params": params}
model_out, net_state = model.apply(
all_params, None, mutable="intermediates")
loss_out = model.loss(model_out)
# design_TO = model.TopOpt.x_calculation(
# model_out, volume_contraint=True, projection=True)
# losses.append(jax.device_get(loss_out))
return loss_out, net_state["intermediates"]
design_condition = model.TopOpt.design_condition
def opt_solver_setting(design_condition, loss_fun, max_iterations):
if design_condition == 1:
learning_rate = 0.001
opt_use = optax.adam(learning_rate)
solver = OptaxSolver(
fun=loss_fun, opt=opt_use, maxiter=max_iterations, has_aux=True)
elif design_condition == 2:
learning_rate = 0.000375
elif design_condition == 6:
learning_rate = 0.001
elif design_condition == 5:
learning_rate = 0.000375
return solver
solver = opt_solver_setting(design_condition, loss_cal, max_iterations)
state = solver.init_state(params)
for iter in range(max_iterations):
start_time_epoch = time.time()
# params, state = solver.update(
# params=params, state=state, aux=batch_stats)
params, state = solver.update(params=params, state=state)
batch_stats = state.aux
losses.append(jax.device_get(state.value)) # 放在这里应该更好一些
# design_TO = jax.device_get(batch_stats._dict['model_out'])[-1]
# design_TO = model.TopOpt.x_calculation(design_TO, volume_constraint=True, projection=True)
design_TO = jax.device_get(model.TopOpt.xPhys.aval.val)
frames.append(design_TO)
plt.figure()
plt.imshow(design_TO, cmap='Greys')
plt.show()
end_time_epoch = time.time()
whole_time_epoch = end_time_epoch - start_time_epoch
print(' It.:%5i Obj.:%11.4f Vol.:%7.3f Time.:%7.3f\n' %
(iter+1, losses[-1], jnp.mean(design_TO), whole_time_epoch))
# frames = jax.device_get(batch_stats._dict['intermediates']['model_out'])
return optimizer_result_dataset(jnp.array(losses), jnp.array(frames), save_intermediate_designs)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.