2022.8.16 Commit
This commit is contained in:
commit
da7820895e
|
@ -0,0 +1,92 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
|
||||
@ Author : Zeyu Zhang
|
||||
@ Email : zhangzeyu_work@outlook.com
|
||||
@ Date : 2021-11-17 10:52:26
|
||||
@ LastEditTime : 2022-08-16 14:45:12
|
||||
@ FilePath : /Topology_Optimization/Linear_TO/TONR_Linear_Stiffness.py
|
||||
@
|
||||
@ Description : Jax+Flax+optax TONR Linear Stiffness Problem
|
||||
@ Reference :
|
||||
'''
|
||||
|
||||
from pathlib import Path # 文件路径的相关操作
|
||||
import sys
|
||||
import pandas as pd
|
||||
import xarray
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn
|
||||
import time
|
||||
import TO_Problem
|
||||
import TO_Define
|
||||
import TO_Model
|
||||
import TO_Train
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True) # 对计算精度影响很大
|
||||
sys.path.append(
|
||||
'/home/zzy/ZZY_CODE/Env_JAX/Topology_Optimization/Linear_TO')
|
||||
here = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
design_condition = 1
|
||||
if design_condition == 1:
|
||||
problem = TO_Problem.cantilever_single()
|
||||
max_iterations = 200
|
||||
args = TO_Define.Toparams(problem)
|
||||
TopOpt_env = TO_Define.Topology_Optimization(args)
|
||||
model_kwargs = TO_Problem.cantilever_single_NN(TopOpt_env)
|
||||
elif design_condition == 2:
|
||||
problem = TO_Problem.cantilever_single_big()
|
||||
max_iterations = 200
|
||||
args = TO_Define.Toparams(problem)
|
||||
TopOpt_env = TO_Define.Topology_Optimization(args)
|
||||
model_kwargs = TO_Problem.cantilever_single_big_NN(TopOpt_env)
|
||||
|
||||
TONR_model = TO_Model.TO_CNN(TopOpt=TopOpt_env, **model_kwargs)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
start_time = time.perf_counter()
|
||||
seed = 1
|
||||
ds_TopOpt = TO_Train.train_TO_Optax(
|
||||
TONR_model, max_iterations, save_intermediate_designs=True, seed=1)
|
||||
ds_TopOpt.to_netcdf('ds_TopOpt.nc') # save
|
||||
# ds_TopOpt = xarray.open_dataset('ds_TopOpt.nc') # load
|
||||
end_time = time.perf_counter()
|
||||
whole_time = end_time - start_time
|
||||
|
||||
obj_value = ds_TopOpt.loss.data
|
||||
x_designed = ds_TopOpt.design.data
|
||||
|
||||
obj_min = jnp.min(obj_value)
|
||||
obj_max = jnp.max(obj_value)
|
||||
obj_min_index = jnp.where(obj_value == obj_min)[0][0]
|
||||
|
||||
dims = pd.Index(['TONR'], name='model')
|
||||
ds = xarray.concat([ds_TopOpt], dim=dims)
|
||||
ds_TopOpt.loss.transpose().to_pandas().cummin().plot(linewidth=2)
|
||||
plt.ylim(obj_min * 0.85, obj_value[0] * 1.15)
|
||||
plt.ylabel('Compliance (loss)')
|
||||
plt.xlabel('Optimization step')
|
||||
seaborn.despine()
|
||||
|
||||
final_design = x_designed[obj_min_index, :, :]
|
||||
final_obj = obj_min
|
||||
|
||||
plt.show()
|
||||
ds_TopOpt.design.sel(step=obj_min_index).plot.imshow(x='x', y='y', size=2,
|
||||
aspect=2.5, col_wrap=2, yincrease=False, add_colorbar=False, cmap='Greys')
|
||||
|
||||
def save_gif_movie(images, path, duration=200, loop=0, **kwargs):
|
||||
images[0].save(path, save_all=True, append_images=images[1:],
|
||||
duration=duration, loop=loop, **kwargs)
|
||||
|
||||
# images = [
|
||||
# TO_Support.image_from_design(design, problem)
|
||||
# for design in ds.design.sel(model='DL_TO')[:150]
|
||||
# ]
|
||||
|
||||
# save_gif_movie([im.resize((5*120, 5*20)) for im in images], 'movie.gif')
|
|
@ -0,0 +1,253 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Thu Apr 23 15:08:01 2020
|
||||
|
||||
@author: zzy
|
||||
|
||||
Object: 存储TO计算过程所需的函数
|
||||
1. 所有函数均可jit
|
||||
2. 在调用时 选择在最外层封装jit 内层函数则将jit注释
|
||||
3. 对于部分确定的static variable 应采用numpy定义(仅限CPU)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy.signal as jss
|
||||
import jax.lax as jl
|
||||
from jaxopt import Bisection
|
||||
from functools import partial
|
||||
import TO_Private_Autodiff as TPA
|
||||
import TO_FEA
|
||||
import TO_Obj
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(5, 6, 7))
|
||||
def design_variable(x, design_area_indices, nondesign_area_indices, filter_kernal, filter_weight, beta_x, volfrac, projection=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : shape:[nely, nelx](2D) 经过reshape的NN输出
|
||||
design_area_indices : 设计域的单元编号索引
|
||||
nondesign_area_indices : 非设计域的单元编号索引
|
||||
filter_kernal : TOP88中的H 卷积过滤中的卷积核 计算影响域内每一个单元的权重
|
||||
filter_weight : TOP88中的HS 设计域内每一个单元的总权重 与网格划分同维度
|
||||
beta_x : Projection参数
|
||||
volfrac : 目标体积
|
||||
projection : 是否采用Projection
|
||||
|
||||
Returns
|
||||
-------
|
||||
xPhys_1D : shape:[Total_ele](1D) 参与有限元计算的单元相对密度阵
|
||||
"""
|
||||
|
||||
# density filtering 用在了转换之前 目前看也是有效果的
|
||||
x = filter(x, filter_kernal, filter_weight)
|
||||
x_1D = x.flatten(order='F')
|
||||
x_designed = design_and_volume_constraints(
|
||||
x_1D[design_area_indices], volfrac, beta_x)
|
||||
xPhys_1D = matrix_set_0(
|
||||
x_designed, design_area_indices, nondesign_area_indices)
|
||||
# 直接等于xPhys_1D 但尚需要在真实存在非设计域时验证
|
||||
# 原始方案
|
||||
# x_in = x[design_map_bool]
|
||||
# x_designed = design_and_volume_constraints(x_in, volfrac, beta_x) # x[design_map_bool]内含了一步matrix.flatten()
|
||||
# x_flat = matrix_set_0(x_designed, jnp.flatnonzero(design_map_bool, size=num_design_ele), Total_ele, num_nondesign_ele)
|
||||
# 因为x_designed内含flatten() 所以这里对应要复原 由于flatten时没有传入order='F' 这里也不能传入
|
||||
# xPhys = x_flat.reshape(x.shape)
|
||||
return xPhys_1D
|
||||
|
||||
|
||||
def projection(x, beta_TONR):
|
||||
"""
|
||||
实现projection 将[-z, z]的数值映射至(0, 1)
|
||||
目前
|
||||
原理 tanh能够将元素映射至(-1, 1) 因此 tanh_out * 0.5 + 0.5可以映射至(0, 1)
|
||||
注意 sigmoid(x) = 0.5 * tanh(0.5 * x) + 0.5
|
||||
Parameters
|
||||
----------
|
||||
x : 待映射的矩阵
|
||||
beta_TONR : projection函数斜率的控制参数
|
||||
|
||||
Returns
|
||||
-------
|
||||
output : sigmoid/projection变换后的矩阵
|
||||
"""
|
||||
output = jnp.tanh(beta_TONR * 0.5 * x)*.5 + 0.5
|
||||
return output
|
||||
|
||||
|
||||
def logit(obj_V):
|
||||
"""
|
||||
根据obj_V确定sigmoid分母部分e ** (x-b)中 b搜索的初始上下界
|
||||
np.clip(a, a_min, a_max) 将a的取值截断在(a_min, a_max)内
|
||||
Parameters
|
||||
----------
|
||||
obj_V : 目标体积
|
||||
|
||||
Returns
|
||||
-------
|
||||
"""
|
||||
p = jnp.clip(
|
||||
obj_V, 0, 1)
|
||||
return jnp.log(p) - jnp.log1p(-p) # p从-inf到inf p取0.5为0 以0.5为界限 大于小于符号相反 数值相同
|
||||
|
||||
|
||||
# @partial(jax.jit, static_argnums=(1, 2))
|
||||
def design_and_volume_constraints(x, obj_V, beta_TONR):
|
||||
"""
|
||||
用来将神经网络的输出转化为满足设计约束 体积约束和projection的用于有限元计算的单元密度
|
||||
x_design = 1/(1 + e ** (x - b(x, obj_V)))
|
||||
Parameters
|
||||
----------
|
||||
x : 经过密度过滤后的单元相对密度阵
|
||||
obj_V : 目标体积.
|
||||
beta_TONR : Projection参数
|
||||
|
||||
Returns
|
||||
-------
|
||||
x_design : 满足设计约束和体积约束的相对密度 向量形式
|
||||
"""
|
||||
def find_yita(y, x, beta):
|
||||
projection_out = projection(x + y, beta)
|
||||
volume_diff = jnp.mean(projection_out) - obj_V
|
||||
return volume_diff
|
||||
lower_bound = logit(obj_V) - jl.stop_gradient(np.max(x))
|
||||
upper_bound = logit(obj_V) - jl.stop_gradient(np.min(x))
|
||||
bisec = Bisection(optimality_fun=find_yita, lower=lower_bound, upper=upper_bound, tol=1e-12, maxiter=64, check_bracket=False)
|
||||
yita = bisec.run(x=x, beta=beta_TONR).params
|
||||
# yita = TPA.binary_search(find_yita, x, beta_TONR, lower_bound, upper_bound) # 自编二分法
|
||||
# 用来搜索yita的取值 由二分法获得 本质上这里和保体积projection搜索yita的过程完全相同
|
||||
# 实验和公式表明 在保体积projection和此处添加体积约束的sigmoid中 对yita的求解均是一个关于(x, obj_V)的函数
|
||||
# 其敏度必然要考虑
|
||||
x_design = projection(x + yita, beta_TONR)
|
||||
x_design = x_design + 0.00001
|
||||
return x_design
|
||||
|
||||
|
||||
def filter_define(design_map, filter_width):
|
||||
"""
|
||||
定义TO中过滤器的卷积核 权重等参数
|
||||
Parameters
|
||||
----------
|
||||
design_map : 设计域和非设计域定义矩阵
|
||||
filter_width : 过滤半径
|
||||
|
||||
Returns
|
||||
-------
|
||||
filter_kernal : TOP88中的H 卷积过滤中的卷积核 计算影响域内每一个单元的权重
|
||||
filter_weight : TOP88中的HS 设计域内每一个单元的总权重 与网格划分同维度
|
||||
"""
|
||||
dy, dx = jnp.meshgrid(jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width)),
|
||||
jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width)))
|
||||
filter_kernal = jnp.maximum(0, filter_width - jnp.sqrt(dx ** 2 + dy ** 2))
|
||||
filter_weight = jss.convolve2d(design_map, filter_kernal, 'same')
|
||||
return filter_kernal, filter_weight
|
||||
|
||||
|
||||
# @jax.jit
|
||||
def filter(matrixA, filter_kernal, filter_weight):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
matrixA : 待过滤的矩阵
|
||||
filter_kernal : TOP88中的H 卷积过滤中的卷积核 计算影响域内每一个单元的权重
|
||||
filter_weight : TOP88中的HS 设计域内每一个单元的总权重 与网格划分同维度
|
||||
|
||||
Returns
|
||||
-------
|
||||
new_matrix : 过滤后的矩阵
|
||||
"""
|
||||
new_matrix = jss.convolve2d(matrixA, filter_kernal, 'same') / filter_weight
|
||||
return new_matrix
|
||||
|
||||
|
||||
def inverse_permutation(indices):
|
||||
inverse_perm = jnp.zeros(len(indices), dtype=jnp.int64)
|
||||
inverse_perm = inverse_perm.at[indices].set(
|
||||
jnp.arange(len(indices), dtype=jnp.int64))
|
||||
return inverse_perm
|
||||
|
||||
|
||||
# @jax.jit
|
||||
def matrix_set_0(nonzero_values, nonzero_indices, zero_indices):
|
||||
"""
|
||||
补齐矩阵 对原始矩阵指定区域置0 可用于非设计域进行置0 以及节点位移添加指定节点的位移
|
||||
Parameters
|
||||
----------
|
||||
nonzero_values : 已存在数值的矩阵 如设计域的相对密度阵 freedofs对应的节点位移
|
||||
nonzero_indices : size=存在数值的个数 已存在数值矩阵对应的索引编号 如设计域对应的索引
|
||||
注意 这里为根据design_map自动生成的 是基于行展开编号的 仅用于此处计算 并非真正TO中的单元编号
|
||||
zero_indices : 矩阵中未存在数值的单元对应的索引编号
|
||||
origin_matrix_size : 原始矩阵元素个数 即包含设计域和非设计域
|
||||
zero_size : 矩阵中0元素的个数 代表非设计域的单元个数
|
||||
|
||||
Returns
|
||||
-------
|
||||
new_matrix : 补齐后的矩阵
|
||||
"""
|
||||
# all_indices = np.arange(origin_matrix_size, dtype=jnp.int64)
|
||||
# zero_indices = jnp.setdiff1d(
|
||||
# all_indices, nonzero_indices, size=zero_size, assume_unique=True) # setdiff1d通常不支持jit 需要指定size
|
||||
index_map = inverse_permutation(
|
||||
jnp.concatenate([nonzero_indices, zero_indices]))
|
||||
u_values = jnp.concatenate([nonzero_values, jnp.zeros(len(zero_indices))])
|
||||
new_matrix = u_values[index_map]
|
||||
return new_matrix
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3))
|
||||
def matrix_set_1(nonzero_values, nonzero_indices, origin_matrix_size, zero_size):
|
||||
"""
|
||||
补齐矩阵 对原始矩阵指定区域置1
|
||||
Parameters
|
||||
----------
|
||||
nonzero_values : 已存在数值的矩阵
|
||||
nonzero_indices : 已存在数值矩阵对应的索引编号
|
||||
origin_matrix_size : 原始矩阵元素个数 即包含设计域和非设计域
|
||||
zero_size : 矩阵中0元素的个数 代表非设计域的单元个数
|
||||
|
||||
Returns
|
||||
-------
|
||||
new_matrix : 补齐后的矩阵
|
||||
"""
|
||||
all_indices = jnp.arange(origin_matrix_size, dtype=jnp.int64)
|
||||
zero_indices = jnp.setdiff1d(
|
||||
all_indices, nonzero_indices, size=zero_size, assume_unique=True)
|
||||
index_map = inverse_permutation(
|
||||
jnp.concatenate([nonzero_indices, zero_indices]))
|
||||
u_values = jnp.concatenate([nonzero_values, jnp.ones(len(zero_indices))])
|
||||
new_matrix = u_values[index_map]
|
||||
return new_matrix
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(1, 2, 10))
|
||||
def objective(xPhys_1D, young, young_min, freedofs, fext, s_K_predefined, dispTD_predefined, idx, edofMat, ke, penal):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
xPhys_1D : shape:[Total_ele](1D) 参与有限元计算的相对密度
|
||||
young : 杨氏模量
|
||||
young_min : 弱单元的杨氏模量
|
||||
freedofs : 有限元的自由节点
|
||||
fext : Python中的1D矢量 外力
|
||||
s_K_predefined : 预定义的刚度阵
|
||||
dispTD_predefined : 预定义的位移阵
|
||||
idx : 刚度阵组装编号
|
||||
edofMat : shape:[Total_ele, 8](2D) 按照单元顺序存储每一个单元的自由度编号
|
||||
ke : 单元刚度阵 不考虑杨氏模量
|
||||
penal : 惩罚因子
|
||||
|
||||
Returns
|
||||
-------
|
||||
obj : 目标函数输出 此处为结构的柔度
|
||||
"""
|
||||
# xPhys_1D = xPhys.T.flatten()
|
||||
disp = TO_FEA.FEA(young, young_min, ke, xPhys_1D, freedofs,
|
||||
penal, idx, fext, s_K_predefined, dispTD_predefined)
|
||||
obj = TO_Obj.compliance(young, young_min, ke,
|
||||
xPhys_1D, edofMat, penal, disp)
|
||||
# print('obj is running')
|
||||
return obj
|
|
@ -0,0 +1,245 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Wed Apr 22 17:40:00 2020
|
||||
|
||||
@author: zzy
|
||||
|
||||
Object: 对TO问题进行定义
|
||||
1. 对于部分确定的static variable 应采用numpy定义
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import jax.ops as jops
|
||||
import TO_Cal
|
||||
import TO_FEA
|
||||
|
||||
|
||||
def Toparams(problem):
|
||||
young = 1.0
|
||||
rou = 1.0
|
||||
nelx, nely, ele_length, ele_width = problem.nelx, problem.nely, problem.ele_length, problem.ele_width
|
||||
Total_ele, Total_nod, Total_dof = nelx * \
|
||||
nely, (nelx+1)*(nely+1), 2*(nelx+1)*(nely+1)
|
||||
gobalcoords0xy, M_elenod, M_edofMat, M_iK, M_jK = coordinate_and_preFEA(
|
||||
nelx, nely, ele_length, ele_width, Total_ele, Total_nod)
|
||||
elenod, edofMat, iK, jK = M_elenod-1, M_edofMat-1, M_iK-1, M_jK - \
|
||||
1 # gobalcoords0xy是绝对坐标无需更改 其他根据索引方式均-1 纯粹的数值处理 后面索引用的IE也要从0起
|
||||
edofMat_3D = TO_edofMat_3DTensor(nelx, nely) # 这是一个Tensor版本edofMat
|
||||
design_map_bool, num_design_ele, num_nondesign_ele, design_area_indices, nondesign_area_indices = handle_design_region(
|
||||
nelx, nely, Total_ele, problem.design_map)
|
||||
params = {
|
||||
'young': young,
|
||||
'young_min': 1e-9*young,
|
||||
'rou': rou,
|
||||
'poisson': 0.3,
|
||||
'g': 0,
|
||||
'ele_length': ele_length,
|
||||
'ele_width': ele_width,
|
||||
'ele_thickness': problem.ele_thickness,
|
||||
'nelx': nelx,
|
||||
'nely': nely,
|
||||
'Total_ele': Total_ele,
|
||||
'Total_nod': Total_nod,
|
||||
'Total_dof': Total_dof,
|
||||
'volfrac': problem.volfrac,
|
||||
'xmin': 0.001,
|
||||
'xmax': 1.0,
|
||||
'design_map': problem.design_map,
|
||||
# 'design_map_bool': design_map_bool,
|
||||
# 'num_design_ele': num_design_ele,
|
||||
# 'num_nondesign_ele': num_nondesign_ele,
|
||||
'design_area_indices': design_area_indices,
|
||||
'nondesign_area_indices': nondesign_area_indices,
|
||||
'freedofs': problem.freedofs,
|
||||
'fixeddofs': problem.fixeddofs,
|
||||
'fext': problem.fext,
|
||||
'penal': 3.0,
|
||||
'filter_width': problem.filter_width,
|
||||
'gobalcoords0xy': gobalcoords0xy,
|
||||
'M_elenod': M_elenod,
|
||||
'M_edofMat': M_edofMat,
|
||||
'M_iK': M_iK,
|
||||
'M_jK': M_jK,
|
||||
'elenod': elenod,
|
||||
'edofMat': edofMat,
|
||||
'edofMat_3D': edofMat_3D,
|
||||
'iK': iK,
|
||||
'jK': jK,
|
||||
'design_condition': problem.design_condition,
|
||||
}
|
||||
return params
|
||||
|
||||
|
||||
class Topology_Optimization:
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.young, self.young_min, self.poisson = args['young'], args['young_min'], args['poisson']
|
||||
self.Total_dof, self.Total_ele, self.nelx, self.nely, self.ele_thickness = args[
|
||||
'Total_dof'], args['Total_ele'], args['nelx'], args['nely'], args['ele_thickness']
|
||||
self.freedofs, self.fext, self.iK, self.jK, self.edofMat = args[
|
||||
'freedofs'], args['fext'], args['iK'], args['jK'], args['edofMat']
|
||||
self.filter_width, self.design_condition, self.volfrac = args[
|
||||
'filter_width'], args['design_condition'], args['volfrac']
|
||||
self.design_map, self.design_area_indices, self.nondesign_area_indices = args[
|
||||
'design_map'], args['design_area_indices'], args['nondesign_area_indices']
|
||||
self.ke = TO_FEA.linear_stiffness_matrix(
|
||||
self.young, self.poisson, self.ele_thickness)
|
||||
self.filter_kernal, self.filter_weight = TO_Cal.filter_define(
|
||||
self.design_map, self.filter_width)
|
||||
self.s_K_predefined, self.dispTD_predefined, self.idx = jnp.zeros(
|
||||
[self.Total_dof, self.Total_dof]), jnp.zeros([self.Total_dof]), jops.index[self.iK, self.jK]
|
||||
self.loop_1, self.iCont = 0, 0
|
||||
self.penal_use, self.beta_x_use = 3.0, 1.0
|
||||
self.beta_x_use_max, self.beta_x_deltath1, self.beta_x_deltath2 = 16, 50, 30,
|
||||
self.xPhys = self.volfrac * jnp.ones([self.nely, self.nelx])
|
||||
|
||||
def xPhys_transform(self, params, projection=False):
|
||||
x = params.reshape(self.nely, self.nelx)
|
||||
beta_x = self.penal_and_betax()
|
||||
xPhys_1D = TO_Cal.design_variable(x, self.design_area_indices, self.nondesign_area_indices,
|
||||
self.filter_kernal, self.filter_weight, beta_x, self.volfrac, projection=projection)
|
||||
return xPhys_1D
|
||||
|
||||
def objective(self, params, projection=False):
|
||||
self.loop_1 = self.loop_1 + 1
|
||||
xPhys_1D = self.xPhys_transform(params, projection=projection)
|
||||
Obj = TO_Cal.objective(xPhys_1D, self.young, self.young_min, self.freedofs, self.fext,
|
||||
self.s_K_predefined, self.dispTD_predefined, self.idx, self.edofMat, self.ke, self.penal_use)
|
||||
self.xPhys = xPhys_1D.reshape(self.nely, self.nelx, order='F')
|
||||
return Obj
|
||||
|
||||
def penal_and_betax(self, *args):
|
||||
self.iCont = self.iCont + 1
|
||||
if (self.iCont > self.beta_x_deltath1 and self.beta_x_use < 2):
|
||||
self.beta_x_use = self.beta_x_use * 2
|
||||
self.iCont = 1
|
||||
print(' beta_x increased to :%7.3f\n' % (self.beta_x_use))
|
||||
elif (self.iCont > self.beta_x_deltath2 and self.beta_x_use < self.beta_x_use_max and self.beta_x_use >= 2):
|
||||
self.beta_x_use = self.beta_x_use * 2
|
||||
self.iCont = 1
|
||||
print(' beta_x increased to :%7.3f\n' % (self.beta_x_use))
|
||||
return self.beta_x_use
|
||||
|
||||
|
||||
def coordinate_and_preFEA(nelx, nely, ele_length, ele_width, Total_ele, Total_nod):
|
||||
"""
|
||||
有限元计算相关量 单元及节点编号顺序为从左至右 从上至下
|
||||
采用Matlab计数法 对单一单元内物理量 按左下逆时针顺序
|
||||
Parameters
|
||||
----------
|
||||
nelx : 横向网格划分
|
||||
nely : 纵向网格划分
|
||||
ele_length : 单元长度
|
||||
ele_width : 单元宽度
|
||||
Total_ele : 单元总数
|
||||
Total_nod : 节点总数
|
||||
|
||||
Returns
|
||||
-------
|
||||
gobalcoords0xy : shape:[Total_nod, 2](2D) 按照节点顺序存储每一个节点的物理坐标(x, y)
|
||||
elenod : shape:[Total_ele, 4](2D) 按照单元顺序存储每一个单元的节点编号
|
||||
edofMat : shape:[Total_ele, 8](2D) 按照单元顺序存储每一个单元的自由度编号
|
||||
iK_int : shape:[Total_ele * 64](2D) 用于刚度阵组装 按照特定顺序储存自由度编号
|
||||
jK_int : shape:[Total_ele * 64](2D) 用于刚度阵组装 按照特定顺序储存自由度编号
|
||||
"""
|
||||
i0, j0 = jnp.meshgrid(jnp.arange(nelx+1), jnp.arange(nely+1))
|
||||
gobalcoords0x = i0*ele_length
|
||||
gobalcoords0y = ele_width*nely-j0*ele_width
|
||||
gobalcoords0xy = jnp.hstack((matrix_array(gobalcoords0x), matrix_array(
|
||||
gobalcoords0y)))
|
||||
gobalcoords0xyT = gobalcoords0xy.T
|
||||
gobalcoords0 = matrix_array(gobalcoords0xyT)
|
||||
nodegrd = matrix_order_arrange(1, Total_nod, nely+1, nelx+1)
|
||||
nodelast = Matlab_reshape(Matlab_matrix_extract(
|
||||
nodegrd, 1, -1, 1, -1), Total_ele, 1)
|
||||
edofVec = 2*matrix_array(nodelast)+1
|
||||
elenod = jnp.tile(nodelast, (1, 4)) + \
|
||||
jnp.tile(jnp.array([1, nely+2, nely+1, 0]), (Total_ele, 1))
|
||||
edofMat = jnp.tile(edofVec, (1, 8))+jnp.tile(jnp.array(
|
||||
[0, 1, 2*nely+2, 2*nely+3, 2*nely+0, 2*nely+1, -2, -1]), (Total_ele, 1))
|
||||
iK = jnp.kron(edofMat, jnp.ones((8, 1))).flatten()
|
||||
jK = jnp.kron(edofMat, jnp.ones((1, 8))).flatten()
|
||||
iK_int = iK.astype(jnp.int32)
|
||||
jK_int = jK.astype(jnp.int32)
|
||||
return gobalcoords0xy, elenod, edofMat, iK_int, jK_int
|
||||
|
||||
|
||||
def handle_design_region(nelx, nely, Total_ele, design_map):
|
||||
"""
|
||||
对设计域/非设计域相关参数进行预定义
|
||||
Parameters
|
||||
----------
|
||||
nelx : 横向网格划分
|
||||
nely : 纵向网格划分
|
||||
Total_ele : 单元总数
|
||||
design_map : 设计域和非设计域定义矩阵
|
||||
|
||||
Returns
|
||||
-------
|
||||
design_map_bool : bool形式的设计域和非设计域定义矩阵
|
||||
num_design_ele : 可设计单元总数
|
||||
num_nondesign_ele : 非设计单元总数
|
||||
design_area_indices : 设计域的单元编号索引
|
||||
nondesign_area_indices : 非设计域的单元编号索引
|
||||
"""
|
||||
shape = (nely, nelx)
|
||||
design_map_bool = np.broadcast_to(design_map, shape) > 0
|
||||
num_design_ele = np.sum(design_map)
|
||||
num_nondesign_ele = Total_ele - num_design_ele
|
||||
|
||||
design_map_bool_1D = design_map_bool.flatten(order='F')
|
||||
all_indices = np.arange(Total_ele, dtype=jnp.int64)
|
||||
design_area_indices = jnp.flatnonzero(
|
||||
design_map_bool_1D, size=num_design_ele)
|
||||
nondesign_area_indices = jnp.setdiff1d(
|
||||
all_indices, design_area_indices, size=num_nondesign_ele, assume_unique=True) # setdiff1d通常不支持jit 需要指定size
|
||||
return design_map_bool, num_design_ele, num_nondesign_ele, design_area_indices, nondesign_area_indices
|
||||
|
||||
|
||||
def Matlab_matrix_array1D(matrixA):
|
||||
"""
|
||||
shape:[*] 1D array 按照matlab的排列
|
||||
"""
|
||||
return matrixA.T.flatten()
|
||||
|
||||
|
||||
def matrix_array(matrixA):
|
||||
"""
|
||||
shape:[*, 1] 2D array 按照matlab的排列 matrixA(:)
|
||||
"""
|
||||
return matrixA.T.reshape(-1, 1)
|
||||
|
||||
|
||||
def matrix_order_arrange(numbegin, numend, ydim, xdim):
|
||||
'''
|
||||
shape:[ydim, xdim] 2D array 实现一个元素按列顺序排列的矩阵 注意索引编号和matlab的差1
|
||||
'''
|
||||
matrixA = jnp.array([[i for i in range(numbegin, numend+1)]])
|
||||
matrixA = matrixA.reshape(xdim, ydim)
|
||||
matrixA = matrixA.T
|
||||
return matrixA
|
||||
|
||||
|
||||
def Matlab_reshape(matrixA, ydim, xdim):
|
||||
'''
|
||||
shape:[ydim, xdim] 2D array matlab的reshape效果
|
||||
'''
|
||||
return matrixA.reshape((ydim, xdim), order='F')
|
||||
|
||||
|
||||
def Matlab_matrix_extract(matrixA, xbegin, xend, ybegin, yend):
|
||||
matrixA = matrixA[ybegin-1:yend, xbegin-1:xend]
|
||||
return matrixA
|
||||
|
||||
|
||||
def TO_edofMat_3DTensor(nelx, nely):
|
||||
ely, elx = jnp.meshgrid(jnp.arange(nely), jnp.arange(nelx)) # x, y coords
|
||||
n1 = (nely+1)*(elx+0) + (ely+0)
|
||||
n2 = (nely+1)*(elx+1) + (ely+0)
|
||||
n3 = (nely+1)*(elx+1) + (ely+1)
|
||||
n4 = (nely+1)*(elx+0) + (ely+1)
|
||||
edofMat_3D = jnp.array(
|
||||
[2*n4, 2*n4+1, 2*n3, 2*n3+1, 2*n2, 2*n2+1, 2*n1, 2*n1+1])
|
||||
return edofMat_3D
|
|
@ -0,0 +1,133 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Sat Apr 25 01:57:01 2020
|
||||
|
||||
@author: zzy
|
||||
|
||||
Object: 存储有限元计算所需函数
|
||||
1. 所有函数均可jit
|
||||
2. 在调用时 选择在最外层封装jit 内层函数则将jit注释
|
||||
3. 对于部分确定的static variable 应采用numpy定义(仅限CPU)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from numpy import float64
|
||||
import jax.numpy as jnp
|
||||
import jax.ops as jops
|
||||
import jax.scipy.linalg as jsl
|
||||
from jax import jit
|
||||
import TO_Private_Autodiff as TPA
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True) # 对计算精度影响很大
|
||||
|
||||
|
||||
def linear_stiffness_matrix(young, poisson, ele_thickness):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
young : 杨氏模量
|
||||
poisson : 泊松比
|
||||
ele_thickness : 单元厚度
|
||||
|
||||
Returns
|
||||
-------
|
||||
stiffness_ele : 单元刚度矩阵 注意 考虑到TO问题 此处没有考虑单元杨式模量的影响
|
||||
"""
|
||||
young, poisson, h = young, poisson, ele_thickness
|
||||
k = np.array([1/2-poisson/6, 1/8+poisson/8, -1/4-poisson/12, -1/8+3*poisson/8,
|
||||
-1/4+poisson/12, -1/8-poisson/8, poisson/6, 1/8-3*poisson/8])
|
||||
stiffness_ele = h/(1-poisson**2)*np.array([[k[0], k[1], k[2], k[3], k[4], k[5], k[6], k[7]],
|
||||
[k[1], k[0], k[7], k[6],
|
||||
k[5], k[4], k[3], k[2]],
|
||||
[k[2], k[7], k[0], k[5],
|
||||
k[6], k[3], k[4], k[1]],
|
||||
[k[3], k[6], k[5], k[0],
|
||||
k[7], k[2], k[1], k[4]],
|
||||
[k[4], k[5], k[6], k[7],
|
||||
k[0], k[1], k[2], k[3]],
|
||||
[k[5], k[4], k[3], k[2],
|
||||
k[1], k[0], k[7], k[6]],
|
||||
[k[6], k[3], k[4], k[1],
|
||||
k[2], k[7], k[0], k[5]],
|
||||
[k[7], k[2], k[1], k[4],
|
||||
k[3], k[6], k[5], k[0]]
|
||||
], dtype=float64)
|
||||
return stiffness_ele
|
||||
|
||||
|
||||
def density_matrix(rou, ele_length, ele_width, ele_thickness):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
rou : 密度
|
||||
ele_length : 单元长度
|
||||
ele_width : 单元宽度
|
||||
ele_thickness : 单元厚度
|
||||
|
||||
Returns
|
||||
-------
|
||||
density_ele : 单元密度矩阵(一致质量矩阵) 注意 考虑到TO问题 此处没有考虑单元相对密度的影响 否则分子应当是质量
|
||||
"""
|
||||
ele_volume = ele_length*ele_width*ele_thickness
|
||||
density_ele = ele_volume / 36 * np.array([[4, 0, 2, 0, 1, 0, 2, 0],
|
||||
[0, 4, 0, 2,
|
||||
0, 1, 0, 2],
|
||||
[2, 0, 4, 0,
|
||||
2, 0, 1, 0],
|
||||
[0, 2, 0, 4,
|
||||
0, 2, 0, 1],
|
||||
[1, 0, 2, 0,
|
||||
4, 0, 2, 0],
|
||||
[0, 1, 0, 2,
|
||||
0, 4, 0, 2],
|
||||
[2, 0, 1, 0,
|
||||
2, 0, 4, 0],
|
||||
[0, 2, 0, 1,
|
||||
0, 2, 0, 4],
|
||||
], dtype=float64)
|
||||
return density_ele
|
||||
|
||||
|
||||
# @jit
|
||||
def FEA(young, young_min, ke, xPhys_1D, freedofs, penal, idx, fext, s_K_predefined, dispTD_predefined):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
young : 杨氏模量
|
||||
young_min : 弱单元的杨氏模量
|
||||
ke : 单元刚度阵 不考虑杨氏模量
|
||||
xPhys_1D : shape:[Total_ele](1D) 参与有限元计算的单元相对密度阵 (这种形式是必要的)
|
||||
freedofs : 有限元的自由节点
|
||||
penal : 惩罚因子
|
||||
idx : 刚度阵组装编号
|
||||
fext : shape:[Total_dof](1D) 外力
|
||||
s_K_predefined : shape:[Total_dof, Total_dof](2D) 预定义的刚度阵
|
||||
dispTD_predefined : shape:[Total_dof](1D) 预定义的位移阵
|
||||
|
||||
Returns
|
||||
-------
|
||||
dispTD :shape:[Total_dof](1D) 节点位移
|
||||
"""
|
||||
|
||||
def kuf_solve(fext_1D_free, s_K_free):
|
||||
u_free = jsl.solve(s_K_free, fext_1D_free,
|
||||
sym_pos=True, check_finite=False)
|
||||
return u_free
|
||||
|
||||
def young_modulus(xPhys_1D, young, young_min, ke, penal):
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
sK : shape:[64, Total_ele](2D) 考虑相对密度后的每个单元的刚度阵(列形式存储)
|
||||
"""
|
||||
sK = ((ke.flatten()[jnp.newaxis]).T * (young_min + (xPhys_1D) ** penal * (young-young_min))
|
||||
).flatten(order='F')
|
||||
return sK
|
||||
|
||||
sK = young_modulus(xPhys_1D, young, young_min, ke, penal)
|
||||
s_K = jops.index_add(s_K_predefined, idx, sK)
|
||||
s_K_free = s_K[freedofs, :][:, freedofs]
|
||||
fext_free = fext[freedofs]
|
||||
u_free = kuf_solve(fext_free, s_K_free)
|
||||
dispTD = jops.index_add(dispTD_predefined, freedofs, u_free)
|
||||
return dispTD
|
|
@ -0,0 +1,249 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Tue Oct 19 00:00:06 2021
|
||||
|
||||
@author: zzy
|
||||
|
||||
Object: AuTONR的model构建
|
||||
1. 基于Jax+Flax+JAXopt/Optax
|
||||
2. 关于dataclasses和flax.linen可以参考:
|
||||
learning_dataclasses.py
|
||||
learning_flax_module.py
|
||||
"""
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.image as ji
|
||||
import jax.tree_util as jtree
|
||||
from jax import random
|
||||
import flax.linen as nn
|
||||
import flax.linen.initializers as fli
|
||||
import flax.traverse_util as flu
|
||||
import flax.core.frozen_dict as fcfd
|
||||
from functools import partial
|
||||
from typing import Any, Callable
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
if seed is not None:
|
||||
rand_key = random.PRNGKey(seed)
|
||||
return rand_key
|
||||
|
||||
|
||||
def constant_array(rng_key, shape, dtype=jnp.float64, value=0.0):
|
||||
"""
|
||||
自定义参数初始化函数 实现矩阵填充数值
|
||||
Parameters
|
||||
----------
|
||||
rng_key :
|
||||
shape : 目标矩阵的shape
|
||||
dtype : 数据格式 The default is jnp.float64.
|
||||
value : 填充数值 The default is 0.0.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : 初始化后的矩阵
|
||||
"""
|
||||
out = jnp.full(shape, value, dtype)
|
||||
# print("i am running")
|
||||
return out
|
||||
|
||||
|
||||
def extract_model_params(params_use):
|
||||
"""
|
||||
提取网络参数
|
||||
Parameters
|
||||
----------
|
||||
params_use : FrozenDict 需要提取的参数
|
||||
|
||||
Returns
|
||||
-------
|
||||
extracted_params : dict 提取出的参数 {'Conv_0/bias': ..., 'Conv_0/kernel': ..., ...}
|
||||
params_shape : dict 提取的每一组参数的shape {'Conv_0/bias': (...), 'Conv_0/kernel': (...), ...}
|
||||
params_total_num : 参数总数
|
||||
"""
|
||||
params_unfreeze = params_use.unfreeze()
|
||||
extracted_params = {
|
||||
'/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()}
|
||||
params_shape = jtree.tree_map(jnp.shape, extracted_params)
|
||||
params_total_num = sum(p.size for p in jtree.tree_leaves(extracted_params))
|
||||
return extracted_params, params_shape, params_total_num
|
||||
|
||||
|
||||
def replace_model_params(pre_trained_params_dict, params_use):
|
||||
"""
|
||||
替换网络参数
|
||||
Parameters
|
||||
----------
|
||||
pre_trained_params_dict : dict 预训练的参数 {'Conv_0/bias': ..., 'Conv_0/kernel': ..., ...}
|
||||
params_use : FrozenDict 需要被替换的参数
|
||||
|
||||
Returns
|
||||
-------
|
||||
new_params_use : FrozenDict 新的参数
|
||||
"""
|
||||
params_unfreeze = params_use.unfreeze()
|
||||
# 有两种方式均可实现功能
|
||||
# Method A 基于flax.traverse_util实现
|
||||
extracted_params = {
|
||||
'/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()}
|
||||
for key, val in pre_trained_params_dict.items():
|
||||
extracted_params[key] = val
|
||||
new_params_use = flu.unflatten_dict(
|
||||
{tuple(k.split('/')): v for k, v in extracted_params.items()})
|
||||
# Method B 基于jax.tree实现
|
||||
params_leaves, params_struct = jtree.tree_flatten(params_unfreeze)
|
||||
i = 0
|
||||
for key, val in pre_trained_params_dict.items():
|
||||
params_leaves[i] = val
|
||||
i = i + 1
|
||||
new_params_use = jtree.tree_unflatten(params_struct, params_leaves)
|
||||
# 封装新的参数
|
||||
new_params_use = fcfd.freeze(new_params_use)
|
||||
return new_params_use
|
||||
|
||||
|
||||
def batched_loss(x_inputs, TopOpt_envs):
|
||||
"""
|
||||
处理带有batch_size和多个实例化的TopOpt类的情况
|
||||
Parameters
|
||||
----------
|
||||
x_inputs : shape:[1, nely, nelx] (3D array) 第一个维度代表batch 第二和第三个维度代表TO_2D问题中的单元相对密度 通常batch_size = 1
|
||||
TopOpt_envs : 多个实例化的TopOpt类 以列表形式存储 通常仅有一个实例化的类
|
||||
|
||||
Returns
|
||||
-------
|
||||
losses_array : loss_list以列表形式存储对应于每一个batch和TopOpt的obj输出 将其转化为array形式
|
||||
"""
|
||||
loss_list = [TopOpt.objective(x_inputs[i], projection=True)
|
||||
for i, TopOpt in enumerate(TopOpt_envs)]
|
||||
losses_array = jnp.stack(loss_list)[0]
|
||||
return losses_array
|
||||
|
||||
|
||||
# @jax.jit
|
||||
class Normalize_TO(nn.Module):
|
||||
"""
|
||||
自定义Normalize类
|
||||
"""
|
||||
epsilon: float = 1e-6
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input):
|
||||
# 通过对axis设置可实现不同normalization的效果
|
||||
input_mean = jnp.mean(input, axis=(1, 2, 3), keepdims=True)
|
||||
input_var = jnp.var(input, axis=(1, 2, 3), keepdims=True)
|
||||
output = input - input_mean
|
||||
output = output * jax.lax.rsqrt(input_var + self.epsilon)
|
||||
return output
|
||||
|
||||
|
||||
class Add_offset(nn.Module):
|
||||
scale: int = 10
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input):
|
||||
add_bias = self.param('add_bias', lambda rng, shape: constant_array(
|
||||
rng, shape, value=0.0), input.shape)
|
||||
return input + self.scale * add_bias
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
resize_scale_one: int
|
||||
conv_output_one: int
|
||||
norm_use: nn.Module = Normalize_TO
|
||||
resize_method: str = "bilinear"
|
||||
kernel_size: int = 5
|
||||
conv_kernel_init: partial = fli.variance_scaling(
|
||||
scale=1.0, mode="fan_in", distribution="truncated_normal")
|
||||
offset: nn.Module = Add_offset
|
||||
offset_scale: int = 10
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input):
|
||||
output = jnp.tanh(input)
|
||||
output_shape = output.shape
|
||||
output = ji.resize(output, (output_shape[0], self.resize_scale_one * output_shape[1],
|
||||
self.resize_scale_one * output_shape[2], output_shape[-1]), method=self.resize_method, antialias=True)
|
||||
output = self.norm_use()(output)
|
||||
output = nn.Conv(features=self.conv_output_one, kernel_size=(
|
||||
self.kernel_size, self.kernel_size), kernel_init=self.conv_kernel_init)(output)
|
||||
output = self.offset(scale=self.offset_scale)(output)
|
||||
return output
|
||||
|
||||
|
||||
class My_TO_Model(nn.Module):
|
||||
TopOpt: Any
|
||||
seed: int = 1
|
||||
replace_params: Callable = replace_model_params
|
||||
extract_params: Callable = extract_model_params
|
||||
|
||||
def loss(self, input_TO):
|
||||
# 以list形式传入TopOpt 其实tuple形式也可以
|
||||
obj_TO = batched_loss(input_TO, [self.TopOpt])
|
||||
return obj_TO
|
||||
|
||||
|
||||
class TO_CNN(My_TO_Model):
|
||||
h: int = 5
|
||||
w: int = 10
|
||||
dense_scale_init: Any = 1.0
|
||||
dense_output_channel: int = 1000
|
||||
dtype: Any = jnp.float32
|
||||
input_init_std: float = 1.0 # 初始化设计变量的标准差 正态分布
|
||||
dense_input_channel: int = 128
|
||||
conv_input_0: int = 32
|
||||
conv_output: tuple = (128, 64, 32, 16, 1)
|
||||
up_sampling_scale: tuple = (1, 2, 2, 2, 1)
|
||||
offset_scale: int = 10
|
||||
block: nn.Module = BasicBlock
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x=None): # [N, H, W, C]
|
||||
nn_input = self.param('z', lambda rng, shape: constant_array(
|
||||
rng, shape, value=self.TopOpt.volfrac), self.dense_input_channel)
|
||||
|
||||
output = nn.Dense(features=self.dense_output_channel, kernel_init=fli.orthogonal(
|
||||
scale=self.dense_scale_init))(nn_input)
|
||||
output = output.reshape([1, self.h, self.w, self.conv_input_0])
|
||||
|
||||
output = self.block(
|
||||
self.up_sampling_scale[0], self.conv_output[0])(output)
|
||||
output = self.block(
|
||||
self.up_sampling_scale[1], self.conv_output[1])(output)
|
||||
output = self.block(
|
||||
self.up_sampling_scale[2], self.conv_output[2])(output)
|
||||
output = self.block(
|
||||
self.up_sampling_scale[3], self.conv_output[3])(output)
|
||||
output = self.block(
|
||||
self.up_sampling_scale[4], self.conv_output[4])(output)
|
||||
|
||||
nn_output = jnp.squeeze(output, axis=-1)
|
||||
self.sow('intermediates', 'model_out',
|
||||
nn_output) # 在这里是否可以考虑去掉batch维度?
|
||||
return nn_output
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def funA(params):
|
||||
out = 1 * jnp.sum(params)
|
||||
return out
|
||||
|
||||
def funB(params):
|
||||
out = 2 * jnp.sum(params)
|
||||
return out
|
||||
|
||||
def funC(params):
|
||||
out = 3 * jnp.sum(params)
|
||||
return out
|
||||
|
||||
# 测试 直观展示batched_loss
|
||||
params_1 = jnp.ones([3, 4, 4])
|
||||
fun_list = [funA, funB, funC]
|
||||
losses = [fun_used(params_1[i]) for i, fun_used in enumerate(fun_list)]
|
||||
loss_out = jnp.stack(losses)
|
||||
print("losses:", losses)
|
||||
print("loss_out:", loss_out)
|
|
@ -0,0 +1,22 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Sun Apr 26 01:59:52 2020
|
||||
|
||||
@author: zzy
|
||||
|
||||
Object: 存储TO计算目标函数过程所需的函数
|
||||
1. 所有函数均可jit
|
||||
2. 在调用时 选择在最外层封装jit 内层函数则将jit注释
|
||||
3. 对于部分确定的static variable 应采用numpy定义
|
||||
"""
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
|
||||
|
||||
# @jit
|
||||
def compliance(young, young_min, ke, xPhys_1D, edofMat, penal, disp):
|
||||
ce = jnp.sum(jnp.matmul(disp[edofMat], ke) * disp[edofMat], 1)
|
||||
ce = (young_min + xPhys_1D ** penal * (young - young_min)) * ce
|
||||
obj = jnp.sum(ce)
|
||||
return obj
|
|
@ -0,0 +1,129 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Tue Aug 11 14:20:11 2020
|
||||
|
||||
@author: zzy
|
||||
|
||||
定义设计域
|
||||
"""
|
||||
|
||||
import jax.ops as jops
|
||||
import numpy as np
|
||||
import dataclasses
|
||||
from typing import Optional, Union
|
||||
import jax.numpy as jnp
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ToProblem:
|
||||
L: int
|
||||
W: int
|
||||
nelx: int
|
||||
nely: int
|
||||
design_condition: int
|
||||
ele_length: np.float64
|
||||
ele_width: np.float64
|
||||
ele_thickness: np.float64
|
||||
volfrac: np.float64
|
||||
fixeddofs: np.ndarray
|
||||
freedofs: np.ndarray
|
||||
fext: np.ndarray
|
||||
filter_width: float
|
||||
design_map: np.ndarray
|
||||
|
||||
|
||||
def cantilever_single():
|
||||
L = 0.8
|
||||
W = 0.4
|
||||
nelx = 80
|
||||
nely = 40
|
||||
design_condition = 1
|
||||
Total_dof = 2*(nelx+1)*(nely+1)
|
||||
ele_length = L/nelx
|
||||
ele_width = W/nely
|
||||
ele_thickness = 1.0
|
||||
volfrac = 0.5
|
||||
dofs = np.arange(Total_dof)
|
||||
loaddof = Total_dof-1
|
||||
# fext = jnp.zeros(Total_dof)
|
||||
# fext = jops.index_update(fext, jops.index[loaddof], -1.)
|
||||
fext = np.zeros(Total_dof)
|
||||
fext[loaddof] = -1
|
||||
fixeddofs = np.array(dofs[0:2*(nely+1):1])
|
||||
freedofs = np.setdiff1d(dofs, fixeddofs)
|
||||
filter_width = 3.0
|
||||
design_map = np.ones([nely, nelx], dtype=np.int64)
|
||||
return ToProblem(L, W, nelx, nely, design_condition, ele_length, ele_width, ele_thickness, volfrac, fixeddofs, freedofs, fext, filter_width, design_map)
|
||||
|
||||
|
||||
def cantilever_single_NN(TopOpt):
|
||||
dense_input_channel = 128
|
||||
dense_init_scale = 1.0
|
||||
conv_input_0 = 32
|
||||
up_sampling_scale = (1, 2, 2, 2, 1)
|
||||
total_resize = int(np.prod(up_sampling_scale))
|
||||
h = TopOpt.nely // total_resize
|
||||
w = TopOpt.nelx // total_resize
|
||||
dense_output_channel = h * w * conv_input_0
|
||||
dense_scale_init = dense_init_scale * \
|
||||
jnp.sqrt(jnp.maximum(dense_output_channel / dense_input_channel, 1))
|
||||
model_kwargs = {
|
||||
'h': h,
|
||||
'w': w,
|
||||
'dense_scale_init': dense_scale_init,
|
||||
'dense_output_channel': dense_output_channel,
|
||||
'dense_input_channel': dense_input_channel,
|
||||
'conv_input_0': conv_input_0,
|
||||
'up_sampling_scale': up_sampling_scale,
|
||||
}
|
||||
return model_kwargs
|
||||
|
||||
|
||||
def cantilever_single_big():
|
||||
L = 1.6
|
||||
W = 0.8
|
||||
nelx = 160
|
||||
nely = 80
|
||||
design_condition = 2
|
||||
Total_dof = 2*(nelx+1)*(nely+1)
|
||||
ele_length = L/nelx
|
||||
ele_width = W/nely
|
||||
ele_thickness = 1.0
|
||||
volfrac = 0.5
|
||||
dofs = np.arange(Total_dof)
|
||||
loaddof = Total_dof-1
|
||||
# fext = jnp.zeros(Total_dof)
|
||||
# fext = jops.index_update(fext, jops.index[loaddof], -1.)
|
||||
fext = np.zeros(Total_dof)
|
||||
fext[loaddof] = -1
|
||||
fixeddofs = np.array(dofs[0:2*(nely+1):1])
|
||||
freedofs = np.setdiff1d(dofs, fixeddofs)
|
||||
filter_width = 3.0
|
||||
design_map = np.ones([nely, nelx], dtype=np.int64)
|
||||
return ToProblem(L, W, nelx, nely, design_condition, ele_length, ele_width, ele_thickness, volfrac, fixeddofs, freedofs, fext, filter_width, design_map)
|
||||
|
||||
|
||||
def cantilever_single_big_NN(TopOpt):
|
||||
dense_input_channel = 128
|
||||
dense_init_scale = 1.0
|
||||
conv_input_0 = 32
|
||||
up_sampling_scale = (1, 2, 2, 2, 2)
|
||||
total_resize = int(np.prod(up_sampling_scale))
|
||||
h = TopOpt.nely // total_resize
|
||||
w = TopOpt.nelx // total_resize
|
||||
dense_output_channel = h * w * conv_input_0
|
||||
dense_scale_init = dense_init_scale * \
|
||||
jnp.sqrt(jnp.maximum(dense_output_channel / dense_input_channel, 1))
|
||||
model_kwargs = {
|
||||
'h': h,
|
||||
'w': w,
|
||||
'dense_scale_init': dense_scale_init,
|
||||
'dense_output_channel': dense_output_channel,
|
||||
'dense_input_channel': dense_input_channel,
|
||||
'conv_input_0': conv_input_0,
|
||||
'up_sampling_scale': up_sampling_scale,
|
||||
}
|
||||
return model_kwargs
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Tue Oct 19 09:53:14 2021
|
||||
|
||||
@author: zzy
|
||||
|
||||
Object: AuTONR的训练函数
|
||||
1. 基于Flax+JAXopt/Optax实现
|
||||
2. 关于train过程的实现可以参考:
|
||||
learning_optax_training.py
|
||||
learning_flax_managing_params.py
|
||||
learning_flax_training.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from absl import logging
|
||||
import xarray
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
from jaxopt import OptaxSolver
|
||||
import optax
|
||||
import time
|
||||
from functools import partial
|
||||
import matplotlib.pyplot as plt
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
sys.path.append(
|
||||
'/home/zzy/ZZY_CODE/Env_JAX/Topology_Optimization/Linear_TO')
|
||||
here = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
if seed is not None:
|
||||
rand_key = random.PRNGKey(seed)
|
||||
return rand_key
|
||||
|
||||
|
||||
def optimizer_result_dataset(losses, frames, save_intermediate_designs=False):
|
||||
# The best design will often but not always be the final one.
|
||||
best_design = jnp.nanargmin(losses)
|
||||
logging.info(f'Final loss: {losses[best_design]}')
|
||||
if save_intermediate_designs:
|
||||
ds = xarray.Dataset({
|
||||
'loss': (('step',), losses),
|
||||
'design': (('step', 'y', 'x'), frames),
|
||||
}, coords={'step': jnp.arange(len(losses))})
|
||||
else:
|
||||
ds = xarray.Dataset({
|
||||
'loss': (('step',), losses),
|
||||
'design': (('y', 'x'), frames[best_design]),
|
||||
}, coords={'step': jnp.arange(len(losses))})
|
||||
return ds
|
||||
|
||||
|
||||
def train_TO_Optax(model, max_iterations, save_intermediate_designs=True, seed=1, **kwargs):
|
||||
losses = []
|
||||
frames = []
|
||||
# Initialize parameters
|
||||
init_params = model.init(set_random_seed(seed), None)
|
||||
batch_state, params = init_params.pop("params")
|
||||
del init_params
|
||||
# Instantiate Optimizer
|
||||
design_condition = model.TopOpt.design_condition
|
||||
if design_condition == 1:
|
||||
learning_rate = 0.001
|
||||
optimizer = optax.adam(learning_rate)
|
||||
elif design_condition == 2:
|
||||
learning_rate = 0.001
|
||||
optimizer = optax.adam(learning_rate)
|
||||
# loss
|
||||
def loss_cal(params):
|
||||
all_params = {"params": params}
|
||||
model_out, net_state = model.apply(all_params, None, mutable="intermediates")
|
||||
loss_out = model.loss(model_out)
|
||||
return loss_out, net_state["intermediates"]
|
||||
loss_grad_fn = jax.value_and_grad(loss_cal, has_aux=True)
|
||||
# Initialize Optimizer State
|
||||
state = optimizer.init(params)
|
||||
# Updated
|
||||
for iter in range(max_iterations):
|
||||
start_time_epoch = time.perf_counter()
|
||||
(loss_val, batch_stats), grads = loss_grad_fn(params)
|
||||
losses.append(jax.device_get(loss_val))
|
||||
updates_params, state = optimizer.update(grads, state)
|
||||
params = optax.apply_updates(params, updates_params)
|
||||
design_TO = jax.device_get(model.TopOpt.xPhys.aval.val)
|
||||
frames.append(design_TO)
|
||||
plt.figure()
|
||||
plt.imshow(design_TO, cmap='Greys')
|
||||
plt.show()
|
||||
whole_time_epoch = time.perf_counter() - start_time_epoch
|
||||
print(' It.:%5i Obj.:%11.4f Vol.:%7.3f Time.:%7.3f\n' %
|
||||
(iter+1, losses[-1], jnp.mean(design_TO), whole_time_epoch))
|
||||
|
||||
return optimizer_result_dataset(jnp.array(losses), jnp.array(frames), save_intermediate_designs)
|
||||
|
||||
|
||||
def train_TO_JAXopt(model, max_iterations, save_intermediate_designs=True, seed=1, **kwargs):
|
||||
losses = []
|
||||
frames = []
|
||||
# Initialize parameters
|
||||
init_params = model.init(set_random_seed(seed), None)
|
||||
# batch_stats, params = init_params.pop("params")
|
||||
params = init_params["params"]
|
||||
# batch_stats = init_params["batch_stats"]
|
||||
del init_params
|
||||
|
||||
# Initialize solver
|
||||
# @jax.jit
|
||||
def loss_cal(params):
|
||||
# all_params = {"params": params, "batch_stats": aux}
|
||||
# model_out, net_state = model.apply(all_params, None)
|
||||
all_params = {"params": params}
|
||||
model_out, net_state = model.apply(
|
||||
all_params, None, mutable="intermediates")
|
||||
loss_out = model.loss(model_out)
|
||||
# design_TO = model.TopOpt.x_calculation(
|
||||
# model_out, volume_contraint=True, projection=True)
|
||||
# losses.append(jax.device_get(loss_out))
|
||||
return loss_out, net_state["intermediates"]
|
||||
|
||||
design_condition = model.TopOpt.design_condition
|
||||
|
||||
def opt_solver_setting(design_condition, loss_fun, max_iterations):
|
||||
if design_condition == 1:
|
||||
learning_rate = 0.001
|
||||
opt_use = optax.adam(learning_rate)
|
||||
solver = OptaxSolver(
|
||||
fun=loss_fun, opt=opt_use, maxiter=max_iterations, has_aux=True)
|
||||
elif design_condition == 2:
|
||||
learning_rate = 0.000375
|
||||
elif design_condition == 6:
|
||||
learning_rate = 0.001
|
||||
elif design_condition == 5:
|
||||
learning_rate = 0.000375
|
||||
return solver
|
||||
|
||||
solver = opt_solver_setting(design_condition, loss_cal, max_iterations)
|
||||
|
||||
state = solver.init_state(params)
|
||||
for iter in range(max_iterations):
|
||||
start_time_epoch = time.time()
|
||||
# params, state = solver.update(
|
||||
# params=params, state=state, aux=batch_stats)
|
||||
|
||||
params, state = solver.update(params=params, state=state)
|
||||
batch_stats = state.aux
|
||||
losses.append(jax.device_get(state.value)) # 放在这里应该更好一些
|
||||
# design_TO = jax.device_get(batch_stats._dict['model_out'])[-1]
|
||||
# design_TO = model.TopOpt.x_calculation(design_TO, volume_constraint=True, projection=True)
|
||||
design_TO = jax.device_get(model.TopOpt.xPhys.aval.val)
|
||||
frames.append(design_TO)
|
||||
plt.figure()
|
||||
plt.imshow(design_TO, cmap='Greys')
|
||||
plt.show()
|
||||
end_time_epoch = time.time()
|
||||
whole_time_epoch = end_time_epoch - start_time_epoch
|
||||
print(' It.:%5i Obj.:%11.4f Vol.:%7.3f Time.:%7.3f\n' %
|
||||
(iter+1, losses[-1], jnp.mean(design_TO), whole_time_epoch))
|
||||
|
||||
# frames = jax.device_get(batch_stats._dict['intermediates']['model_out'])
|
||||
return optimizer_result_dataset(jnp.array(losses), jnp.array(frames), save_intermediate_designs)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue