diff --git a/Linear_TO/TONR_Linear_Stiffness.py b/Linear_TO/TONR_Linear_Stiffness.py new file mode 100644 index 0000000..6e90722 --- /dev/null +++ b/Linear_TO/TONR_Linear_Stiffness.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +''' +@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved. +@ Author : Zeyu Zhang +@ Email : zhangzeyu_work@outlook.com +@ Date : 2021-11-17 10:52:26 +@ LastEditTime : 2022-10-25 09:31:20 +@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TONR_Linear_Stiffness.py +@ +@ Description : Jax+Flax+optax TONR Linear Stiffness Problem +@ Reference : +''' + +import os +from pathlib import Path +import sys +import pandas as pd +import xarray +import matplotlib.pyplot as plt +import seaborn +import time +import TO_Problem +import TO_Define +import TO_Model +import TO_Train +import jax.numpy as jnp +from jax import random +from jax.config import config +config.update("jax_enable_x64", True) +sys.path.append(os.path.dirname(os.path.realpath(__file__))) +here = Path(__file__).resolve().parent + +design_condition = 1 +if design_condition == 1: + problem = TO_Problem.cantilever_single() + max_iterations = 200 + args = TO_Define.Toparams(problem) + TopOpt_env = TO_Define.Topology_Optimization(args) + model_kwargs = TO_Problem.cantilever_single_NN(TopOpt_env) +elif design_condition == 2: + problem = TO_Problem.cantilever_single_big() + max_iterations = 200 + args = TO_Define.Toparams(problem) + TopOpt_env = TO_Define.Topology_Optimization(args) + model_kwargs = TO_Problem.cantilever_single_big_NN(TopOpt_env) + +TONR_model = TO_Model.TO_CNN(TopOpt=TopOpt_env, **model_kwargs) + +if __name__ == '__main__': + + start_time = time.perf_counter() + seed = 1 + ds_TopOpt = TO_Train.train_TO_Optax( + TONR_model, max_iterations, save_intermediate_designs=True, seed=1) + ds_TopOpt.to_netcdf('ds_TopOpt.nc') # save + # ds_TopOpt = xarray.open_dataset('ds_TopOpt.nc') # load + end_time = time.perf_counter() + whole_time = end_time - start_time + + obj_value = ds_TopOpt.loss.data + x_designed = ds_TopOpt.design.data + + obj_min = jnp.min(obj_value) + obj_max = jnp.max(obj_value) + obj_min_index = jnp.where(obj_value == obj_min)[0][0] + + dims = pd.Index(['TONR'], name='model') + ds = xarray.concat([ds_TopOpt], dim=dims) + ds_TopOpt.loss.transpose().to_pandas().cummin().plot(linewidth=2) + plt.ylim(obj_min * 0.85, obj_value[0] * 1.15) + plt.ylabel('Compliance (loss)') + plt.xlabel('Optimization step') + seaborn.despine() + + final_design = x_designed[obj_min_index, :, :] + final_obj = obj_min + + plt.show() + ds_TopOpt.design.sel(step=obj_min_index).plot.imshow(x='x', y='y', size=2, + aspect=2.5, col_wrap=2, yincrease=False, add_colorbar=False, cmap='Greys') + + diff --git a/Linear_TO/TO_Cal.py b/Linear_TO/TO_Cal.py new file mode 100644 index 0000000..b382f09 --- /dev/null +++ b/Linear_TO/TO_Cal.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +''' +@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved. +@ Author : Zeyu Zhang +@ Email : zhangzeyu_work@outlook.com +@ Date : 2021-11-17 10:52:26 +@ LastEditTime : 2022-10-25 09:33:11 +@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Cal.py +@ +@ Description : 存储TO计算过程所需的函数 +@ Reference : +''' + +import numpy as np +import jax +import jax.numpy as jnp +import jax.scipy.signal as jss +import jax.lax as jl +from jaxopt import Bisection +from functools import partial +import TO_Private_Autodiff as TPA +import TO_FEA +import TO_Obj +from jax.config import config +config.update("jax_enable_x64", True) + + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def design_variable(x, design_area_indices, nondesign_area_indices, filter_kernal, filter_weight, beta_x, volfrac, projection=False): + x = filter(x, filter_kernal, filter_weight) + x_1D = x.flatten(order='F') + x_designed = design_and_volume_constraints( + x_1D[design_area_indices], volfrac, beta_x) + xPhys_1D = matrix_set_0( + x_designed, design_area_indices, nondesign_area_indices) + return xPhys_1D + + +def projection(x, beta_TONR): + output = jnp.tanh(beta_TONR * 0.5 * x)*.5 + 0.5 + return output + + +def logit(obj_V): + p = jnp.clip( + obj_V, 0, 1) + return jnp.log(p) - jnp.log1p(-p) # p从-inf到inf p取0.5为0 以0.5为界限 大于小于符号相反 数值相同 + + +# @partial(jax.jit, static_argnums=(1, 2)) +def design_and_volume_constraints(x, obj_V, beta_TONR): + def find_yita(y, x, beta): + projection_out = projection(x + y, beta) + volume_diff = jnp.mean(projection_out) - obj_V + return volume_diff + lower_bound = logit(obj_V) - jl.stop_gradient(np.max(x)) + upper_bound = logit(obj_V) - jl.stop_gradient(np.min(x)) + bisec = Bisection(optimality_fun=find_yita, lower=lower_bound, upper=upper_bound, tol=1e-12, maxiter=64, check_bracket=False) + yita = bisec.run(x=x, beta=beta_TONR).params + x_design = projection(x + yita, beta_TONR) + x_design = x_design + 0.00001 + return x_design + + +def filter_define(design_map, filter_width): + dy, dx = jnp.meshgrid(jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width)), + jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width))) + filter_kernal = jnp.maximum(0, filter_width - jnp.sqrt(dx ** 2 + dy ** 2)) + filter_weight = jss.convolve2d(design_map, filter_kernal, 'same') + return filter_kernal, filter_weight + + +# @jax.jit +def filter(matrixA, filter_kernal, filter_weight): + new_matrix = jss.convolve2d(matrixA, filter_kernal, 'same') / filter_weight + return new_matrix + + +def inverse_permutation(indices): + inverse_perm = jnp.zeros(len(indices), dtype=jnp.int64) + inverse_perm = inverse_perm.at[indices].set( + jnp.arange(len(indices), dtype=jnp.int64)) + return inverse_perm + + +# @jax.jit +def matrix_set_0(nonzero_values, nonzero_indices, zero_indices): + index_map = inverse_permutation( + jnp.concatenate([nonzero_indices, zero_indices])) + u_values = jnp.concatenate([nonzero_values, jnp.zeros(len(zero_indices))]) + new_matrix = u_values[index_map] + return new_matrix + + +@partial(jax.jit, static_argnums=(2, 3)) +def matrix_set_1(nonzero_values, nonzero_indices, origin_matrix_size, zero_size): + all_indices = jnp.arange(origin_matrix_size, dtype=jnp.int64) + zero_indices = jnp.setdiff1d( + all_indices, nonzero_indices, size=zero_size, assume_unique=True) + index_map = inverse_permutation( + jnp.concatenate([nonzero_indices, zero_indices])) + u_values = jnp.concatenate([nonzero_values, jnp.ones(len(zero_indices))]) + new_matrix = u_values[index_map] + return new_matrix + + +@partial(jax.jit, static_argnums=(1, 2, 10)) +def objective(xPhys_1D, young, young_min, freedofs, fext, s_K_predefined, dispTD_predefined, idx, edofMat, ke, penal): + # xPhys_1D = xPhys.T.flatten() + disp = TO_FEA.FEA(young, young_min, ke, xPhys_1D, freedofs, + penal, idx, fext, s_K_predefined, dispTD_predefined) + obj = TO_Obj.compliance(young, young_min, ke, + xPhys_1D, edofMat, penal, disp) + # print('obj is running') + return obj \ No newline at end of file diff --git a/Linear_TO/TO_Define.py b/Linear_TO/TO_Define.py new file mode 100644 index 0000000..9e93726 --- /dev/null +++ b/Linear_TO/TO_Define.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- +''' +@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved. +@ Author : Zeyu Zhang +@ Email : zhangzeyu_work@outlook.com +@ Date : 2022-10-25 09:30:27 +@ LastEditTime : 2022-10-25 09:37:58 +@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Define.py +@ +@ Description : +@ Reference : +''' + +import numpy as np +import jax.numpy as jnp +import jax.ops as jops +import TO_Cal +import TO_FEA + + +def Toparams(problem): + young = 1.0 + rou = 1.0 + nelx, nely, ele_length, ele_width = problem.nelx, problem.nely, problem.ele_length, problem.ele_width + Total_ele, Total_nod, Total_dof = nelx * \ + nely, (nelx+1)*(nely+1), 2*(nelx+1)*(nely+1) + gobalcoords0xy, M_elenod, M_edofMat, M_iK, M_jK = coordinate_and_preFEA( + nelx, nely, ele_length, ele_width, Total_ele, Total_nod) + elenod, edofMat, iK, jK = M_elenod-1, M_edofMat-1, M_iK-1, M_jK - \ + 1 + edofMat_3D = TO_edofMat_3DTensor(nelx, nely) + design_map_bool, num_design_ele, num_nondesign_ele, design_area_indices, nondesign_area_indices = handle_design_region( + nelx, nely, Total_ele, problem.design_map) + params = { + 'young': young, + 'young_min': 1e-9*young, + 'rou': rou, + 'poisson': 0.3, + 'g': 0, + 'ele_length': ele_length, + 'ele_width': ele_width, + 'ele_thickness': problem.ele_thickness, + 'nelx': nelx, + 'nely': nely, + 'Total_ele': Total_ele, + 'Total_nod': Total_nod, + 'Total_dof': Total_dof, + 'volfrac': problem.volfrac, + 'xmin': 0.001, + 'xmax': 1.0, + 'design_map': problem.design_map, + # 'design_map_bool': design_map_bool, + # 'num_design_ele': num_design_ele, + # 'num_nondesign_ele': num_nondesign_ele, + 'design_area_indices': design_area_indices, + 'nondesign_area_indices': nondesign_area_indices, + 'freedofs': problem.freedofs, + 'fixeddofs': problem.fixeddofs, + 'fext': problem.fext, + 'penal': 3.0, + 'filter_width': problem.filter_width, + 'gobalcoords0xy': gobalcoords0xy, + 'M_elenod': M_elenod, + 'M_edofMat': M_edofMat, + 'M_iK': M_iK, + 'M_jK': M_jK, + 'elenod': elenod, + 'edofMat': edofMat, + 'edofMat_3D': edofMat_3D, + 'iK': iK, + 'jK': jK, + 'design_condition': problem.design_condition, + } + return params + + +class Topology_Optimization: + + def __init__(self, args): + self.args = args + self.young, self.young_min, self.poisson = args['young'], args['young_min'], args['poisson'] + self.Total_dof, self.Total_ele, self.nelx, self.nely, self.ele_thickness = args[ + 'Total_dof'], args['Total_ele'], args['nelx'], args['nely'], args['ele_thickness'] + self.freedofs, self.fext, self.iK, self.jK, self.edofMat = args[ + 'freedofs'], args['fext'], args['iK'], args['jK'], args['edofMat'] + self.filter_width, self.design_condition, self.volfrac = args[ + 'filter_width'], args['design_condition'], args['volfrac'] + self.design_map, self.design_area_indices, self.nondesign_area_indices = args[ + 'design_map'], args['design_area_indices'], args['nondesign_area_indices'] + self.ke = TO_FEA.linear_stiffness_matrix( + self.young, self.poisson, self.ele_thickness) + self.filter_kernal, self.filter_weight = TO_Cal.filter_define( + self.design_map, self.filter_width) + self.s_K_predefined, self.dispTD_predefined, self.idx = jnp.zeros( + [self.Total_dof, self.Total_dof]), jnp.zeros([self.Total_dof]), jops.index[self.iK, self.jK] + self.loop_1, self.iCont = 0, 0 + self.penal_use, self.beta_x_use = 3.0, 1.0 + self.beta_x_use_max, self.beta_x_deltath1, self.beta_x_deltath2 = 16, 50, 30, + self.xPhys = self.volfrac * jnp.ones([self.nely, self.nelx]) + + def xPhys_transform(self, params, projection=False): + x = params.reshape(self.nely, self.nelx) + beta_x = self.penal_and_betax() + xPhys_1D = TO_Cal.design_variable(x, self.design_area_indices, self.nondesign_area_indices, + self.filter_kernal, self.filter_weight, beta_x, self.volfrac, projection=projection) + return xPhys_1D + + def objective(self, params, projection=False): + self.loop_1 = self.loop_1 + 1 + xPhys_1D = self.xPhys_transform(params, projection=projection) + Obj = TO_Cal.objective(xPhys_1D, self.young, self.young_min, self.freedofs, self.fext, + self.s_K_predefined, self.dispTD_predefined, self.idx, self.edofMat, self.ke, self.penal_use) + self.xPhys = xPhys_1D.reshape(self.nely, self.nelx, order='F') + return Obj + + def penal_and_betax(self, *args): + self.iCont = self.iCont + 1 + if (self.iCont > self.beta_x_deltath1 and self.beta_x_use < 2): + self.beta_x_use = self.beta_x_use * 2 + self.iCont = 1 + print(' beta_x increased to :%7.3f\n' % (self.beta_x_use)) + elif (self.iCont > self.beta_x_deltath2 and self.beta_x_use < self.beta_x_use_max and self.beta_x_use >= 2): + self.beta_x_use = self.beta_x_use * 2 + self.iCont = 1 + print(' beta_x increased to :%7.3f\n' % (self.beta_x_use)) + return self.beta_x_use + + +def coordinate_and_preFEA(nelx, nely, ele_length, ele_width, Total_ele, Total_nod): + i0, j0 = jnp.meshgrid(jnp.arange(nelx+1), jnp.arange(nely+1)) + gobalcoords0x = i0*ele_length + gobalcoords0y = ele_width*nely-j0*ele_width + gobalcoords0xy = jnp.hstack((matrix_array(gobalcoords0x), matrix_array( + gobalcoords0y))) + gobalcoords0xyT = gobalcoords0xy.T + gobalcoords0 = matrix_array(gobalcoords0xyT) + nodegrd = matrix_order_arrange(1, Total_nod, nely+1, nelx+1) + nodelast = Matlab_reshape(Matlab_matrix_extract( + nodegrd, 1, -1, 1, -1), Total_ele, 1) + edofVec = 2*matrix_array(nodelast)+1 + elenod = jnp.tile(nodelast, (1, 4)) + \ + jnp.tile(jnp.array([1, nely+2, nely+1, 0]), (Total_ele, 1)) + edofMat = jnp.tile(edofVec, (1, 8))+jnp.tile(jnp.array( + [0, 1, 2*nely+2, 2*nely+3, 2*nely+0, 2*nely+1, -2, -1]), (Total_ele, 1)) + iK = jnp.kron(edofMat, jnp.ones((8, 1))).flatten() + jK = jnp.kron(edofMat, jnp.ones((1, 8))).flatten() + iK_int = iK.astype(jnp.int32) + jK_int = jK.astype(jnp.int32) + return gobalcoords0xy, elenod, edofMat, iK_int, jK_int + + +def handle_design_region(nelx, nely, Total_ele, design_map): + shape = (nely, nelx) + design_map_bool = np.broadcast_to(design_map, shape) > 0 + num_design_ele = np.sum(design_map) + num_nondesign_ele = Total_ele - num_design_ele + + design_map_bool_1D = design_map_bool.flatten(order='F') + all_indices = np.arange(Total_ele, dtype=jnp.int64) + design_area_indices = jnp.flatnonzero( + design_map_bool_1D, size=num_design_ele) + nondesign_area_indices = jnp.setdiff1d( + all_indices, design_area_indices, size=num_nondesign_ele, assume_unique=True) # setdiff1d通常不支持jit 需要指定size + return design_map_bool, num_design_ele, num_nondesign_ele, design_area_indices, nondesign_area_indices + + +def Matlab_matrix_array1D(matrixA): + return matrixA.T.flatten() + + +def matrix_array(matrixA): + return matrixA.T.reshape(-1, 1) + + +def matrix_order_arrange(numbegin, numend, ydim, xdim): + matrixA = jnp.array([[i for i in range(numbegin, numend+1)]]) + matrixA = matrixA.reshape(xdim, ydim) + matrixA = matrixA.T + return matrixA + + +def Matlab_reshape(matrixA, ydim, xdim): + return matrixA.reshape((ydim, xdim), order='F') + + +def Matlab_matrix_extract(matrixA, xbegin, xend, ybegin, yend): + matrixA = matrixA[ybegin-1:yend, xbegin-1:xend] + return matrixA + + +def TO_edofMat_3DTensor(nelx, nely): + ely, elx = jnp.meshgrid(jnp.arange(nely), jnp.arange(nelx)) # x, y coords + n1 = (nely+1)*(elx+0) + (ely+0) + n2 = (nely+1)*(elx+1) + (ely+0) + n3 = (nely+1)*(elx+1) + (ely+1) + n4 = (nely+1)*(elx+0) + (ely+1) + edofMat_3D = jnp.array( + [2*n4, 2*n4+1, 2*n3, 2*n3+1, 2*n2, 2*n2+1, 2*n1, 2*n1+1]) + return edofMat_3D diff --git a/Linear_TO/TO_FEA.py b/Linear_TO/TO_FEA.py new file mode 100644 index 0000000..4d44791 --- /dev/null +++ b/Linear_TO/TO_FEA.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +''' +@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved. +@ Author : Zeyu Zhang +@ Email : zhangzeyu_work@outlook.com +@ Date : 2022-10-25 09:30:27 +@ LastEditTime : 2022-10-25 09:38:08 +@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_FEA.py +@ +@ Description : +@ Reference : +''' + +import numpy as np +from numpy import float64 +import jax.numpy as jnp +import jax.ops as jops +import jax.scipy.linalg as jsl +from jax import jit +from jax.config import config +config.update("jax_enable_x64", True) + + +def linear_stiffness_matrix(young, poisson, ele_thickness): + young, poisson, h = young, poisson, ele_thickness + k = np.array([1/2-poisson/6, 1/8+poisson/8, -1/4-poisson/12, -1/8+3*poisson/8, + -1/4+poisson/12, -1/8-poisson/8, poisson/6, 1/8-3*poisson/8]) + stiffness_ele = h/(1-poisson**2)*np.array([[k[0], k[1], k[2], k[3], k[4], k[5], k[6], k[7]], + [k[1], k[0], k[7], k[6], + k[5], k[4], k[3], k[2]], + [k[2], k[7], k[0], k[5], + k[6], k[3], k[4], k[1]], + [k[3], k[6], k[5], k[0], + k[7], k[2], k[1], k[4]], + [k[4], k[5], k[6], k[7], + k[0], k[1], k[2], k[3]], + [k[5], k[4], k[3], k[2], + k[1], k[0], k[7], k[6]], + [k[6], k[3], k[4], k[1], + k[2], k[7], k[0], k[5]], + [k[7], k[2], k[1], k[4], + k[3], k[6], k[5], k[0]] + ], dtype=float64) + return stiffness_ele + + +def density_matrix(rou, ele_length, ele_width, ele_thickness): + ele_volume = ele_length*ele_width*ele_thickness + density_ele = ele_volume / 36 * np.array([[4, 0, 2, 0, 1, 0, 2, 0], + [0, 4, 0, 2, + 0, 1, 0, 2], + [2, 0, 4, 0, + 2, 0, 1, 0], + [0, 2, 0, 4, + 0, 2, 0, 1], + [1, 0, 2, 0, + 4, 0, 2, 0], + [0, 1, 0, 2, + 0, 4, 0, 2], + [2, 0, 1, 0, + 2, 0, 4, 0], + [0, 2, 0, 1, + 0, 2, 0, 4], + ], dtype=float64) + return density_ele + + +# @jit +def FEA(young, young_min, ke, xPhys_1D, freedofs, penal, idx, fext, s_K_predefined, dispTD_predefined): + + def kuf_solve(fext_1D_free, s_K_free): + u_free = jsl.solve(s_K_free, fext_1D_free, + sym_pos=True, check_finite=False) + return u_free + + def young_modulus(xPhys_1D, young, young_min, ke, penal): + + sK = ((ke.flatten()[jnp.newaxis]).T * (young_min + (xPhys_1D) ** penal * (young-young_min)) + ).flatten(order='F') + return sK + + sK = young_modulus(xPhys_1D, young, young_min, ke, penal) + s_K = jops.index_add(s_K_predefined, idx, sK) + s_K_free = s_K[freedofs, :][:, freedofs] + fext_free = fext[freedofs] + u_free = kuf_solve(fext_free, s_K_free) + dispTD = jops.index_add(dispTD_predefined, freedofs, u_free) + return dispTD diff --git a/Linear_TO/TO_Model.py b/Linear_TO/TO_Model.py new file mode 100644 index 0000000..fb28e7c --- /dev/null +++ b/Linear_TO/TO_Model.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +''' +@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved. +@ Author : Zeyu Zhang +@ Email : zhangzeyu_work@outlook.com +@ Date : 2022-10-25 09:30:27 +@ LastEditTime : 2022-10-25 09:36:54 +@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Model.py +@ +@ Description : +@ Reference : +''' + +import jax +import jax.numpy as jnp +import jax.image as ji +import jax.tree_util as jtree +from jax import random +import flax.linen as nn +import flax.linen.initializers as fli +import flax.traverse_util as flu +import flax.core.frozen_dict as fcfd +from functools import partial +from typing import Any, Callable +from jax.config import config +config.update("jax_enable_x64", True) + + +def set_random_seed(seed): + if seed is not None: + rand_key = random.PRNGKey(seed) + return rand_key + + +def constant_array(rng_key, shape, dtype=jnp.float64, value=0.0): + out = jnp.full(shape, value, dtype) + # print("i am running") + return out + + +def extract_model_params(params_use): + params_unfreeze = params_use.unfreeze() + extracted_params = { + '/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()} + params_shape = jtree.tree_map(jnp.shape, extracted_params) + params_total_num = sum(p.size for p in jtree.tree_leaves(extracted_params)) + return extracted_params, params_shape, params_total_num + + +def replace_model_params(pre_trained_params_dict, params_use): + params_unfreeze = params_use.unfreeze() + # + extracted_params = { + '/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()} + for key, val in pre_trained_params_dict.items(): + extracted_params[key] = val + new_params_use = flu.unflatten_dict( + {tuple(k.split('/')): v for k, v in extracted_params.items()}) + # + params_leaves, params_struct = jtree.tree_flatten(params_unfreeze) + i = 0 + for key, val in pre_trained_params_dict.items(): + params_leaves[i] = val + i = i + 1 + new_params_use = jtree.tree_unflatten(params_struct, params_leaves) + # + new_params_use = fcfd.freeze(new_params_use) + return new_params_use + + +def batched_loss(x_inputs, TopOpt_envs): + loss_list = [TopOpt.objective(x_inputs[i], projection=True) + for i, TopOpt in enumerate(TopOpt_envs)] + losses_array = jnp.stack(loss_list)[0] + return losses_array + + +# @jax.jit +class Normalize_TO(nn.Module): + + epsilon: float = 1e-6 + + @nn.compact + def __call__(self, input): + input_mean = jnp.mean(input, axis=(1, 2, 3), keepdims=True) + input_var = jnp.var(input, axis=(1, 2, 3), keepdims=True) + output = input - input_mean + output = output * jax.lax.rsqrt(input_var + self.epsilon) + return output + + +class Add_offset(nn.Module): + scale: int = 10 + + @nn.compact + def __call__(self, input): + add_bias = self.param('add_bias', lambda rng, shape: constant_array( + rng, shape, value=0.0), input.shape) + return input + self.scale * add_bias + + +class BasicBlock(nn.Module): + resize_scale_one: int + conv_output_one: int + norm_use: nn.Module = Normalize_TO + resize_method: str = "bilinear" + kernel_size: int = 5 + conv_kernel_init: partial = fli.variance_scaling( + scale=1.0, mode="fan_in", distribution="truncated_normal") + offset: nn.Module = Add_offset + offset_scale: int = 10 + + @nn.compact + def __call__(self, input): + output = jnp.tanh(input) + output_shape = output.shape + output = ji.resize(output, (output_shape[0], self.resize_scale_one * output_shape[1], + self.resize_scale_one * output_shape[2], output_shape[-1]), method=self.resize_method, antialias=True) + output = self.norm_use()(output) + output = nn.Conv(features=self.conv_output_one, kernel_size=( + self.kernel_size, self.kernel_size), kernel_init=self.conv_kernel_init)(output) + output = self.offset(scale=self.offset_scale)(output) + return output + + +class My_TO_Model(nn.Module): + TopOpt: Any + seed: int = 1 + replace_params: Callable = replace_model_params + extract_params: Callable = extract_model_params + + def loss(self, input_TO): + obj_TO = batched_loss(input_TO, [self.TopOpt]) + return obj_TO + + +class TO_CNN(My_TO_Model): + h: int = 5 + w: int = 10 + dense_scale_init: Any = 1.0 + dense_output_channel: int = 1000 + dtype: Any = jnp.float32 + input_init_std: float = 1.0 + dense_input_channel: int = 128 + conv_input_0: int = 32 + conv_output: tuple = (128, 64, 32, 16, 1) + up_sampling_scale: tuple = (1, 2, 2, 2, 1) + offset_scale: int = 10 + block: nn.Module = BasicBlock + + @nn.compact + def __call__(self, x=None): # [N, H, W, C] + nn_input = self.param('z', lambda rng, shape: constant_array( + rng, shape, value=self.TopOpt.volfrac), self.dense_input_channel) + + output = nn.Dense(features=self.dense_output_channel, kernel_init=fli.orthogonal( + scale=self.dense_scale_init))(nn_input) + output = output.reshape([1, self.h, self.w, self.conv_input_0]) + + output = self.block( + self.up_sampling_scale[0], self.conv_output[0])(output) + output = self.block( + self.up_sampling_scale[1], self.conv_output[1])(output) + output = self.block( + self.up_sampling_scale[2], self.conv_output[2])(output) + output = self.block( + self.up_sampling_scale[3], self.conv_output[3])(output) + output = self.block( + self.up_sampling_scale[4], self.conv_output[4])(output) + + nn_output = jnp.squeeze(output, axis=-1) + self.sow('intermediates', 'model_out', + nn_output) + return nn_output + + + diff --git a/Linear_TO/TO_Obj.py b/Linear_TO/TO_Obj.py new file mode 100644 index 0000000..e3b6e2c --- /dev/null +++ b/Linear_TO/TO_Obj.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +''' +@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved. +@ Author : Zeyu Zhang +@ Email : zhangzeyu_work@outlook.com +@ Date : 2022-10-25 09:30:27 +@ LastEditTime : 2022-10-25 09:34:00 +@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Obj.py +@ +@ Description : +@ Reference : +''' +import jax.numpy as jnp +from jax import jit + + +# @jit +def compliance(young, young_min, ke, xPhys_1D, edofMat, penal, disp): + ce = jnp.sum(jnp.matmul(disp[edofMat], ke) * disp[edofMat], 1) + ce = (young_min + xPhys_1D ** penal * (young - young_min)) * ce + obj = jnp.sum(ce) + return obj diff --git a/Linear_TO/TO_Problem.py b/Linear_TO/TO_Problem.py new file mode 100644 index 0000000..51d6b16 --- /dev/null +++ b/Linear_TO/TO_Problem.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +''' +@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved. +@ Author : Zeyu Zhang +@ Email : zhangzeyu_work@outlook.com +@ Date : 2022-10-25 09:30:27 +@ LastEditTime : 2022-10-25 09:33:20 +@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Problem.py +@ +@ Description : +@ Reference : +''' + +import jax.ops as jops +import numpy as np +import dataclasses +from typing import Optional, Union +import jax.numpy as jnp +from jax.config import config +config.update("jax_enable_x64", True) + + +@dataclasses.dataclass +class ToProblem: + L: int + W: int + nelx: int + nely: int + design_condition: int + ele_length: np.float64 + ele_width: np.float64 + ele_thickness: np.float64 + volfrac: np.float64 + fixeddofs: np.ndarray + freedofs: np.ndarray + fext: np.ndarray + filter_width: float + design_map: np.ndarray + + +def cantilever_single(): + L = 0.8 + W = 0.4 + nelx = 80 + nely = 40 + design_condition = 1 + Total_dof = 2*(nelx+1)*(nely+1) + ele_length = L/nelx + ele_width = W/nely + ele_thickness = 1.0 + volfrac = 0.5 + dofs = np.arange(Total_dof) + loaddof = Total_dof-1 + # fext = jnp.zeros(Total_dof) + # fext = jops.index_update(fext, jops.index[loaddof], -1.) + fext = np.zeros(Total_dof) + fext[loaddof] = -1 + fixeddofs = np.array(dofs[0:2*(nely+1):1]) + freedofs = np.setdiff1d(dofs, fixeddofs) + filter_width = 3.0 + design_map = np.ones([nely, nelx], dtype=np.int64) + return ToProblem(L, W, nelx, nely, design_condition, ele_length, ele_width, ele_thickness, volfrac, fixeddofs, freedofs, fext, filter_width, design_map) + + +def cantilever_single_NN(TopOpt): + dense_input_channel = 128 + dense_init_scale = 1.0 + conv_input_0 = 32 + up_sampling_scale = (1, 2, 2, 2, 1) + total_resize = int(np.prod(up_sampling_scale)) + h = TopOpt.nely // total_resize + w = TopOpt.nelx // total_resize + dense_output_channel = h * w * conv_input_0 + dense_scale_init = dense_init_scale * \ + jnp.sqrt(jnp.maximum(dense_output_channel / dense_input_channel, 1)) + model_kwargs = { + 'h': h, + 'w': w, + 'dense_scale_init': dense_scale_init, + 'dense_output_channel': dense_output_channel, + 'dense_input_channel': dense_input_channel, + 'conv_input_0': conv_input_0, + 'up_sampling_scale': up_sampling_scale, + } + return model_kwargs + + diff --git a/Linear_TO/TO_Train.py b/Linear_TO/TO_Train.py new file mode 100644 index 0000000..b7471bc --- /dev/null +++ b/Linear_TO/TO_Train.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +''' +@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved. +@ Author : Zeyu Zhang +@ Email : zhangzeyu_work@outlook.com +@ Date : 2022-10-25 09:30:27 +@ LastEditTime : 2022-10-25 09:37:03 +@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Train.py +@ +@ Description : +@ Reference : +''' + +import sys +from pathlib import Path +from absl import logging +import xarray +import jax +import jax.numpy as jnp +from jax import random +from jaxopt import OptaxSolver +import optax +import time +from functools import partial +import matplotlib.pyplot as plt +from jax.config import config +config.update("jax_enable_x64", True) +sys.path.append( + '/home/zzy/ZZY_CODE/Env_JAX/Topology_Optimization/Linear_TO') +here = Path(__file__).resolve().parent + + +def set_random_seed(seed): + if seed is not None: + rand_key = random.PRNGKey(seed) + return rand_key + + +def optimizer_result_dataset(losses, frames, save_intermediate_designs=False): + best_design = jnp.nanargmin(losses) + logging.info(f'Final loss: {losses[best_design]}') + if save_intermediate_designs: + ds = xarray.Dataset({ + 'loss': (('step',), losses), + 'design': (('step', 'y', 'x'), frames), + }, coords={'step': jnp.arange(len(losses))}) + else: + ds = xarray.Dataset({ + 'loss': (('step',), losses), + 'design': (('y', 'x'), frames[best_design]), + }, coords={'step': jnp.arange(len(losses))}) + return ds + + +def train_TO_Optax(model, max_iterations, save_intermediate_designs=True, seed=1, **kwargs): + losses = [] + frames = [] + # Initialize parameters + init_params = model.init(set_random_seed(seed), None) + batch_state, params = init_params.pop("params") + del init_params + # Instantiate Optimizer + design_condition = model.TopOpt.design_condition + if design_condition == 1: + learning_rate = 0.001 + optimizer = optax.adam(learning_rate) + elif design_condition == 2: + learning_rate = 0.001 + optimizer = optax.adam(learning_rate) + # loss + def loss_cal(params): + all_params = {"params": params} + model_out, net_state = model.apply(all_params, None, mutable="intermediates") + loss_out = model.loss(model_out) + return loss_out, net_state["intermediates"] + loss_grad_fn = jax.value_and_grad(loss_cal, has_aux=True) + # Initialize Optimizer State + state = optimizer.init(params) + # Updated + for iter in range(max_iterations): + start_time_epoch = time.perf_counter() + (loss_val, batch_stats), grads = loss_grad_fn(params) + losses.append(jax.device_get(loss_val)) + updates_params, state = optimizer.update(grads, state) + params = optax.apply_updates(params, updates_params) + design_TO = jax.device_get(model.TopOpt.xPhys.aval.val) + frames.append(design_TO) + plt.figure() + plt.imshow(design_TO, cmap='Greys') + plt.show() + whole_time_epoch = time.perf_counter() - start_time_epoch + print(' It.:%5i Obj.:%11.4f Vol.:%7.3f Time.:%7.3f\n' % + (iter+1, losses[-1], jnp.mean(design_TO), whole_time_epoch)) + + return optimizer_result_dataset(jnp.array(losses), jnp.array(frames), save_intermediate_designs)