commit da7820895e416cb9e184ec0036d4981c2295fef5 Author: zhangzeyu Date: Tue Aug 16 15:35:58 2022 +0800 2022.8.16 Commit diff --git a/Linear_TO/TONR_Linear_Stiffness.py b/Linear_TO/TONR_Linear_Stiffness.py new file mode 100644 index 0000000..93850e9 --- /dev/null +++ b/Linear_TO/TONR_Linear_Stiffness.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +''' +@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved. +@ Author : Zeyu Zhang +@ Email : zhangzeyu_work@outlook.com +@ Date : 2021-11-17 10:52:26 +@ LastEditTime : 2022-08-16 14:45:12 +@ FilePath : /Topology_Optimization/Linear_TO/TONR_Linear_Stiffness.py +@ +@ Description : Jax+Flax+optax TONR Linear Stiffness Problem +@ Reference : +''' + +from pathlib import Path # 文件路径的相关操作 +import sys +import pandas as pd +import xarray +import matplotlib.pyplot as plt +import seaborn +import time +import TO_Problem +import TO_Define +import TO_Model +import TO_Train +import jax.numpy as jnp +from jax import random +from jax.config import config +config.update("jax_enable_x64", True) # 对计算精度影响很大 +sys.path.append( + '/home/zzy/ZZY_CODE/Env_JAX/Topology_Optimization/Linear_TO') +here = Path(__file__).resolve().parent + + +design_condition = 1 +if design_condition == 1: + problem = TO_Problem.cantilever_single() + max_iterations = 200 + args = TO_Define.Toparams(problem) + TopOpt_env = TO_Define.Topology_Optimization(args) + model_kwargs = TO_Problem.cantilever_single_NN(TopOpt_env) +elif design_condition == 2: + problem = TO_Problem.cantilever_single_big() + max_iterations = 200 + args = TO_Define.Toparams(problem) + TopOpt_env = TO_Define.Topology_Optimization(args) + model_kwargs = TO_Problem.cantilever_single_big_NN(TopOpt_env) + +TONR_model = TO_Model.TO_CNN(TopOpt=TopOpt_env, **model_kwargs) + +if __name__ == '__main__': + + start_time = time.perf_counter() + seed = 1 + ds_TopOpt = TO_Train.train_TO_Optax( + TONR_model, max_iterations, save_intermediate_designs=True, seed=1) + ds_TopOpt.to_netcdf('ds_TopOpt.nc') # save + # ds_TopOpt = xarray.open_dataset('ds_TopOpt.nc') # load + end_time = time.perf_counter() + whole_time = end_time - start_time + + obj_value = ds_TopOpt.loss.data + x_designed = ds_TopOpt.design.data + + obj_min = jnp.min(obj_value) + obj_max = jnp.max(obj_value) + obj_min_index = jnp.where(obj_value == obj_min)[0][0] + + dims = pd.Index(['TONR'], name='model') + ds = xarray.concat([ds_TopOpt], dim=dims) + ds_TopOpt.loss.transpose().to_pandas().cummin().plot(linewidth=2) + plt.ylim(obj_min * 0.85, obj_value[0] * 1.15) + plt.ylabel('Compliance (loss)') + plt.xlabel('Optimization step') + seaborn.despine() + + final_design = x_designed[obj_min_index, :, :] + final_obj = obj_min + + plt.show() + ds_TopOpt.design.sel(step=obj_min_index).plot.imshow(x='x', y='y', size=2, + aspect=2.5, col_wrap=2, yincrease=False, add_colorbar=False, cmap='Greys') + + def save_gif_movie(images, path, duration=200, loop=0, **kwargs): + images[0].save(path, save_all=True, append_images=images[1:], + duration=duration, loop=loop, **kwargs) + + # images = [ + # TO_Support.image_from_design(design, problem) + # for design in ds.design.sel(model='DL_TO')[:150] + # ] + + # save_gif_movie([im.resize((5*120, 5*20)) for im in images], 'movie.gif') diff --git a/Linear_TO/TO_Cal.py b/Linear_TO/TO_Cal.py new file mode 100644 index 0000000..596b50c --- /dev/null +++ b/Linear_TO/TO_Cal.py @@ -0,0 +1,253 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Apr 23 15:08:01 2020 + +@author: zzy + +Object: 存储TO计算过程所需的函数 +1. 所有函数均可jit +2. 在调用时 选择在最外层封装jit 内层函数则将jit注释 +3. 对于部分确定的static variable 应采用numpy定义(仅限CPU) +""" + +import numpy as np +import jax +import jax.numpy as jnp +import jax.scipy.signal as jss +import jax.lax as jl +from jaxopt import Bisection +from functools import partial +import TO_Private_Autodiff as TPA +import TO_FEA +import TO_Obj +from jax.config import config +config.update("jax_enable_x64", True) + + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def design_variable(x, design_area_indices, nondesign_area_indices, filter_kernal, filter_weight, beta_x, volfrac, projection=False): + """ + Parameters + ---------- + x : shape:[nely, nelx](2D) 经过reshape的NN输出 + design_area_indices : 设计域的单元编号索引 + nondesign_area_indices : 非设计域的单元编号索引 + filter_kernal : TOP88中的H 卷积过滤中的卷积核 计算影响域内每一个单元的权重 + filter_weight : TOP88中的HS 设计域内每一个单元的总权重 与网格划分同维度 + beta_x : Projection参数 + volfrac : 目标体积 + projection : 是否采用Projection + + Returns + ------- + xPhys_1D : shape:[Total_ele](1D) 参与有限元计算的单元相对密度阵 + """ + + # density filtering 用在了转换之前 目前看也是有效果的 + x = filter(x, filter_kernal, filter_weight) + x_1D = x.flatten(order='F') + x_designed = design_and_volume_constraints( + x_1D[design_area_indices], volfrac, beta_x) + xPhys_1D = matrix_set_0( + x_designed, design_area_indices, nondesign_area_indices) + # 直接等于xPhys_1D 但尚需要在真实存在非设计域时验证 + # 原始方案 + # x_in = x[design_map_bool] + # x_designed = design_and_volume_constraints(x_in, volfrac, beta_x) # x[design_map_bool]内含了一步matrix.flatten() + # x_flat = matrix_set_0(x_designed, jnp.flatnonzero(design_map_bool, size=num_design_ele), Total_ele, num_nondesign_ele) + # 因为x_designed内含flatten() 所以这里对应要复原 由于flatten时没有传入order='F' 这里也不能传入 + # xPhys = x_flat.reshape(x.shape) + return xPhys_1D + + +def projection(x, beta_TONR): + """ + 实现projection 将[-z, z]的数值映射至(0, 1) + 目前 + 原理 tanh能够将元素映射至(-1, 1) 因此 tanh_out * 0.5 + 0.5可以映射至(0, 1) + 注意 sigmoid(x) = 0.5 * tanh(0.5 * x) + 0.5 + Parameters + ---------- + x : 待映射的矩阵 + beta_TONR : projection函数斜率的控制参数 + + Returns + ------- + output : sigmoid/projection变换后的矩阵 + """ + output = jnp.tanh(beta_TONR * 0.5 * x)*.5 + 0.5 + return output + + +def logit(obj_V): + """ + 根据obj_V确定sigmoid分母部分e ** (x-b)中 b搜索的初始上下界 + np.clip(a, a_min, a_max) 将a的取值截断在(a_min, a_max)内 + Parameters + ---------- + obj_V : 目标体积 + + Returns + ------- + """ + p = jnp.clip( + obj_V, 0, 1) + return jnp.log(p) - jnp.log1p(-p) # p从-inf到inf p取0.5为0 以0.5为界限 大于小于符号相反 数值相同 + + +# @partial(jax.jit, static_argnums=(1, 2)) +def design_and_volume_constraints(x, obj_V, beta_TONR): + """ + 用来将神经网络的输出转化为满足设计约束 体积约束和projection的用于有限元计算的单元密度 + x_design = 1/(1 + e ** (x - b(x, obj_V))) + Parameters + ---------- + x : 经过密度过滤后的单元相对密度阵 + obj_V : 目标体积. + beta_TONR : Projection参数 + + Returns + ------- + x_design : 满足设计约束和体积约束的相对密度 向量形式 + """ + def find_yita(y, x, beta): + projection_out = projection(x + y, beta) + volume_diff = jnp.mean(projection_out) - obj_V + return volume_diff + lower_bound = logit(obj_V) - jl.stop_gradient(np.max(x)) + upper_bound = logit(obj_V) - jl.stop_gradient(np.min(x)) + bisec = Bisection(optimality_fun=find_yita, lower=lower_bound, upper=upper_bound, tol=1e-12, maxiter=64, check_bracket=False) + yita = bisec.run(x=x, beta=beta_TONR).params + # yita = TPA.binary_search(find_yita, x, beta_TONR, lower_bound, upper_bound) # 自编二分法 + # 用来搜索yita的取值 由二分法获得 本质上这里和保体积projection搜索yita的过程完全相同 + # 实验和公式表明 在保体积projection和此处添加体积约束的sigmoid中 对yita的求解均是一个关于(x, obj_V)的函数 + # 其敏度必然要考虑 + x_design = projection(x + yita, beta_TONR) + x_design = x_design + 0.00001 + return x_design + + +def filter_define(design_map, filter_width): + """ + 定义TO中过滤器的卷积核 权重等参数 + Parameters + ---------- + design_map : 设计域和非设计域定义矩阵 + filter_width : 过滤半径 + + Returns + ------- + filter_kernal : TOP88中的H 卷积过滤中的卷积核 计算影响域内每一个单元的权重 + filter_weight : TOP88中的HS 设计域内每一个单元的总权重 与网格划分同维度 + """ + dy, dx = jnp.meshgrid(jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width)), + jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width))) + filter_kernal = jnp.maximum(0, filter_width - jnp.sqrt(dx ** 2 + dy ** 2)) + filter_weight = jss.convolve2d(design_map, filter_kernal, 'same') + return filter_kernal, filter_weight + + +# @jax.jit +def filter(matrixA, filter_kernal, filter_weight): + """ + Parameters + ---------- + matrixA : 待过滤的矩阵 + filter_kernal : TOP88中的H 卷积过滤中的卷积核 计算影响域内每一个单元的权重 + filter_weight : TOP88中的HS 设计域内每一个单元的总权重 与网格划分同维度 + + Returns + ------- + new_matrix : 过滤后的矩阵 + """ + new_matrix = jss.convolve2d(matrixA, filter_kernal, 'same') / filter_weight + return new_matrix + + +def inverse_permutation(indices): + inverse_perm = jnp.zeros(len(indices), dtype=jnp.int64) + inverse_perm = inverse_perm.at[indices].set( + jnp.arange(len(indices), dtype=jnp.int64)) + return inverse_perm + + +# @jax.jit +def matrix_set_0(nonzero_values, nonzero_indices, zero_indices): + """ + 补齐矩阵 对原始矩阵指定区域置0 可用于非设计域进行置0 以及节点位移添加指定节点的位移 + Parameters + ---------- + nonzero_values : 已存在数值的矩阵 如设计域的相对密度阵 freedofs对应的节点位移 + nonzero_indices : size=存在数值的个数 已存在数值矩阵对应的索引编号 如设计域对应的索引 + 注意 这里为根据design_map自动生成的 是基于行展开编号的 仅用于此处计算 并非真正TO中的单元编号 + zero_indices : 矩阵中未存在数值的单元对应的索引编号 + origin_matrix_size : 原始矩阵元素个数 即包含设计域和非设计域 + zero_size : 矩阵中0元素的个数 代表非设计域的单元个数 + + Returns + ------- + new_matrix : 补齐后的矩阵 + """ + # all_indices = np.arange(origin_matrix_size, dtype=jnp.int64) + # zero_indices = jnp.setdiff1d( + # all_indices, nonzero_indices, size=zero_size, assume_unique=True) # setdiff1d通常不支持jit 需要指定size + index_map = inverse_permutation( + jnp.concatenate([nonzero_indices, zero_indices])) + u_values = jnp.concatenate([nonzero_values, jnp.zeros(len(zero_indices))]) + new_matrix = u_values[index_map] + return new_matrix + + +@partial(jax.jit, static_argnums=(2, 3)) +def matrix_set_1(nonzero_values, nonzero_indices, origin_matrix_size, zero_size): + """ + 补齐矩阵 对原始矩阵指定区域置1 + Parameters + ---------- + nonzero_values : 已存在数值的矩阵 + nonzero_indices : 已存在数值矩阵对应的索引编号 + origin_matrix_size : 原始矩阵元素个数 即包含设计域和非设计域 + zero_size : 矩阵中0元素的个数 代表非设计域的单元个数 + + Returns + ------- + new_matrix : 补齐后的矩阵 + """ + all_indices = jnp.arange(origin_matrix_size, dtype=jnp.int64) + zero_indices = jnp.setdiff1d( + all_indices, nonzero_indices, size=zero_size, assume_unique=True) + index_map = inverse_permutation( + jnp.concatenate([nonzero_indices, zero_indices])) + u_values = jnp.concatenate([nonzero_values, jnp.ones(len(zero_indices))]) + new_matrix = u_values[index_map] + return new_matrix + + +@partial(jax.jit, static_argnums=(1, 2, 10)) +def objective(xPhys_1D, young, young_min, freedofs, fext, s_K_predefined, dispTD_predefined, idx, edofMat, ke, penal): + """ + Parameters + ---------- + xPhys_1D : shape:[Total_ele](1D) 参与有限元计算的相对密度 + young : 杨氏模量 + young_min : 弱单元的杨氏模量 + freedofs : 有限元的自由节点 + fext : Python中的1D矢量 外力 + s_K_predefined : 预定义的刚度阵 + dispTD_predefined : 预定义的位移阵 + idx : 刚度阵组装编号 + edofMat : shape:[Total_ele, 8](2D) 按照单元顺序存储每一个单元的自由度编号 + ke : 单元刚度阵 不考虑杨氏模量 + penal : 惩罚因子 + + Returns + ------- + obj : 目标函数输出 此处为结构的柔度 + """ + # xPhys_1D = xPhys.T.flatten() + disp = TO_FEA.FEA(young, young_min, ke, xPhys_1D, freedofs, + penal, idx, fext, s_K_predefined, dispTD_predefined) + obj = TO_Obj.compliance(young, young_min, ke, + xPhys_1D, edofMat, penal, disp) + # print('obj is running') + return obj \ No newline at end of file diff --git a/Linear_TO/TO_Define.py b/Linear_TO/TO_Define.py new file mode 100644 index 0000000..f5df482 --- /dev/null +++ b/Linear_TO/TO_Define.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Apr 22 17:40:00 2020 + +@author: zzy + +Object: 对TO问题进行定义 +1. 对于部分确定的static variable 应采用numpy定义 +""" + +import numpy as np +import jax.numpy as jnp +import jax.ops as jops +import TO_Cal +import TO_FEA + + +def Toparams(problem): + young = 1.0 + rou = 1.0 + nelx, nely, ele_length, ele_width = problem.nelx, problem.nely, problem.ele_length, problem.ele_width + Total_ele, Total_nod, Total_dof = nelx * \ + nely, (nelx+1)*(nely+1), 2*(nelx+1)*(nely+1) + gobalcoords0xy, M_elenod, M_edofMat, M_iK, M_jK = coordinate_and_preFEA( + nelx, nely, ele_length, ele_width, Total_ele, Total_nod) + elenod, edofMat, iK, jK = M_elenod-1, M_edofMat-1, M_iK-1, M_jK - \ + 1 # gobalcoords0xy是绝对坐标无需更改 其他根据索引方式均-1 纯粹的数值处理 后面索引用的IE也要从0起 + edofMat_3D = TO_edofMat_3DTensor(nelx, nely) # 这是一个Tensor版本edofMat + design_map_bool, num_design_ele, num_nondesign_ele, design_area_indices, nondesign_area_indices = handle_design_region( + nelx, nely, Total_ele, problem.design_map) + params = { + 'young': young, + 'young_min': 1e-9*young, + 'rou': rou, + 'poisson': 0.3, + 'g': 0, + 'ele_length': ele_length, + 'ele_width': ele_width, + 'ele_thickness': problem.ele_thickness, + 'nelx': nelx, + 'nely': nely, + 'Total_ele': Total_ele, + 'Total_nod': Total_nod, + 'Total_dof': Total_dof, + 'volfrac': problem.volfrac, + 'xmin': 0.001, + 'xmax': 1.0, + 'design_map': problem.design_map, + # 'design_map_bool': design_map_bool, + # 'num_design_ele': num_design_ele, + # 'num_nondesign_ele': num_nondesign_ele, + 'design_area_indices': design_area_indices, + 'nondesign_area_indices': nondesign_area_indices, + 'freedofs': problem.freedofs, + 'fixeddofs': problem.fixeddofs, + 'fext': problem.fext, + 'penal': 3.0, + 'filter_width': problem.filter_width, + 'gobalcoords0xy': gobalcoords0xy, + 'M_elenod': M_elenod, + 'M_edofMat': M_edofMat, + 'M_iK': M_iK, + 'M_jK': M_jK, + 'elenod': elenod, + 'edofMat': edofMat, + 'edofMat_3D': edofMat_3D, + 'iK': iK, + 'jK': jK, + 'design_condition': problem.design_condition, + } + return params + + +class Topology_Optimization: + + def __init__(self, args): + self.args = args + self.young, self.young_min, self.poisson = args['young'], args['young_min'], args['poisson'] + self.Total_dof, self.Total_ele, self.nelx, self.nely, self.ele_thickness = args[ + 'Total_dof'], args['Total_ele'], args['nelx'], args['nely'], args['ele_thickness'] + self.freedofs, self.fext, self.iK, self.jK, self.edofMat = args[ + 'freedofs'], args['fext'], args['iK'], args['jK'], args['edofMat'] + self.filter_width, self.design_condition, self.volfrac = args[ + 'filter_width'], args['design_condition'], args['volfrac'] + self.design_map, self.design_area_indices, self.nondesign_area_indices = args[ + 'design_map'], args['design_area_indices'], args['nondesign_area_indices'] + self.ke = TO_FEA.linear_stiffness_matrix( + self.young, self.poisson, self.ele_thickness) + self.filter_kernal, self.filter_weight = TO_Cal.filter_define( + self.design_map, self.filter_width) + self.s_K_predefined, self.dispTD_predefined, self.idx = jnp.zeros( + [self.Total_dof, self.Total_dof]), jnp.zeros([self.Total_dof]), jops.index[self.iK, self.jK] + self.loop_1, self.iCont = 0, 0 + self.penal_use, self.beta_x_use = 3.0, 1.0 + self.beta_x_use_max, self.beta_x_deltath1, self.beta_x_deltath2 = 16, 50, 30, + self.xPhys = self.volfrac * jnp.ones([self.nely, self.nelx]) + + def xPhys_transform(self, params, projection=False): + x = params.reshape(self.nely, self.nelx) + beta_x = self.penal_and_betax() + xPhys_1D = TO_Cal.design_variable(x, self.design_area_indices, self.nondesign_area_indices, + self.filter_kernal, self.filter_weight, beta_x, self.volfrac, projection=projection) + return xPhys_1D + + def objective(self, params, projection=False): + self.loop_1 = self.loop_1 + 1 + xPhys_1D = self.xPhys_transform(params, projection=projection) + Obj = TO_Cal.objective(xPhys_1D, self.young, self.young_min, self.freedofs, self.fext, + self.s_K_predefined, self.dispTD_predefined, self.idx, self.edofMat, self.ke, self.penal_use) + self.xPhys = xPhys_1D.reshape(self.nely, self.nelx, order='F') + return Obj + + def penal_and_betax(self, *args): + self.iCont = self.iCont + 1 + if (self.iCont > self.beta_x_deltath1 and self.beta_x_use < 2): + self.beta_x_use = self.beta_x_use * 2 + self.iCont = 1 + print(' beta_x increased to :%7.3f\n' % (self.beta_x_use)) + elif (self.iCont > self.beta_x_deltath2 and self.beta_x_use < self.beta_x_use_max and self.beta_x_use >= 2): + self.beta_x_use = self.beta_x_use * 2 + self.iCont = 1 + print(' beta_x increased to :%7.3f\n' % (self.beta_x_use)) + return self.beta_x_use + + +def coordinate_and_preFEA(nelx, nely, ele_length, ele_width, Total_ele, Total_nod): + """ + 有限元计算相关量 单元及节点编号顺序为从左至右 从上至下 + 采用Matlab计数法 对单一单元内物理量 按左下逆时针顺序 + Parameters + ---------- + nelx : 横向网格划分 + nely : 纵向网格划分 + ele_length : 单元长度 + ele_width : 单元宽度 + Total_ele : 单元总数 + Total_nod : 节点总数 + + Returns + ------- + gobalcoords0xy : shape:[Total_nod, 2](2D) 按照节点顺序存储每一个节点的物理坐标(x, y) + elenod : shape:[Total_ele, 4](2D) 按照单元顺序存储每一个单元的节点编号 + edofMat : shape:[Total_ele, 8](2D) 按照单元顺序存储每一个单元的自由度编号 + iK_int : shape:[Total_ele * 64](2D) 用于刚度阵组装 按照特定顺序储存自由度编号 + jK_int : shape:[Total_ele * 64](2D) 用于刚度阵组装 按照特定顺序储存自由度编号 + """ + i0, j0 = jnp.meshgrid(jnp.arange(nelx+1), jnp.arange(nely+1)) + gobalcoords0x = i0*ele_length + gobalcoords0y = ele_width*nely-j0*ele_width + gobalcoords0xy = jnp.hstack((matrix_array(gobalcoords0x), matrix_array( + gobalcoords0y))) + gobalcoords0xyT = gobalcoords0xy.T + gobalcoords0 = matrix_array(gobalcoords0xyT) + nodegrd = matrix_order_arrange(1, Total_nod, nely+1, nelx+1) + nodelast = Matlab_reshape(Matlab_matrix_extract( + nodegrd, 1, -1, 1, -1), Total_ele, 1) + edofVec = 2*matrix_array(nodelast)+1 + elenod = jnp.tile(nodelast, (1, 4)) + \ + jnp.tile(jnp.array([1, nely+2, nely+1, 0]), (Total_ele, 1)) + edofMat = jnp.tile(edofVec, (1, 8))+jnp.tile(jnp.array( + [0, 1, 2*nely+2, 2*nely+3, 2*nely+0, 2*nely+1, -2, -1]), (Total_ele, 1)) + iK = jnp.kron(edofMat, jnp.ones((8, 1))).flatten() + jK = jnp.kron(edofMat, jnp.ones((1, 8))).flatten() + iK_int = iK.astype(jnp.int32) + jK_int = jK.astype(jnp.int32) + return gobalcoords0xy, elenod, edofMat, iK_int, jK_int + + +def handle_design_region(nelx, nely, Total_ele, design_map): + """ + 对设计域/非设计域相关参数进行预定义 + Parameters + ---------- + nelx : 横向网格划分 + nely : 纵向网格划分 + Total_ele : 单元总数 + design_map : 设计域和非设计域定义矩阵 + + Returns + ------- + design_map_bool : bool形式的设计域和非设计域定义矩阵 + num_design_ele : 可设计单元总数 + num_nondesign_ele : 非设计单元总数 + design_area_indices : 设计域的单元编号索引 + nondesign_area_indices : 非设计域的单元编号索引 + """ + shape = (nely, nelx) + design_map_bool = np.broadcast_to(design_map, shape) > 0 + num_design_ele = np.sum(design_map) + num_nondesign_ele = Total_ele - num_design_ele + + design_map_bool_1D = design_map_bool.flatten(order='F') + all_indices = np.arange(Total_ele, dtype=jnp.int64) + design_area_indices = jnp.flatnonzero( + design_map_bool_1D, size=num_design_ele) + nondesign_area_indices = jnp.setdiff1d( + all_indices, design_area_indices, size=num_nondesign_ele, assume_unique=True) # setdiff1d通常不支持jit 需要指定size + return design_map_bool, num_design_ele, num_nondesign_ele, design_area_indices, nondesign_area_indices + + +def Matlab_matrix_array1D(matrixA): + """ + shape:[*] 1D array 按照matlab的排列 + """ + return matrixA.T.flatten() + + +def matrix_array(matrixA): + """ + shape:[*, 1] 2D array 按照matlab的排列 matrixA(:) + """ + return matrixA.T.reshape(-1, 1) + + +def matrix_order_arrange(numbegin, numend, ydim, xdim): + ''' + shape:[ydim, xdim] 2D array 实现一个元素按列顺序排列的矩阵 注意索引编号和matlab的差1 + ''' + matrixA = jnp.array([[i for i in range(numbegin, numend+1)]]) + matrixA = matrixA.reshape(xdim, ydim) + matrixA = matrixA.T + return matrixA + + +def Matlab_reshape(matrixA, ydim, xdim): + ''' + shape:[ydim, xdim] 2D array matlab的reshape效果 + ''' + return matrixA.reshape((ydim, xdim), order='F') + + +def Matlab_matrix_extract(matrixA, xbegin, xend, ybegin, yend): + matrixA = matrixA[ybegin-1:yend, xbegin-1:xend] + return matrixA + + +def TO_edofMat_3DTensor(nelx, nely): + ely, elx = jnp.meshgrid(jnp.arange(nely), jnp.arange(nelx)) # x, y coords + n1 = (nely+1)*(elx+0) + (ely+0) + n2 = (nely+1)*(elx+1) + (ely+0) + n3 = (nely+1)*(elx+1) + (ely+1) + n4 = (nely+1)*(elx+0) + (ely+1) + edofMat_3D = jnp.array( + [2*n4, 2*n4+1, 2*n3, 2*n3+1, 2*n2, 2*n2+1, 2*n1, 2*n1+1]) + return edofMat_3D diff --git a/Linear_TO/TO_FEA.py b/Linear_TO/TO_FEA.py new file mode 100644 index 0000000..d7cc135 --- /dev/null +++ b/Linear_TO/TO_FEA.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Apr 25 01:57:01 2020 + +@author: zzy + +Object: 存储有限元计算所需函数 +1. 所有函数均可jit +2. 在调用时 选择在最外层封装jit 内层函数则将jit注释 +3. 对于部分确定的static variable 应采用numpy定义(仅限CPU) +""" + +import numpy as np +from numpy import float64 +import jax.numpy as jnp +import jax.ops as jops +import jax.scipy.linalg as jsl +from jax import jit +import TO_Private_Autodiff as TPA +from jax.config import config +config.update("jax_enable_x64", True) # 对计算精度影响很大 + + +def linear_stiffness_matrix(young, poisson, ele_thickness): + """ + Parameters + ---------- + young : 杨氏模量 + poisson : 泊松比 + ele_thickness : 单元厚度 + + Returns + ------- + stiffness_ele : 单元刚度矩阵 注意 考虑到TO问题 此处没有考虑单元杨式模量的影响 + """ + young, poisson, h = young, poisson, ele_thickness + k = np.array([1/2-poisson/6, 1/8+poisson/8, -1/4-poisson/12, -1/8+3*poisson/8, + -1/4+poisson/12, -1/8-poisson/8, poisson/6, 1/8-3*poisson/8]) + stiffness_ele = h/(1-poisson**2)*np.array([[k[0], k[1], k[2], k[3], k[4], k[5], k[6], k[7]], + [k[1], k[0], k[7], k[6], + k[5], k[4], k[3], k[2]], + [k[2], k[7], k[0], k[5], + k[6], k[3], k[4], k[1]], + [k[3], k[6], k[5], k[0], + k[7], k[2], k[1], k[4]], + [k[4], k[5], k[6], k[7], + k[0], k[1], k[2], k[3]], + [k[5], k[4], k[3], k[2], + k[1], k[0], k[7], k[6]], + [k[6], k[3], k[4], k[1], + k[2], k[7], k[0], k[5]], + [k[7], k[2], k[1], k[4], + k[3], k[6], k[5], k[0]] + ], dtype=float64) + return stiffness_ele + + +def density_matrix(rou, ele_length, ele_width, ele_thickness): + """ + Parameters + ---------- + rou : 密度 + ele_length : 单元长度 + ele_width : 单元宽度 + ele_thickness : 单元厚度 + + Returns + ------- + density_ele : 单元密度矩阵(一致质量矩阵) 注意 考虑到TO问题 此处没有考虑单元相对密度的影响 否则分子应当是质量 + """ + ele_volume = ele_length*ele_width*ele_thickness + density_ele = ele_volume / 36 * np.array([[4, 0, 2, 0, 1, 0, 2, 0], + [0, 4, 0, 2, + 0, 1, 0, 2], + [2, 0, 4, 0, + 2, 0, 1, 0], + [0, 2, 0, 4, + 0, 2, 0, 1], + [1, 0, 2, 0, + 4, 0, 2, 0], + [0, 1, 0, 2, + 0, 4, 0, 2], + [2, 0, 1, 0, + 2, 0, 4, 0], + [0, 2, 0, 1, + 0, 2, 0, 4], + ], dtype=float64) + return density_ele + + +# @jit +def FEA(young, young_min, ke, xPhys_1D, freedofs, penal, idx, fext, s_K_predefined, dispTD_predefined): + """ + Parameters + ---------- + young : 杨氏模量 + young_min : 弱单元的杨氏模量 + ke : 单元刚度阵 不考虑杨氏模量 + xPhys_1D : shape:[Total_ele](1D) 参与有限元计算的单元相对密度阵 (这种形式是必要的) + freedofs : 有限元的自由节点 + penal : 惩罚因子 + idx : 刚度阵组装编号 + fext : shape:[Total_dof](1D) 外力 + s_K_predefined : shape:[Total_dof, Total_dof](2D) 预定义的刚度阵 + dispTD_predefined : shape:[Total_dof](1D) 预定义的位移阵 + + Returns + ------- + dispTD :shape:[Total_dof](1D) 节点位移 + """ + + def kuf_solve(fext_1D_free, s_K_free): + u_free = jsl.solve(s_K_free, fext_1D_free, + sym_pos=True, check_finite=False) + return u_free + + def young_modulus(xPhys_1D, young, young_min, ke, penal): + """ + Returns + ------- + sK : shape:[64, Total_ele](2D) 考虑相对密度后的每个单元的刚度阵(列形式存储) + """ + sK = ((ke.flatten()[jnp.newaxis]).T * (young_min + (xPhys_1D) ** penal * (young-young_min)) + ).flatten(order='F') + return sK + + sK = young_modulus(xPhys_1D, young, young_min, ke, penal) + s_K = jops.index_add(s_K_predefined, idx, sK) + s_K_free = s_K[freedofs, :][:, freedofs] + fext_free = fext[freedofs] + u_free = kuf_solve(fext_free, s_K_free) + dispTD = jops.index_add(dispTD_predefined, freedofs, u_free) + return dispTD diff --git a/Linear_TO/TO_Model.py b/Linear_TO/TO_Model.py new file mode 100644 index 0000000..d1d992b --- /dev/null +++ b/Linear_TO/TO_Model.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Oct 19 00:00:06 2021 + +@author: zzy + +Object: AuTONR的model构建 +1. 基于Jax+Flax+JAXopt/Optax +2. 关于dataclasses和flax.linen可以参考: + learning_dataclasses.py + learning_flax_module.py +""" + +import jax +import jax.numpy as jnp +import jax.image as ji +import jax.tree_util as jtree +from jax import random +import flax.linen as nn +import flax.linen.initializers as fli +import flax.traverse_util as flu +import flax.core.frozen_dict as fcfd +from functools import partial +from typing import Any, Callable +from jax.config import config +config.update("jax_enable_x64", True) + + +def set_random_seed(seed): + if seed is not None: + rand_key = random.PRNGKey(seed) + return rand_key + + +def constant_array(rng_key, shape, dtype=jnp.float64, value=0.0): + """ + 自定义参数初始化函数 实现矩阵填充数值 + Parameters + ---------- + rng_key : + shape : 目标矩阵的shape + dtype : 数据格式 The default is jnp.float64. + value : 填充数值 The default is 0.0. + + Returns + ------- + out : 初始化后的矩阵 + """ + out = jnp.full(shape, value, dtype) + # print("i am running") + return out + + +def extract_model_params(params_use): + """ + 提取网络参数 + Parameters + ---------- + params_use : FrozenDict 需要提取的参数 + + Returns + ------- + extracted_params : dict 提取出的参数 {'Conv_0/bias': ..., 'Conv_0/kernel': ..., ...} + params_shape : dict 提取的每一组参数的shape {'Conv_0/bias': (...), 'Conv_0/kernel': (...), ...} + params_total_num : 参数总数 + """ + params_unfreeze = params_use.unfreeze() + extracted_params = { + '/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()} + params_shape = jtree.tree_map(jnp.shape, extracted_params) + params_total_num = sum(p.size for p in jtree.tree_leaves(extracted_params)) + return extracted_params, params_shape, params_total_num + + +def replace_model_params(pre_trained_params_dict, params_use): + """ + 替换网络参数 + Parameters + ---------- + pre_trained_params_dict : dict 预训练的参数 {'Conv_0/bias': ..., 'Conv_0/kernel': ..., ...} + params_use : FrozenDict 需要被替换的参数 + + Returns + ------- + new_params_use : FrozenDict 新的参数 + """ + params_unfreeze = params_use.unfreeze() + # 有两种方式均可实现功能 + # Method A 基于flax.traverse_util实现 + extracted_params = { + '/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()} + for key, val in pre_trained_params_dict.items(): + extracted_params[key] = val + new_params_use = flu.unflatten_dict( + {tuple(k.split('/')): v for k, v in extracted_params.items()}) + # Method B 基于jax.tree实现 + params_leaves, params_struct = jtree.tree_flatten(params_unfreeze) + i = 0 + for key, val in pre_trained_params_dict.items(): + params_leaves[i] = val + i = i + 1 + new_params_use = jtree.tree_unflatten(params_struct, params_leaves) + # 封装新的参数 + new_params_use = fcfd.freeze(new_params_use) + return new_params_use + + +def batched_loss(x_inputs, TopOpt_envs): + """ + 处理带有batch_size和多个实例化的TopOpt类的情况 + Parameters + ---------- + x_inputs : shape:[1, nely, nelx] (3D array) 第一个维度代表batch 第二和第三个维度代表TO_2D问题中的单元相对密度 通常batch_size = 1 + TopOpt_envs : 多个实例化的TopOpt类 以列表形式存储 通常仅有一个实例化的类 + + Returns + ------- + losses_array : loss_list以列表形式存储对应于每一个batch和TopOpt的obj输出 将其转化为array形式 + """ + loss_list = [TopOpt.objective(x_inputs[i], projection=True) + for i, TopOpt in enumerate(TopOpt_envs)] + losses_array = jnp.stack(loss_list)[0] + return losses_array + + +# @jax.jit +class Normalize_TO(nn.Module): + """ + 自定义Normalize类 + """ + epsilon: float = 1e-6 + + @nn.compact + def __call__(self, input): + # 通过对axis设置可实现不同normalization的效果 + input_mean = jnp.mean(input, axis=(1, 2, 3), keepdims=True) + input_var = jnp.var(input, axis=(1, 2, 3), keepdims=True) + output = input - input_mean + output = output * jax.lax.rsqrt(input_var + self.epsilon) + return output + + +class Add_offset(nn.Module): + scale: int = 10 + + @nn.compact + def __call__(self, input): + add_bias = self.param('add_bias', lambda rng, shape: constant_array( + rng, shape, value=0.0), input.shape) + return input + self.scale * add_bias + + +class BasicBlock(nn.Module): + resize_scale_one: int + conv_output_one: int + norm_use: nn.Module = Normalize_TO + resize_method: str = "bilinear" + kernel_size: int = 5 + conv_kernel_init: partial = fli.variance_scaling( + scale=1.0, mode="fan_in", distribution="truncated_normal") + offset: nn.Module = Add_offset + offset_scale: int = 10 + + @nn.compact + def __call__(self, input): + output = jnp.tanh(input) + output_shape = output.shape + output = ji.resize(output, (output_shape[0], self.resize_scale_one * output_shape[1], + self.resize_scale_one * output_shape[2], output_shape[-1]), method=self.resize_method, antialias=True) + output = self.norm_use()(output) + output = nn.Conv(features=self.conv_output_one, kernel_size=( + self.kernel_size, self.kernel_size), kernel_init=self.conv_kernel_init)(output) + output = self.offset(scale=self.offset_scale)(output) + return output + + +class My_TO_Model(nn.Module): + TopOpt: Any + seed: int = 1 + replace_params: Callable = replace_model_params + extract_params: Callable = extract_model_params + + def loss(self, input_TO): + # 以list形式传入TopOpt 其实tuple形式也可以 + obj_TO = batched_loss(input_TO, [self.TopOpt]) + return obj_TO + + +class TO_CNN(My_TO_Model): + h: int = 5 + w: int = 10 + dense_scale_init: Any = 1.0 + dense_output_channel: int = 1000 + dtype: Any = jnp.float32 + input_init_std: float = 1.0 # 初始化设计变量的标准差 正态分布 + dense_input_channel: int = 128 + conv_input_0: int = 32 + conv_output: tuple = (128, 64, 32, 16, 1) + up_sampling_scale: tuple = (1, 2, 2, 2, 1) + offset_scale: int = 10 + block: nn.Module = BasicBlock + + @nn.compact + def __call__(self, x=None): # [N, H, W, C] + nn_input = self.param('z', lambda rng, shape: constant_array( + rng, shape, value=self.TopOpt.volfrac), self.dense_input_channel) + + output = nn.Dense(features=self.dense_output_channel, kernel_init=fli.orthogonal( + scale=self.dense_scale_init))(nn_input) + output = output.reshape([1, self.h, self.w, self.conv_input_0]) + + output = self.block( + self.up_sampling_scale[0], self.conv_output[0])(output) + output = self.block( + self.up_sampling_scale[1], self.conv_output[1])(output) + output = self.block( + self.up_sampling_scale[2], self.conv_output[2])(output) + output = self.block( + self.up_sampling_scale[3], self.conv_output[3])(output) + output = self.block( + self.up_sampling_scale[4], self.conv_output[4])(output) + + nn_output = jnp.squeeze(output, axis=-1) + self.sow('intermediates', 'model_out', + nn_output) # 在这里是否可以考虑去掉batch维度? + return nn_output + + +if __name__ == '__main__': + + def funA(params): + out = 1 * jnp.sum(params) + return out + + def funB(params): + out = 2 * jnp.sum(params) + return out + + def funC(params): + out = 3 * jnp.sum(params) + return out + + # 测试 直观展示batched_loss + params_1 = jnp.ones([3, 4, 4]) + fun_list = [funA, funB, funC] + losses = [fun_used(params_1[i]) for i, fun_used in enumerate(fun_list)] + loss_out = jnp.stack(losses) + print("losses:", losses) + print("loss_out:", loss_out) diff --git a/Linear_TO/TO_Obj.py b/Linear_TO/TO_Obj.py new file mode 100644 index 0000000..c5e5a46 --- /dev/null +++ b/Linear_TO/TO_Obj.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +""" +Created on Sun Apr 26 01:59:52 2020 + +@author: zzy + +Object: 存储TO计算目标函数过程所需的函数 +1. 所有函数均可jit +2. 在调用时 选择在最外层封装jit 内层函数则将jit注释 +3. 对于部分确定的static variable 应采用numpy定义 +""" + +import jax.numpy as jnp +from jax import jit + + +# @jit +def compliance(young, young_min, ke, xPhys_1D, edofMat, penal, disp): + ce = jnp.sum(jnp.matmul(disp[edofMat], ke) * disp[edofMat], 1) + ce = (young_min + xPhys_1D ** penal * (young - young_min)) * ce + obj = jnp.sum(ce) + return obj diff --git a/Linear_TO/TO_Problem.py b/Linear_TO/TO_Problem.py new file mode 100644 index 0000000..a94b7df --- /dev/null +++ b/Linear_TO/TO_Problem.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 11 14:20:11 2020 + +@author: zzy + +定义设计域 +""" + +import jax.ops as jops +import numpy as np +import dataclasses +from typing import Optional, Union +import jax.numpy as jnp +from jax.config import config +config.update("jax_enable_x64", True) + + +@dataclasses.dataclass +class ToProblem: + L: int + W: int + nelx: int + nely: int + design_condition: int + ele_length: np.float64 + ele_width: np.float64 + ele_thickness: np.float64 + volfrac: np.float64 + fixeddofs: np.ndarray + freedofs: np.ndarray + fext: np.ndarray + filter_width: float + design_map: np.ndarray + + +def cantilever_single(): + L = 0.8 + W = 0.4 + nelx = 80 + nely = 40 + design_condition = 1 + Total_dof = 2*(nelx+1)*(nely+1) + ele_length = L/nelx + ele_width = W/nely + ele_thickness = 1.0 + volfrac = 0.5 + dofs = np.arange(Total_dof) + loaddof = Total_dof-1 + # fext = jnp.zeros(Total_dof) + # fext = jops.index_update(fext, jops.index[loaddof], -1.) + fext = np.zeros(Total_dof) + fext[loaddof] = -1 + fixeddofs = np.array(dofs[0:2*(nely+1):1]) + freedofs = np.setdiff1d(dofs, fixeddofs) + filter_width = 3.0 + design_map = np.ones([nely, nelx], dtype=np.int64) + return ToProblem(L, W, nelx, nely, design_condition, ele_length, ele_width, ele_thickness, volfrac, fixeddofs, freedofs, fext, filter_width, design_map) + + +def cantilever_single_NN(TopOpt): + dense_input_channel = 128 + dense_init_scale = 1.0 + conv_input_0 = 32 + up_sampling_scale = (1, 2, 2, 2, 1) + total_resize = int(np.prod(up_sampling_scale)) + h = TopOpt.nely // total_resize + w = TopOpt.nelx // total_resize + dense_output_channel = h * w * conv_input_0 + dense_scale_init = dense_init_scale * \ + jnp.sqrt(jnp.maximum(dense_output_channel / dense_input_channel, 1)) + model_kwargs = { + 'h': h, + 'w': w, + 'dense_scale_init': dense_scale_init, + 'dense_output_channel': dense_output_channel, + 'dense_input_channel': dense_input_channel, + 'conv_input_0': conv_input_0, + 'up_sampling_scale': up_sampling_scale, + } + return model_kwargs + + +def cantilever_single_big(): + L = 1.6 + W = 0.8 + nelx = 160 + nely = 80 + design_condition = 2 + Total_dof = 2*(nelx+1)*(nely+1) + ele_length = L/nelx + ele_width = W/nely + ele_thickness = 1.0 + volfrac = 0.5 + dofs = np.arange(Total_dof) + loaddof = Total_dof-1 + # fext = jnp.zeros(Total_dof) + # fext = jops.index_update(fext, jops.index[loaddof], -1.) + fext = np.zeros(Total_dof) + fext[loaddof] = -1 + fixeddofs = np.array(dofs[0:2*(nely+1):1]) + freedofs = np.setdiff1d(dofs, fixeddofs) + filter_width = 3.0 + design_map = np.ones([nely, nelx], dtype=np.int64) + return ToProblem(L, W, nelx, nely, design_condition, ele_length, ele_width, ele_thickness, volfrac, fixeddofs, freedofs, fext, filter_width, design_map) + + +def cantilever_single_big_NN(TopOpt): + dense_input_channel = 128 + dense_init_scale = 1.0 + conv_input_0 = 32 + up_sampling_scale = (1, 2, 2, 2, 2) + total_resize = int(np.prod(up_sampling_scale)) + h = TopOpt.nely // total_resize + w = TopOpt.nelx // total_resize + dense_output_channel = h * w * conv_input_0 + dense_scale_init = dense_init_scale * \ + jnp.sqrt(jnp.maximum(dense_output_channel / dense_input_channel, 1)) + model_kwargs = { + 'h': h, + 'w': w, + 'dense_scale_init': dense_scale_init, + 'dense_output_channel': dense_output_channel, + 'dense_input_channel': dense_input_channel, + 'conv_input_0': conv_input_0, + 'up_sampling_scale': up_sampling_scale, + } + return model_kwargs + diff --git a/Linear_TO/TO_Train.py b/Linear_TO/TO_Train.py new file mode 100644 index 0000000..e0db91b --- /dev/null +++ b/Linear_TO/TO_Train.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Oct 19 09:53:14 2021 + +@author: zzy + +Object: AuTONR的训练函数 +1. 基于Flax+JAXopt/Optax实现 +2. 关于train过程的实现可以参考: + learning_optax_training.py + learning_flax_managing_params.py + learning_flax_training.py +""" + +import sys +from pathlib import Path +from absl import logging +import xarray +import jax +import jax.numpy as jnp +from jax import random +from jaxopt import OptaxSolver +import optax +import time +from functools import partial +import matplotlib.pyplot as plt +from jax.config import config +config.update("jax_enable_x64", True) +sys.path.append( + '/home/zzy/ZZY_CODE/Env_JAX/Topology_Optimization/Linear_TO') +here = Path(__file__).resolve().parent + + +def set_random_seed(seed): + if seed is not None: + rand_key = random.PRNGKey(seed) + return rand_key + + +def optimizer_result_dataset(losses, frames, save_intermediate_designs=False): + # The best design will often but not always be the final one. + best_design = jnp.nanargmin(losses) + logging.info(f'Final loss: {losses[best_design]}') + if save_intermediate_designs: + ds = xarray.Dataset({ + 'loss': (('step',), losses), + 'design': (('step', 'y', 'x'), frames), + }, coords={'step': jnp.arange(len(losses))}) + else: + ds = xarray.Dataset({ + 'loss': (('step',), losses), + 'design': (('y', 'x'), frames[best_design]), + }, coords={'step': jnp.arange(len(losses))}) + return ds + + +def train_TO_Optax(model, max_iterations, save_intermediate_designs=True, seed=1, **kwargs): + losses = [] + frames = [] + # Initialize parameters + init_params = model.init(set_random_seed(seed), None) + batch_state, params = init_params.pop("params") + del init_params + # Instantiate Optimizer + design_condition = model.TopOpt.design_condition + if design_condition == 1: + learning_rate = 0.001 + optimizer = optax.adam(learning_rate) + elif design_condition == 2: + learning_rate = 0.001 + optimizer = optax.adam(learning_rate) + # loss + def loss_cal(params): + all_params = {"params": params} + model_out, net_state = model.apply(all_params, None, mutable="intermediates") + loss_out = model.loss(model_out) + return loss_out, net_state["intermediates"] + loss_grad_fn = jax.value_and_grad(loss_cal, has_aux=True) + # Initialize Optimizer State + state = optimizer.init(params) + # Updated + for iter in range(max_iterations): + start_time_epoch = time.perf_counter() + (loss_val, batch_stats), grads = loss_grad_fn(params) + losses.append(jax.device_get(loss_val)) + updates_params, state = optimizer.update(grads, state) + params = optax.apply_updates(params, updates_params) + design_TO = jax.device_get(model.TopOpt.xPhys.aval.val) + frames.append(design_TO) + plt.figure() + plt.imshow(design_TO, cmap='Greys') + plt.show() + whole_time_epoch = time.perf_counter() - start_time_epoch + print(' It.:%5i Obj.:%11.4f Vol.:%7.3f Time.:%7.3f\n' % + (iter+1, losses[-1], jnp.mean(design_TO), whole_time_epoch)) + + return optimizer_result_dataset(jnp.array(losses), jnp.array(frames), save_intermediate_designs) + + +def train_TO_JAXopt(model, max_iterations, save_intermediate_designs=True, seed=1, **kwargs): + losses = [] + frames = [] + # Initialize parameters + init_params = model.init(set_random_seed(seed), None) + # batch_stats, params = init_params.pop("params") + params = init_params["params"] + # batch_stats = init_params["batch_stats"] + del init_params + + # Initialize solver + # @jax.jit + def loss_cal(params): + # all_params = {"params": params, "batch_stats": aux} + # model_out, net_state = model.apply(all_params, None) + all_params = {"params": params} + model_out, net_state = model.apply( + all_params, None, mutable="intermediates") + loss_out = model.loss(model_out) + # design_TO = model.TopOpt.x_calculation( + # model_out, volume_contraint=True, projection=True) + # losses.append(jax.device_get(loss_out)) + return loss_out, net_state["intermediates"] + + design_condition = model.TopOpt.design_condition + + def opt_solver_setting(design_condition, loss_fun, max_iterations): + if design_condition == 1: + learning_rate = 0.001 + opt_use = optax.adam(learning_rate) + solver = OptaxSolver( + fun=loss_fun, opt=opt_use, maxiter=max_iterations, has_aux=True) + elif design_condition == 2: + learning_rate = 0.000375 + elif design_condition == 6: + learning_rate = 0.001 + elif design_condition == 5: + learning_rate = 0.000375 + return solver + + solver = opt_solver_setting(design_condition, loss_cal, max_iterations) + + state = solver.init_state(params) + for iter in range(max_iterations): + start_time_epoch = time.time() + # params, state = solver.update( + # params=params, state=state, aux=batch_stats) + + params, state = solver.update(params=params, state=state) + batch_stats = state.aux + losses.append(jax.device_get(state.value)) # 放在这里应该更好一些 + # design_TO = jax.device_get(batch_stats._dict['model_out'])[-1] + # design_TO = model.TopOpt.x_calculation(design_TO, volume_constraint=True, projection=True) + design_TO = jax.device_get(model.TopOpt.xPhys.aval.val) + frames.append(design_TO) + plt.figure() + plt.imshow(design_TO, cmap='Greys') + plt.show() + end_time_epoch = time.time() + whole_time_epoch = end_time_epoch - start_time_epoch + print(' It.:%5i Obj.:%11.4f Vol.:%7.3f Time.:%7.3f\n' % + (iter+1, losses[-1], jnp.mean(design_TO), whole_time_epoch)) + + # frames = jax.device_get(batch_stats._dict['intermediates']['model_out']) + return optimizer_result_dataset(jnp.array(losses), jnp.array(frames), save_intermediate_designs) diff --git a/Linear_TO/__pycache__/TO_Cal.cpython-38.pyc b/Linear_TO/__pycache__/TO_Cal.cpython-38.pyc new file mode 100644 index 0000000..b37e534 Binary files /dev/null and b/Linear_TO/__pycache__/TO_Cal.cpython-38.pyc differ diff --git a/Linear_TO/__pycache__/TO_Define.cpython-38.pyc b/Linear_TO/__pycache__/TO_Define.cpython-38.pyc new file mode 100644 index 0000000..d147a90 Binary files /dev/null and b/Linear_TO/__pycache__/TO_Define.cpython-38.pyc differ diff --git a/Linear_TO/__pycache__/TO_FEA.cpython-38.pyc b/Linear_TO/__pycache__/TO_FEA.cpython-38.pyc new file mode 100644 index 0000000..3c899e4 Binary files /dev/null and b/Linear_TO/__pycache__/TO_FEA.cpython-38.pyc differ diff --git a/Linear_TO/__pycache__/TO_Model.cpython-38.pyc b/Linear_TO/__pycache__/TO_Model.cpython-38.pyc new file mode 100644 index 0000000..55fd58e Binary files /dev/null and b/Linear_TO/__pycache__/TO_Model.cpython-38.pyc differ diff --git a/Linear_TO/__pycache__/TO_Obj.cpython-38.pyc b/Linear_TO/__pycache__/TO_Obj.cpython-38.pyc new file mode 100644 index 0000000..0e467a4 Binary files /dev/null and b/Linear_TO/__pycache__/TO_Obj.cpython-38.pyc differ diff --git a/Linear_TO/__pycache__/TO_Private_Autodiff.cpython-38.pyc b/Linear_TO/__pycache__/TO_Private_Autodiff.cpython-38.pyc new file mode 100644 index 0000000..ee8b5b1 Binary files /dev/null and b/Linear_TO/__pycache__/TO_Private_Autodiff.cpython-38.pyc differ diff --git a/Linear_TO/__pycache__/TO_Problem.cpython-38.pyc b/Linear_TO/__pycache__/TO_Problem.cpython-38.pyc new file mode 100644 index 0000000..279ee9c Binary files /dev/null and b/Linear_TO/__pycache__/TO_Problem.cpython-38.pyc differ diff --git a/Linear_TO/__pycache__/TO_Train.cpython-38.pyc b/Linear_TO/__pycache__/TO_Train.cpython-38.pyc new file mode 100644 index 0000000..ed324ff Binary files /dev/null and b/Linear_TO/__pycache__/TO_Train.cpython-38.pyc differ