From da7820895e416cb9e184ec0036d4981c2295fef5 Mon Sep 17 00:00:00 2001 From: zhangzeyu Date: Tue, 16 Aug 2022 15:35:58 +0800 Subject: [PATCH] 2022.8.16 Commit --- Linear_TO/TONR_Linear_Stiffness.py | 92 +++++++ Linear_TO/TO_Cal.py | 253 ++++++++++++++++++ Linear_TO/TO_Define.py | 245 +++++++++++++++++ Linear_TO/TO_FEA.py | 133 +++++++++ Linear_TO/TO_Model.py | 249 +++++++++++++++++ Linear_TO/TO_Obj.py | 22 ++ Linear_TO/TO_Problem.py | 129 +++++++++ Linear_TO/TO_Train.py | 164 ++++++++++++ Linear_TO/__pycache__/TO_Cal.cpython-38.pyc | Bin 0 -> 8361 bytes .../__pycache__/TO_Define.cpython-38.pyc | Bin 0 -> 8453 bytes Linear_TO/__pycache__/TO_FEA.cpython-38.pyc | Bin 0 -> 4292 bytes Linear_TO/__pycache__/TO_Model.cpython-38.pyc | Bin 0 -> 9461 bytes Linear_TO/__pycache__/TO_Obj.cpython-38.pyc | Bin 0 -> 770 bytes .../TO_Private_Autodiff.cpython-38.pyc | Bin 0 -> 3016 bytes .../__pycache__/TO_Problem.cpython-38.pyc | Bin 0 -> 2152 bytes Linear_TO/__pycache__/TO_Train.cpython-38.pyc | Bin 0 -> 4210 bytes 16 files changed, 1287 insertions(+) create mode 100644 Linear_TO/TONR_Linear_Stiffness.py create mode 100644 Linear_TO/TO_Cal.py create mode 100644 Linear_TO/TO_Define.py create mode 100644 Linear_TO/TO_FEA.py create mode 100644 Linear_TO/TO_Model.py create mode 100644 Linear_TO/TO_Obj.py create mode 100644 Linear_TO/TO_Problem.py create mode 100644 Linear_TO/TO_Train.py create mode 100644 Linear_TO/__pycache__/TO_Cal.cpython-38.pyc create mode 100644 Linear_TO/__pycache__/TO_Define.cpython-38.pyc create mode 100644 Linear_TO/__pycache__/TO_FEA.cpython-38.pyc create mode 100644 Linear_TO/__pycache__/TO_Model.cpython-38.pyc create mode 100644 Linear_TO/__pycache__/TO_Obj.cpython-38.pyc create mode 100644 Linear_TO/__pycache__/TO_Private_Autodiff.cpython-38.pyc create mode 100644 Linear_TO/__pycache__/TO_Problem.cpython-38.pyc create mode 100644 Linear_TO/__pycache__/TO_Train.cpython-38.pyc diff --git a/Linear_TO/TONR_Linear_Stiffness.py b/Linear_TO/TONR_Linear_Stiffness.py new file mode 100644 index 0000000..93850e9 --- /dev/null +++ b/Linear_TO/TONR_Linear_Stiffness.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +''' +@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved. +@ Author : Zeyu Zhang +@ Email : zhangzeyu_work@outlook.com +@ Date : 2021-11-17 10:52:26 +@ LastEditTime : 2022-08-16 14:45:12 +@ FilePath : /Topology_Optimization/Linear_TO/TONR_Linear_Stiffness.py +@ +@ Description : Jax+Flax+optax TONR Linear Stiffness Problem +@ Reference : +''' + +from pathlib import Path # 文件路径的相关操作 +import sys +import pandas as pd +import xarray +import matplotlib.pyplot as plt +import seaborn +import time +import TO_Problem +import TO_Define +import TO_Model +import TO_Train +import jax.numpy as jnp +from jax import random +from jax.config import config +config.update("jax_enable_x64", True) # 对计算精度影响很大 +sys.path.append( + '/home/zzy/ZZY_CODE/Env_JAX/Topology_Optimization/Linear_TO') +here = Path(__file__).resolve().parent + + +design_condition = 1 +if design_condition == 1: + problem = TO_Problem.cantilever_single() + max_iterations = 200 + args = TO_Define.Toparams(problem) + TopOpt_env = TO_Define.Topology_Optimization(args) + model_kwargs = TO_Problem.cantilever_single_NN(TopOpt_env) +elif design_condition == 2: + problem = TO_Problem.cantilever_single_big() + max_iterations = 200 + args = TO_Define.Toparams(problem) + TopOpt_env = TO_Define.Topology_Optimization(args) + model_kwargs = TO_Problem.cantilever_single_big_NN(TopOpt_env) + +TONR_model = TO_Model.TO_CNN(TopOpt=TopOpt_env, **model_kwargs) + +if __name__ == '__main__': + + start_time = time.perf_counter() + seed = 1 + ds_TopOpt = TO_Train.train_TO_Optax( + TONR_model, max_iterations, save_intermediate_designs=True, seed=1) + ds_TopOpt.to_netcdf('ds_TopOpt.nc') # save + # ds_TopOpt = xarray.open_dataset('ds_TopOpt.nc') # load + end_time = time.perf_counter() + whole_time = end_time - start_time + + obj_value = ds_TopOpt.loss.data + x_designed = ds_TopOpt.design.data + + obj_min = jnp.min(obj_value) + obj_max = jnp.max(obj_value) + obj_min_index = jnp.where(obj_value == obj_min)[0][0] + + dims = pd.Index(['TONR'], name='model') + ds = xarray.concat([ds_TopOpt], dim=dims) + ds_TopOpt.loss.transpose().to_pandas().cummin().plot(linewidth=2) + plt.ylim(obj_min * 0.85, obj_value[0] * 1.15) + plt.ylabel('Compliance (loss)') + plt.xlabel('Optimization step') + seaborn.despine() + + final_design = x_designed[obj_min_index, :, :] + final_obj = obj_min + + plt.show() + ds_TopOpt.design.sel(step=obj_min_index).plot.imshow(x='x', y='y', size=2, + aspect=2.5, col_wrap=2, yincrease=False, add_colorbar=False, cmap='Greys') + + def save_gif_movie(images, path, duration=200, loop=0, **kwargs): + images[0].save(path, save_all=True, append_images=images[1:], + duration=duration, loop=loop, **kwargs) + + # images = [ + # TO_Support.image_from_design(design, problem) + # for design in ds.design.sel(model='DL_TO')[:150] + # ] + + # save_gif_movie([im.resize((5*120, 5*20)) for im in images], 'movie.gif') diff --git a/Linear_TO/TO_Cal.py b/Linear_TO/TO_Cal.py new file mode 100644 index 0000000..596b50c --- /dev/null +++ b/Linear_TO/TO_Cal.py @@ -0,0 +1,253 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Apr 23 15:08:01 2020 + +@author: zzy + +Object: 存储TO计算过程所需的函数 +1. 所有函数均可jit +2. 在调用时 选择在最外层封装jit 内层函数则将jit注释 +3. 对于部分确定的static variable 应采用numpy定义(仅限CPU) +""" + +import numpy as np +import jax +import jax.numpy as jnp +import jax.scipy.signal as jss +import jax.lax as jl +from jaxopt import Bisection +from functools import partial +import TO_Private_Autodiff as TPA +import TO_FEA +import TO_Obj +from jax.config import config +config.update("jax_enable_x64", True) + + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def design_variable(x, design_area_indices, nondesign_area_indices, filter_kernal, filter_weight, beta_x, volfrac, projection=False): + """ + Parameters + ---------- + x : shape:[nely, nelx](2D) 经过reshape的NN输出 + design_area_indices : 设计域的单元编号索引 + nondesign_area_indices : 非设计域的单元编号索引 + filter_kernal : TOP88中的H 卷积过滤中的卷积核 计算影响域内每一个单元的权重 + filter_weight : TOP88中的HS 设计域内每一个单元的总权重 与网格划分同维度 + beta_x : Projection参数 + volfrac : 目标体积 + projection : 是否采用Projection + + Returns + ------- + xPhys_1D : shape:[Total_ele](1D) 参与有限元计算的单元相对密度阵 + """ + + # density filtering 用在了转换之前 目前看也是有效果的 + x = filter(x, filter_kernal, filter_weight) + x_1D = x.flatten(order='F') + x_designed = design_and_volume_constraints( + x_1D[design_area_indices], volfrac, beta_x) + xPhys_1D = matrix_set_0( + x_designed, design_area_indices, nondesign_area_indices) + # 直接等于xPhys_1D 但尚需要在真实存在非设计域时验证 + # 原始方案 + # x_in = x[design_map_bool] + # x_designed = design_and_volume_constraints(x_in, volfrac, beta_x) # x[design_map_bool]内含了一步matrix.flatten() + # x_flat = matrix_set_0(x_designed, jnp.flatnonzero(design_map_bool, size=num_design_ele), Total_ele, num_nondesign_ele) + # 因为x_designed内含flatten() 所以这里对应要复原 由于flatten时没有传入order='F' 这里也不能传入 + # xPhys = x_flat.reshape(x.shape) + return xPhys_1D + + +def projection(x, beta_TONR): + """ + 实现projection 将[-z, z]的数值映射至(0, 1) + 目前 + 原理 tanh能够将元素映射至(-1, 1) 因此 tanh_out * 0.5 + 0.5可以映射至(0, 1) + 注意 sigmoid(x) = 0.5 * tanh(0.5 * x) + 0.5 + Parameters + ---------- + x : 待映射的矩阵 + beta_TONR : projection函数斜率的控制参数 + + Returns + ------- + output : sigmoid/projection变换后的矩阵 + """ + output = jnp.tanh(beta_TONR * 0.5 * x)*.5 + 0.5 + return output + + +def logit(obj_V): + """ + 根据obj_V确定sigmoid分母部分e ** (x-b)中 b搜索的初始上下界 + np.clip(a, a_min, a_max) 将a的取值截断在(a_min, a_max)内 + Parameters + ---------- + obj_V : 目标体积 + + Returns + ------- + """ + p = jnp.clip( + obj_V, 0, 1) + return jnp.log(p) - jnp.log1p(-p) # p从-inf到inf p取0.5为0 以0.5为界限 大于小于符号相反 数值相同 + + +# @partial(jax.jit, static_argnums=(1, 2)) +def design_and_volume_constraints(x, obj_V, beta_TONR): + """ + 用来将神经网络的输出转化为满足设计约束 体积约束和projection的用于有限元计算的单元密度 + x_design = 1/(1 + e ** (x - b(x, obj_V))) + Parameters + ---------- + x : 经过密度过滤后的单元相对密度阵 + obj_V : 目标体积. + beta_TONR : Projection参数 + + Returns + ------- + x_design : 满足设计约束和体积约束的相对密度 向量形式 + """ + def find_yita(y, x, beta): + projection_out = projection(x + y, beta) + volume_diff = jnp.mean(projection_out) - obj_V + return volume_diff + lower_bound = logit(obj_V) - jl.stop_gradient(np.max(x)) + upper_bound = logit(obj_V) - jl.stop_gradient(np.min(x)) + bisec = Bisection(optimality_fun=find_yita, lower=lower_bound, upper=upper_bound, tol=1e-12, maxiter=64, check_bracket=False) + yita = bisec.run(x=x, beta=beta_TONR).params + # yita = TPA.binary_search(find_yita, x, beta_TONR, lower_bound, upper_bound) # 自编二分法 + # 用来搜索yita的取值 由二分法获得 本质上这里和保体积projection搜索yita的过程完全相同 + # 实验和公式表明 在保体积projection和此处添加体积约束的sigmoid中 对yita的求解均是一个关于(x, obj_V)的函数 + # 其敏度必然要考虑 + x_design = projection(x + yita, beta_TONR) + x_design = x_design + 0.00001 + return x_design + + +def filter_define(design_map, filter_width): + """ + 定义TO中过滤器的卷积核 权重等参数 + Parameters + ---------- + design_map : 设计域和非设计域定义矩阵 + filter_width : 过滤半径 + + Returns + ------- + filter_kernal : TOP88中的H 卷积过滤中的卷积核 计算影响域内每一个单元的权重 + filter_weight : TOP88中的HS 设计域内每一个单元的总权重 与网格划分同维度 + """ + dy, dx = jnp.meshgrid(jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width)), + jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width))) + filter_kernal = jnp.maximum(0, filter_width - jnp.sqrt(dx ** 2 + dy ** 2)) + filter_weight = jss.convolve2d(design_map, filter_kernal, 'same') + return filter_kernal, filter_weight + + +# @jax.jit +def filter(matrixA, filter_kernal, filter_weight): + """ + Parameters + ---------- + matrixA : 待过滤的矩阵 + filter_kernal : TOP88中的H 卷积过滤中的卷积核 计算影响域内每一个单元的权重 + filter_weight : TOP88中的HS 设计域内每一个单元的总权重 与网格划分同维度 + + Returns + ------- + new_matrix : 过滤后的矩阵 + """ + new_matrix = jss.convolve2d(matrixA, filter_kernal, 'same') / filter_weight + return new_matrix + + +def inverse_permutation(indices): + inverse_perm = jnp.zeros(len(indices), dtype=jnp.int64) + inverse_perm = inverse_perm.at[indices].set( + jnp.arange(len(indices), dtype=jnp.int64)) + return inverse_perm + + +# @jax.jit +def matrix_set_0(nonzero_values, nonzero_indices, zero_indices): + """ + 补齐矩阵 对原始矩阵指定区域置0 可用于非设计域进行置0 以及节点位移添加指定节点的位移 + Parameters + ---------- + nonzero_values : 已存在数值的矩阵 如设计域的相对密度阵 freedofs对应的节点位移 + nonzero_indices : size=存在数值的个数 已存在数值矩阵对应的索引编号 如设计域对应的索引 + 注意 这里为根据design_map自动生成的 是基于行展开编号的 仅用于此处计算 并非真正TO中的单元编号 + zero_indices : 矩阵中未存在数值的单元对应的索引编号 + origin_matrix_size : 原始矩阵元素个数 即包含设计域和非设计域 + zero_size : 矩阵中0元素的个数 代表非设计域的单元个数 + + Returns + ------- + new_matrix : 补齐后的矩阵 + """ + # all_indices = np.arange(origin_matrix_size, dtype=jnp.int64) + # zero_indices = jnp.setdiff1d( + # all_indices, nonzero_indices, size=zero_size, assume_unique=True) # setdiff1d通常不支持jit 需要指定size + index_map = inverse_permutation( + jnp.concatenate([nonzero_indices, zero_indices])) + u_values = jnp.concatenate([nonzero_values, jnp.zeros(len(zero_indices))]) + new_matrix = u_values[index_map] + return new_matrix + + +@partial(jax.jit, static_argnums=(2, 3)) +def matrix_set_1(nonzero_values, nonzero_indices, origin_matrix_size, zero_size): + """ + 补齐矩阵 对原始矩阵指定区域置1 + Parameters + ---------- + nonzero_values : 已存在数值的矩阵 + nonzero_indices : 已存在数值矩阵对应的索引编号 + origin_matrix_size : 原始矩阵元素个数 即包含设计域和非设计域 + zero_size : 矩阵中0元素的个数 代表非设计域的单元个数 + + Returns + ------- + new_matrix : 补齐后的矩阵 + """ + all_indices = jnp.arange(origin_matrix_size, dtype=jnp.int64) + zero_indices = jnp.setdiff1d( + all_indices, nonzero_indices, size=zero_size, assume_unique=True) + index_map = inverse_permutation( + jnp.concatenate([nonzero_indices, zero_indices])) + u_values = jnp.concatenate([nonzero_values, jnp.ones(len(zero_indices))]) + new_matrix = u_values[index_map] + return new_matrix + + +@partial(jax.jit, static_argnums=(1, 2, 10)) +def objective(xPhys_1D, young, young_min, freedofs, fext, s_K_predefined, dispTD_predefined, idx, edofMat, ke, penal): + """ + Parameters + ---------- + xPhys_1D : shape:[Total_ele](1D) 参与有限元计算的相对密度 + young : 杨氏模量 + young_min : 弱单元的杨氏模量 + freedofs : 有限元的自由节点 + fext : Python中的1D矢量 外力 + s_K_predefined : 预定义的刚度阵 + dispTD_predefined : 预定义的位移阵 + idx : 刚度阵组装编号 + edofMat : shape:[Total_ele, 8](2D) 按照单元顺序存储每一个单元的自由度编号 + ke : 单元刚度阵 不考虑杨氏模量 + penal : 惩罚因子 + + Returns + ------- + obj : 目标函数输出 此处为结构的柔度 + """ + # xPhys_1D = xPhys.T.flatten() + disp = TO_FEA.FEA(young, young_min, ke, xPhys_1D, freedofs, + penal, idx, fext, s_K_predefined, dispTD_predefined) + obj = TO_Obj.compliance(young, young_min, ke, + xPhys_1D, edofMat, penal, disp) + # print('obj is running') + return obj \ No newline at end of file diff --git a/Linear_TO/TO_Define.py b/Linear_TO/TO_Define.py new file mode 100644 index 0000000..f5df482 --- /dev/null +++ b/Linear_TO/TO_Define.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Apr 22 17:40:00 2020 + +@author: zzy + +Object: 对TO问题进行定义 +1. 对于部分确定的static variable 应采用numpy定义 +""" + +import numpy as np +import jax.numpy as jnp +import jax.ops as jops +import TO_Cal +import TO_FEA + + +def Toparams(problem): + young = 1.0 + rou = 1.0 + nelx, nely, ele_length, ele_width = problem.nelx, problem.nely, problem.ele_length, problem.ele_width + Total_ele, Total_nod, Total_dof = nelx * \ + nely, (nelx+1)*(nely+1), 2*(nelx+1)*(nely+1) + gobalcoords0xy, M_elenod, M_edofMat, M_iK, M_jK = coordinate_and_preFEA( + nelx, nely, ele_length, ele_width, Total_ele, Total_nod) + elenod, edofMat, iK, jK = M_elenod-1, M_edofMat-1, M_iK-1, M_jK - \ + 1 # gobalcoords0xy是绝对坐标无需更改 其他根据索引方式均-1 纯粹的数值处理 后面索引用的IE也要从0起 + edofMat_3D = TO_edofMat_3DTensor(nelx, nely) # 这是一个Tensor版本edofMat + design_map_bool, num_design_ele, num_nondesign_ele, design_area_indices, nondesign_area_indices = handle_design_region( + nelx, nely, Total_ele, problem.design_map) + params = { + 'young': young, + 'young_min': 1e-9*young, + 'rou': rou, + 'poisson': 0.3, + 'g': 0, + 'ele_length': ele_length, + 'ele_width': ele_width, + 'ele_thickness': problem.ele_thickness, + 'nelx': nelx, + 'nely': nely, + 'Total_ele': Total_ele, + 'Total_nod': Total_nod, + 'Total_dof': Total_dof, + 'volfrac': problem.volfrac, + 'xmin': 0.001, + 'xmax': 1.0, + 'design_map': problem.design_map, + # 'design_map_bool': design_map_bool, + # 'num_design_ele': num_design_ele, + # 'num_nondesign_ele': num_nondesign_ele, + 'design_area_indices': design_area_indices, + 'nondesign_area_indices': nondesign_area_indices, + 'freedofs': problem.freedofs, + 'fixeddofs': problem.fixeddofs, + 'fext': problem.fext, + 'penal': 3.0, + 'filter_width': problem.filter_width, + 'gobalcoords0xy': gobalcoords0xy, + 'M_elenod': M_elenod, + 'M_edofMat': M_edofMat, + 'M_iK': M_iK, + 'M_jK': M_jK, + 'elenod': elenod, + 'edofMat': edofMat, + 'edofMat_3D': edofMat_3D, + 'iK': iK, + 'jK': jK, + 'design_condition': problem.design_condition, + } + return params + + +class Topology_Optimization: + + def __init__(self, args): + self.args = args + self.young, self.young_min, self.poisson = args['young'], args['young_min'], args['poisson'] + self.Total_dof, self.Total_ele, self.nelx, self.nely, self.ele_thickness = args[ + 'Total_dof'], args['Total_ele'], args['nelx'], args['nely'], args['ele_thickness'] + self.freedofs, self.fext, self.iK, self.jK, self.edofMat = args[ + 'freedofs'], args['fext'], args['iK'], args['jK'], args['edofMat'] + self.filter_width, self.design_condition, self.volfrac = args[ + 'filter_width'], args['design_condition'], args['volfrac'] + self.design_map, self.design_area_indices, self.nondesign_area_indices = args[ + 'design_map'], args['design_area_indices'], args['nondesign_area_indices'] + self.ke = TO_FEA.linear_stiffness_matrix( + self.young, self.poisson, self.ele_thickness) + self.filter_kernal, self.filter_weight = TO_Cal.filter_define( + self.design_map, self.filter_width) + self.s_K_predefined, self.dispTD_predefined, self.idx = jnp.zeros( + [self.Total_dof, self.Total_dof]), jnp.zeros([self.Total_dof]), jops.index[self.iK, self.jK] + self.loop_1, self.iCont = 0, 0 + self.penal_use, self.beta_x_use = 3.0, 1.0 + self.beta_x_use_max, self.beta_x_deltath1, self.beta_x_deltath2 = 16, 50, 30, + self.xPhys = self.volfrac * jnp.ones([self.nely, self.nelx]) + + def xPhys_transform(self, params, projection=False): + x = params.reshape(self.nely, self.nelx) + beta_x = self.penal_and_betax() + xPhys_1D = TO_Cal.design_variable(x, self.design_area_indices, self.nondesign_area_indices, + self.filter_kernal, self.filter_weight, beta_x, self.volfrac, projection=projection) + return xPhys_1D + + def objective(self, params, projection=False): + self.loop_1 = self.loop_1 + 1 + xPhys_1D = self.xPhys_transform(params, projection=projection) + Obj = TO_Cal.objective(xPhys_1D, self.young, self.young_min, self.freedofs, self.fext, + self.s_K_predefined, self.dispTD_predefined, self.idx, self.edofMat, self.ke, self.penal_use) + self.xPhys = xPhys_1D.reshape(self.nely, self.nelx, order='F') + return Obj + + def penal_and_betax(self, *args): + self.iCont = self.iCont + 1 + if (self.iCont > self.beta_x_deltath1 and self.beta_x_use < 2): + self.beta_x_use = self.beta_x_use * 2 + self.iCont = 1 + print(' beta_x increased to :%7.3f\n' % (self.beta_x_use)) + elif (self.iCont > self.beta_x_deltath2 and self.beta_x_use < self.beta_x_use_max and self.beta_x_use >= 2): + self.beta_x_use = self.beta_x_use * 2 + self.iCont = 1 + print(' beta_x increased to :%7.3f\n' % (self.beta_x_use)) + return self.beta_x_use + + +def coordinate_and_preFEA(nelx, nely, ele_length, ele_width, Total_ele, Total_nod): + """ + 有限元计算相关量 单元及节点编号顺序为从左至右 从上至下 + 采用Matlab计数法 对单一单元内物理量 按左下逆时针顺序 + Parameters + ---------- + nelx : 横向网格划分 + nely : 纵向网格划分 + ele_length : 单元长度 + ele_width : 单元宽度 + Total_ele : 单元总数 + Total_nod : 节点总数 + + Returns + ------- + gobalcoords0xy : shape:[Total_nod, 2](2D) 按照节点顺序存储每一个节点的物理坐标(x, y) + elenod : shape:[Total_ele, 4](2D) 按照单元顺序存储每一个单元的节点编号 + edofMat : shape:[Total_ele, 8](2D) 按照单元顺序存储每一个单元的自由度编号 + iK_int : shape:[Total_ele * 64](2D) 用于刚度阵组装 按照特定顺序储存自由度编号 + jK_int : shape:[Total_ele * 64](2D) 用于刚度阵组装 按照特定顺序储存自由度编号 + """ + i0, j0 = jnp.meshgrid(jnp.arange(nelx+1), jnp.arange(nely+1)) + gobalcoords0x = i0*ele_length + gobalcoords0y = ele_width*nely-j0*ele_width + gobalcoords0xy = jnp.hstack((matrix_array(gobalcoords0x), matrix_array( + gobalcoords0y))) + gobalcoords0xyT = gobalcoords0xy.T + gobalcoords0 = matrix_array(gobalcoords0xyT) + nodegrd = matrix_order_arrange(1, Total_nod, nely+1, nelx+1) + nodelast = Matlab_reshape(Matlab_matrix_extract( + nodegrd, 1, -1, 1, -1), Total_ele, 1) + edofVec = 2*matrix_array(nodelast)+1 + elenod = jnp.tile(nodelast, (1, 4)) + \ + jnp.tile(jnp.array([1, nely+2, nely+1, 0]), (Total_ele, 1)) + edofMat = jnp.tile(edofVec, (1, 8))+jnp.tile(jnp.array( + [0, 1, 2*nely+2, 2*nely+3, 2*nely+0, 2*nely+1, -2, -1]), (Total_ele, 1)) + iK = jnp.kron(edofMat, jnp.ones((8, 1))).flatten() + jK = jnp.kron(edofMat, jnp.ones((1, 8))).flatten() + iK_int = iK.astype(jnp.int32) + jK_int = jK.astype(jnp.int32) + return gobalcoords0xy, elenod, edofMat, iK_int, jK_int + + +def handle_design_region(nelx, nely, Total_ele, design_map): + """ + 对设计域/非设计域相关参数进行预定义 + Parameters + ---------- + nelx : 横向网格划分 + nely : 纵向网格划分 + Total_ele : 单元总数 + design_map : 设计域和非设计域定义矩阵 + + Returns + ------- + design_map_bool : bool形式的设计域和非设计域定义矩阵 + num_design_ele : 可设计单元总数 + num_nondesign_ele : 非设计单元总数 + design_area_indices : 设计域的单元编号索引 + nondesign_area_indices : 非设计域的单元编号索引 + """ + shape = (nely, nelx) + design_map_bool = np.broadcast_to(design_map, shape) > 0 + num_design_ele = np.sum(design_map) + num_nondesign_ele = Total_ele - num_design_ele + + design_map_bool_1D = design_map_bool.flatten(order='F') + all_indices = np.arange(Total_ele, dtype=jnp.int64) + design_area_indices = jnp.flatnonzero( + design_map_bool_1D, size=num_design_ele) + nondesign_area_indices = jnp.setdiff1d( + all_indices, design_area_indices, size=num_nondesign_ele, assume_unique=True) # setdiff1d通常不支持jit 需要指定size + return design_map_bool, num_design_ele, num_nondesign_ele, design_area_indices, nondesign_area_indices + + +def Matlab_matrix_array1D(matrixA): + """ + shape:[*] 1D array 按照matlab的排列 + """ + return matrixA.T.flatten() + + +def matrix_array(matrixA): + """ + shape:[*, 1] 2D array 按照matlab的排列 matrixA(:) + """ + return matrixA.T.reshape(-1, 1) + + +def matrix_order_arrange(numbegin, numend, ydim, xdim): + ''' + shape:[ydim, xdim] 2D array 实现一个元素按列顺序排列的矩阵 注意索引编号和matlab的差1 + ''' + matrixA = jnp.array([[i for i in range(numbegin, numend+1)]]) + matrixA = matrixA.reshape(xdim, ydim) + matrixA = matrixA.T + return matrixA + + +def Matlab_reshape(matrixA, ydim, xdim): + ''' + shape:[ydim, xdim] 2D array matlab的reshape效果 + ''' + return matrixA.reshape((ydim, xdim), order='F') + + +def Matlab_matrix_extract(matrixA, xbegin, xend, ybegin, yend): + matrixA = matrixA[ybegin-1:yend, xbegin-1:xend] + return matrixA + + +def TO_edofMat_3DTensor(nelx, nely): + ely, elx = jnp.meshgrid(jnp.arange(nely), jnp.arange(nelx)) # x, y coords + n1 = (nely+1)*(elx+0) + (ely+0) + n2 = (nely+1)*(elx+1) + (ely+0) + n3 = (nely+1)*(elx+1) + (ely+1) + n4 = (nely+1)*(elx+0) + (ely+1) + edofMat_3D = jnp.array( + [2*n4, 2*n4+1, 2*n3, 2*n3+1, 2*n2, 2*n2+1, 2*n1, 2*n1+1]) + return edofMat_3D diff --git a/Linear_TO/TO_FEA.py b/Linear_TO/TO_FEA.py new file mode 100644 index 0000000..d7cc135 --- /dev/null +++ b/Linear_TO/TO_FEA.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Apr 25 01:57:01 2020 + +@author: zzy + +Object: 存储有限元计算所需函数 +1. 所有函数均可jit +2. 在调用时 选择在最外层封装jit 内层函数则将jit注释 +3. 对于部分确定的static variable 应采用numpy定义(仅限CPU) +""" + +import numpy as np +from numpy import float64 +import jax.numpy as jnp +import jax.ops as jops +import jax.scipy.linalg as jsl +from jax import jit +import TO_Private_Autodiff as TPA +from jax.config import config +config.update("jax_enable_x64", True) # 对计算精度影响很大 + + +def linear_stiffness_matrix(young, poisson, ele_thickness): + """ + Parameters + ---------- + young : 杨氏模量 + poisson : 泊松比 + ele_thickness : 单元厚度 + + Returns + ------- + stiffness_ele : 单元刚度矩阵 注意 考虑到TO问题 此处没有考虑单元杨式模量的影响 + """ + young, poisson, h = young, poisson, ele_thickness + k = np.array([1/2-poisson/6, 1/8+poisson/8, -1/4-poisson/12, -1/8+3*poisson/8, + -1/4+poisson/12, -1/8-poisson/8, poisson/6, 1/8-3*poisson/8]) + stiffness_ele = h/(1-poisson**2)*np.array([[k[0], k[1], k[2], k[3], k[4], k[5], k[6], k[7]], + [k[1], k[0], k[7], k[6], + k[5], k[4], k[3], k[2]], + [k[2], k[7], k[0], k[5], + k[6], k[3], k[4], k[1]], + [k[3], k[6], k[5], k[0], + k[7], k[2], k[1], k[4]], + [k[4], k[5], k[6], k[7], + k[0], k[1], k[2], k[3]], + [k[5], k[4], k[3], k[2], + k[1], k[0], k[7], k[6]], + [k[6], k[3], k[4], k[1], + k[2], k[7], k[0], k[5]], + [k[7], k[2], k[1], k[4], + k[3], k[6], k[5], k[0]] + ], dtype=float64) + return stiffness_ele + + +def density_matrix(rou, ele_length, ele_width, ele_thickness): + """ + Parameters + ---------- + rou : 密度 + ele_length : 单元长度 + ele_width : 单元宽度 + ele_thickness : 单元厚度 + + Returns + ------- + density_ele : 单元密度矩阵(一致质量矩阵) 注意 考虑到TO问题 此处没有考虑单元相对密度的影响 否则分子应当是质量 + """ + ele_volume = ele_length*ele_width*ele_thickness + density_ele = ele_volume / 36 * np.array([[4, 0, 2, 0, 1, 0, 2, 0], + [0, 4, 0, 2, + 0, 1, 0, 2], + [2, 0, 4, 0, + 2, 0, 1, 0], + [0, 2, 0, 4, + 0, 2, 0, 1], + [1, 0, 2, 0, + 4, 0, 2, 0], + [0, 1, 0, 2, + 0, 4, 0, 2], + [2, 0, 1, 0, + 2, 0, 4, 0], + [0, 2, 0, 1, + 0, 2, 0, 4], + ], dtype=float64) + return density_ele + + +# @jit +def FEA(young, young_min, ke, xPhys_1D, freedofs, penal, idx, fext, s_K_predefined, dispTD_predefined): + """ + Parameters + ---------- + young : 杨氏模量 + young_min : 弱单元的杨氏模量 + ke : 单元刚度阵 不考虑杨氏模量 + xPhys_1D : shape:[Total_ele](1D) 参与有限元计算的单元相对密度阵 (这种形式是必要的) + freedofs : 有限元的自由节点 + penal : 惩罚因子 + idx : 刚度阵组装编号 + fext : shape:[Total_dof](1D) 外力 + s_K_predefined : shape:[Total_dof, Total_dof](2D) 预定义的刚度阵 + dispTD_predefined : shape:[Total_dof](1D) 预定义的位移阵 + + Returns + ------- + dispTD :shape:[Total_dof](1D) 节点位移 + """ + + def kuf_solve(fext_1D_free, s_K_free): + u_free = jsl.solve(s_K_free, fext_1D_free, + sym_pos=True, check_finite=False) + return u_free + + def young_modulus(xPhys_1D, young, young_min, ke, penal): + """ + Returns + ------- + sK : shape:[64, Total_ele](2D) 考虑相对密度后的每个单元的刚度阵(列形式存储) + """ + sK = ((ke.flatten()[jnp.newaxis]).T * (young_min + (xPhys_1D) ** penal * (young-young_min)) + ).flatten(order='F') + return sK + + sK = young_modulus(xPhys_1D, young, young_min, ke, penal) + s_K = jops.index_add(s_K_predefined, idx, sK) + s_K_free = s_K[freedofs, :][:, freedofs] + fext_free = fext[freedofs] + u_free = kuf_solve(fext_free, s_K_free) + dispTD = jops.index_add(dispTD_predefined, freedofs, u_free) + return dispTD diff --git a/Linear_TO/TO_Model.py b/Linear_TO/TO_Model.py new file mode 100644 index 0000000..d1d992b --- /dev/null +++ b/Linear_TO/TO_Model.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Oct 19 00:00:06 2021 + +@author: zzy + +Object: AuTONR的model构建 +1. 基于Jax+Flax+JAXopt/Optax +2. 关于dataclasses和flax.linen可以参考: + learning_dataclasses.py + learning_flax_module.py +""" + +import jax +import jax.numpy as jnp +import jax.image as ji +import jax.tree_util as jtree +from jax import random +import flax.linen as nn +import flax.linen.initializers as fli +import flax.traverse_util as flu +import flax.core.frozen_dict as fcfd +from functools import partial +from typing import Any, Callable +from jax.config import config +config.update("jax_enable_x64", True) + + +def set_random_seed(seed): + if seed is not None: + rand_key = random.PRNGKey(seed) + return rand_key + + +def constant_array(rng_key, shape, dtype=jnp.float64, value=0.0): + """ + 自定义参数初始化函数 实现矩阵填充数值 + Parameters + ---------- + rng_key : + shape : 目标矩阵的shape + dtype : 数据格式 The default is jnp.float64. + value : 填充数值 The default is 0.0. + + Returns + ------- + out : 初始化后的矩阵 + """ + out = jnp.full(shape, value, dtype) + # print("i am running") + return out + + +def extract_model_params(params_use): + """ + 提取网络参数 + Parameters + ---------- + params_use : FrozenDict 需要提取的参数 + + Returns + ------- + extracted_params : dict 提取出的参数 {'Conv_0/bias': ..., 'Conv_0/kernel': ..., ...} + params_shape : dict 提取的每一组参数的shape {'Conv_0/bias': (...), 'Conv_0/kernel': (...), ...} + params_total_num : 参数总数 + """ + params_unfreeze = params_use.unfreeze() + extracted_params = { + '/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()} + params_shape = jtree.tree_map(jnp.shape, extracted_params) + params_total_num = sum(p.size for p in jtree.tree_leaves(extracted_params)) + return extracted_params, params_shape, params_total_num + + +def replace_model_params(pre_trained_params_dict, params_use): + """ + 替换网络参数 + Parameters + ---------- + pre_trained_params_dict : dict 预训练的参数 {'Conv_0/bias': ..., 'Conv_0/kernel': ..., ...} + params_use : FrozenDict 需要被替换的参数 + + Returns + ------- + new_params_use : FrozenDict 新的参数 + """ + params_unfreeze = params_use.unfreeze() + # 有两种方式均可实现功能 + # Method A 基于flax.traverse_util实现 + extracted_params = { + '/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()} + for key, val in pre_trained_params_dict.items(): + extracted_params[key] = val + new_params_use = flu.unflatten_dict( + {tuple(k.split('/')): v for k, v in extracted_params.items()}) + # Method B 基于jax.tree实现 + params_leaves, params_struct = jtree.tree_flatten(params_unfreeze) + i = 0 + for key, val in pre_trained_params_dict.items(): + params_leaves[i] = val + i = i + 1 + new_params_use = jtree.tree_unflatten(params_struct, params_leaves) + # 封装新的参数 + new_params_use = fcfd.freeze(new_params_use) + return new_params_use + + +def batched_loss(x_inputs, TopOpt_envs): + """ + 处理带有batch_size和多个实例化的TopOpt类的情况 + Parameters + ---------- + x_inputs : shape:[1, nely, nelx] (3D array) 第一个维度代表batch 第二和第三个维度代表TO_2D问题中的单元相对密度 通常batch_size = 1 + TopOpt_envs : 多个实例化的TopOpt类 以列表形式存储 通常仅有一个实例化的类 + + Returns + ------- + losses_array : loss_list以列表形式存储对应于每一个batch和TopOpt的obj输出 将其转化为array形式 + """ + loss_list = [TopOpt.objective(x_inputs[i], projection=True) + for i, TopOpt in enumerate(TopOpt_envs)] + losses_array = jnp.stack(loss_list)[0] + return losses_array + + +# @jax.jit +class Normalize_TO(nn.Module): + """ + 自定义Normalize类 + """ + epsilon: float = 1e-6 + + @nn.compact + def __call__(self, input): + # 通过对axis设置可实现不同normalization的效果 + input_mean = jnp.mean(input, axis=(1, 2, 3), keepdims=True) + input_var = jnp.var(input, axis=(1, 2, 3), keepdims=True) + output = input - input_mean + output = output * jax.lax.rsqrt(input_var + self.epsilon) + return output + + +class Add_offset(nn.Module): + scale: int = 10 + + @nn.compact + def __call__(self, input): + add_bias = self.param('add_bias', lambda rng, shape: constant_array( + rng, shape, value=0.0), input.shape) + return input + self.scale * add_bias + + +class BasicBlock(nn.Module): + resize_scale_one: int + conv_output_one: int + norm_use: nn.Module = Normalize_TO + resize_method: str = "bilinear" + kernel_size: int = 5 + conv_kernel_init: partial = fli.variance_scaling( + scale=1.0, mode="fan_in", distribution="truncated_normal") + offset: nn.Module = Add_offset + offset_scale: int = 10 + + @nn.compact + def __call__(self, input): + output = jnp.tanh(input) + output_shape = output.shape + output = ji.resize(output, (output_shape[0], self.resize_scale_one * output_shape[1], + self.resize_scale_one * output_shape[2], output_shape[-1]), method=self.resize_method, antialias=True) + output = self.norm_use()(output) + output = nn.Conv(features=self.conv_output_one, kernel_size=( + self.kernel_size, self.kernel_size), kernel_init=self.conv_kernel_init)(output) + output = self.offset(scale=self.offset_scale)(output) + return output + + +class My_TO_Model(nn.Module): + TopOpt: Any + seed: int = 1 + replace_params: Callable = replace_model_params + extract_params: Callable = extract_model_params + + def loss(self, input_TO): + # 以list形式传入TopOpt 其实tuple形式也可以 + obj_TO = batched_loss(input_TO, [self.TopOpt]) + return obj_TO + + +class TO_CNN(My_TO_Model): + h: int = 5 + w: int = 10 + dense_scale_init: Any = 1.0 + dense_output_channel: int = 1000 + dtype: Any = jnp.float32 + input_init_std: float = 1.0 # 初始化设计变量的标准差 正态分布 + dense_input_channel: int = 128 + conv_input_0: int = 32 + conv_output: tuple = (128, 64, 32, 16, 1) + up_sampling_scale: tuple = (1, 2, 2, 2, 1) + offset_scale: int = 10 + block: nn.Module = BasicBlock + + @nn.compact + def __call__(self, x=None): # [N, H, W, C] + nn_input = self.param('z', lambda rng, shape: constant_array( + rng, shape, value=self.TopOpt.volfrac), self.dense_input_channel) + + output = nn.Dense(features=self.dense_output_channel, kernel_init=fli.orthogonal( + scale=self.dense_scale_init))(nn_input) + output = output.reshape([1, self.h, self.w, self.conv_input_0]) + + output = self.block( + self.up_sampling_scale[0], self.conv_output[0])(output) + output = self.block( + self.up_sampling_scale[1], self.conv_output[1])(output) + output = self.block( + self.up_sampling_scale[2], self.conv_output[2])(output) + output = self.block( + self.up_sampling_scale[3], self.conv_output[3])(output) + output = self.block( + self.up_sampling_scale[4], self.conv_output[4])(output) + + nn_output = jnp.squeeze(output, axis=-1) + self.sow('intermediates', 'model_out', + nn_output) # 在这里是否可以考虑去掉batch维度? + return nn_output + + +if __name__ == '__main__': + + def funA(params): + out = 1 * jnp.sum(params) + return out + + def funB(params): + out = 2 * jnp.sum(params) + return out + + def funC(params): + out = 3 * jnp.sum(params) + return out + + # 测试 直观展示batched_loss + params_1 = jnp.ones([3, 4, 4]) + fun_list = [funA, funB, funC] + losses = [fun_used(params_1[i]) for i, fun_used in enumerate(fun_list)] + loss_out = jnp.stack(losses) + print("losses:", losses) + print("loss_out:", loss_out) diff --git a/Linear_TO/TO_Obj.py b/Linear_TO/TO_Obj.py new file mode 100644 index 0000000..c5e5a46 --- /dev/null +++ b/Linear_TO/TO_Obj.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +""" +Created on Sun Apr 26 01:59:52 2020 + +@author: zzy + +Object: 存储TO计算目标函数过程所需的函数 +1. 所有函数均可jit +2. 在调用时 选择在最外层封装jit 内层函数则将jit注释 +3. 对于部分确定的static variable 应采用numpy定义 +""" + +import jax.numpy as jnp +from jax import jit + + +# @jit +def compliance(young, young_min, ke, xPhys_1D, edofMat, penal, disp): + ce = jnp.sum(jnp.matmul(disp[edofMat], ke) * disp[edofMat], 1) + ce = (young_min + xPhys_1D ** penal * (young - young_min)) * ce + obj = jnp.sum(ce) + return obj diff --git a/Linear_TO/TO_Problem.py b/Linear_TO/TO_Problem.py new file mode 100644 index 0000000..a94b7df --- /dev/null +++ b/Linear_TO/TO_Problem.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 11 14:20:11 2020 + +@author: zzy + +定义设计域 +""" + +import jax.ops as jops +import numpy as np +import dataclasses +from typing import Optional, Union +import jax.numpy as jnp +from jax.config import config +config.update("jax_enable_x64", True) + + +@dataclasses.dataclass +class ToProblem: + L: int + W: int + nelx: int + nely: int + design_condition: int + ele_length: np.float64 + ele_width: np.float64 + ele_thickness: np.float64 + volfrac: np.float64 + fixeddofs: np.ndarray + freedofs: np.ndarray + fext: np.ndarray + filter_width: float + design_map: np.ndarray + + +def cantilever_single(): + L = 0.8 + W = 0.4 + nelx = 80 + nely = 40 + design_condition = 1 + Total_dof = 2*(nelx+1)*(nely+1) + ele_length = L/nelx + ele_width = W/nely + ele_thickness = 1.0 + volfrac = 0.5 + dofs = np.arange(Total_dof) + loaddof = Total_dof-1 + # fext = jnp.zeros(Total_dof) + # fext = jops.index_update(fext, jops.index[loaddof], -1.) + fext = np.zeros(Total_dof) + fext[loaddof] = -1 + fixeddofs = np.array(dofs[0:2*(nely+1):1]) + freedofs = np.setdiff1d(dofs, fixeddofs) + filter_width = 3.0 + design_map = np.ones([nely, nelx], dtype=np.int64) + return ToProblem(L, W, nelx, nely, design_condition, ele_length, ele_width, ele_thickness, volfrac, fixeddofs, freedofs, fext, filter_width, design_map) + + +def cantilever_single_NN(TopOpt): + dense_input_channel = 128 + dense_init_scale = 1.0 + conv_input_0 = 32 + up_sampling_scale = (1, 2, 2, 2, 1) + total_resize = int(np.prod(up_sampling_scale)) + h = TopOpt.nely // total_resize + w = TopOpt.nelx // total_resize + dense_output_channel = h * w * conv_input_0 + dense_scale_init = dense_init_scale * \ + jnp.sqrt(jnp.maximum(dense_output_channel / dense_input_channel, 1)) + model_kwargs = { + 'h': h, + 'w': w, + 'dense_scale_init': dense_scale_init, + 'dense_output_channel': dense_output_channel, + 'dense_input_channel': dense_input_channel, + 'conv_input_0': conv_input_0, + 'up_sampling_scale': up_sampling_scale, + } + return model_kwargs + + +def cantilever_single_big(): + L = 1.6 + W = 0.8 + nelx = 160 + nely = 80 + design_condition = 2 + Total_dof = 2*(nelx+1)*(nely+1) + ele_length = L/nelx + ele_width = W/nely + ele_thickness = 1.0 + volfrac = 0.5 + dofs = np.arange(Total_dof) + loaddof = Total_dof-1 + # fext = jnp.zeros(Total_dof) + # fext = jops.index_update(fext, jops.index[loaddof], -1.) + fext = np.zeros(Total_dof) + fext[loaddof] = -1 + fixeddofs = np.array(dofs[0:2*(nely+1):1]) + freedofs = np.setdiff1d(dofs, fixeddofs) + filter_width = 3.0 + design_map = np.ones([nely, nelx], dtype=np.int64) + return ToProblem(L, W, nelx, nely, design_condition, ele_length, ele_width, ele_thickness, volfrac, fixeddofs, freedofs, fext, filter_width, design_map) + + +def cantilever_single_big_NN(TopOpt): + dense_input_channel = 128 + dense_init_scale = 1.0 + conv_input_0 = 32 + up_sampling_scale = (1, 2, 2, 2, 2) + total_resize = int(np.prod(up_sampling_scale)) + h = TopOpt.nely // total_resize + w = TopOpt.nelx // total_resize + dense_output_channel = h * w * conv_input_0 + dense_scale_init = dense_init_scale * \ + jnp.sqrt(jnp.maximum(dense_output_channel / dense_input_channel, 1)) + model_kwargs = { + 'h': h, + 'w': w, + 'dense_scale_init': dense_scale_init, + 'dense_output_channel': dense_output_channel, + 'dense_input_channel': dense_input_channel, + 'conv_input_0': conv_input_0, + 'up_sampling_scale': up_sampling_scale, + } + return model_kwargs + diff --git a/Linear_TO/TO_Train.py b/Linear_TO/TO_Train.py new file mode 100644 index 0000000..e0db91b --- /dev/null +++ b/Linear_TO/TO_Train.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Oct 19 09:53:14 2021 + +@author: zzy + +Object: AuTONR的训练函数 +1. 基于Flax+JAXopt/Optax实现 +2. 关于train过程的实现可以参考: + learning_optax_training.py + learning_flax_managing_params.py + learning_flax_training.py +""" + +import sys +from pathlib import Path +from absl import logging +import xarray +import jax +import jax.numpy as jnp +from jax import random +from jaxopt import OptaxSolver +import optax +import time +from functools import partial +import matplotlib.pyplot as plt +from jax.config import config +config.update("jax_enable_x64", True) +sys.path.append( + '/home/zzy/ZZY_CODE/Env_JAX/Topology_Optimization/Linear_TO') +here = Path(__file__).resolve().parent + + +def set_random_seed(seed): + if seed is not None: + rand_key = random.PRNGKey(seed) + return rand_key + + +def optimizer_result_dataset(losses, frames, save_intermediate_designs=False): + # The best design will often but not always be the final one. + best_design = jnp.nanargmin(losses) + logging.info(f'Final loss: {losses[best_design]}') + if save_intermediate_designs: + ds = xarray.Dataset({ + 'loss': (('step',), losses), + 'design': (('step', 'y', 'x'), frames), + }, coords={'step': jnp.arange(len(losses))}) + else: + ds = xarray.Dataset({ + 'loss': (('step',), losses), + 'design': (('y', 'x'), frames[best_design]), + }, coords={'step': jnp.arange(len(losses))}) + return ds + + +def train_TO_Optax(model, max_iterations, save_intermediate_designs=True, seed=1, **kwargs): + losses = [] + frames = [] + # Initialize parameters + init_params = model.init(set_random_seed(seed), None) + batch_state, params = init_params.pop("params") + del init_params + # Instantiate Optimizer + design_condition = model.TopOpt.design_condition + if design_condition == 1: + learning_rate = 0.001 + optimizer = optax.adam(learning_rate) + elif design_condition == 2: + learning_rate = 0.001 + optimizer = optax.adam(learning_rate) + # loss + def loss_cal(params): + all_params = {"params": params} + model_out, net_state = model.apply(all_params, None, mutable="intermediates") + loss_out = model.loss(model_out) + return loss_out, net_state["intermediates"] + loss_grad_fn = jax.value_and_grad(loss_cal, has_aux=True) + # Initialize Optimizer State + state = optimizer.init(params) + # Updated + for iter in range(max_iterations): + start_time_epoch = time.perf_counter() + (loss_val, batch_stats), grads = loss_grad_fn(params) + losses.append(jax.device_get(loss_val)) + updates_params, state = optimizer.update(grads, state) + params = optax.apply_updates(params, updates_params) + design_TO = jax.device_get(model.TopOpt.xPhys.aval.val) + frames.append(design_TO) + plt.figure() + plt.imshow(design_TO, cmap='Greys') + plt.show() + whole_time_epoch = time.perf_counter() - start_time_epoch + print(' It.:%5i Obj.:%11.4f Vol.:%7.3f Time.:%7.3f\n' % + (iter+1, losses[-1], jnp.mean(design_TO), whole_time_epoch)) + + return optimizer_result_dataset(jnp.array(losses), jnp.array(frames), save_intermediate_designs) + + +def train_TO_JAXopt(model, max_iterations, save_intermediate_designs=True, seed=1, **kwargs): + losses = [] + frames = [] + # Initialize parameters + init_params = model.init(set_random_seed(seed), None) + # batch_stats, params = init_params.pop("params") + params = init_params["params"] + # batch_stats = init_params["batch_stats"] + del init_params + + # Initialize solver + # @jax.jit + def loss_cal(params): + # all_params = {"params": params, "batch_stats": aux} + # model_out, net_state = model.apply(all_params, None) + all_params = {"params": params} + model_out, net_state = model.apply( + all_params, None, mutable="intermediates") + loss_out = model.loss(model_out) + # design_TO = model.TopOpt.x_calculation( + # model_out, volume_contraint=True, projection=True) + # losses.append(jax.device_get(loss_out)) + return loss_out, net_state["intermediates"] + + design_condition = model.TopOpt.design_condition + + def opt_solver_setting(design_condition, loss_fun, max_iterations): + if design_condition == 1: + learning_rate = 0.001 + opt_use = optax.adam(learning_rate) + solver = OptaxSolver( + fun=loss_fun, opt=opt_use, maxiter=max_iterations, has_aux=True) + elif design_condition == 2: + learning_rate = 0.000375 + elif design_condition == 6: + learning_rate = 0.001 + elif design_condition == 5: + learning_rate = 0.000375 + return solver + + solver = opt_solver_setting(design_condition, loss_cal, max_iterations) + + state = solver.init_state(params) + for iter in range(max_iterations): + start_time_epoch = time.time() + # params, state = solver.update( + # params=params, state=state, aux=batch_stats) + + params, state = solver.update(params=params, state=state) + batch_stats = state.aux + losses.append(jax.device_get(state.value)) # 放在这里应该更好一些 + # design_TO = jax.device_get(batch_stats._dict['model_out'])[-1] + # design_TO = model.TopOpt.x_calculation(design_TO, volume_constraint=True, projection=True) + design_TO = jax.device_get(model.TopOpt.xPhys.aval.val) + frames.append(design_TO) + plt.figure() + plt.imshow(design_TO, cmap='Greys') + plt.show() + end_time_epoch = time.time() + whole_time_epoch = end_time_epoch - start_time_epoch + print(' It.:%5i Obj.:%11.4f Vol.:%7.3f Time.:%7.3f\n' % + (iter+1, losses[-1], jnp.mean(design_TO), whole_time_epoch)) + + # frames = jax.device_get(batch_stats._dict['intermediates']['model_out']) + return optimizer_result_dataset(jnp.array(losses), jnp.array(frames), save_intermediate_designs) diff --git a/Linear_TO/__pycache__/TO_Cal.cpython-38.pyc b/Linear_TO/__pycache__/TO_Cal.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b37e534d59cc4f5d61cc20946ae9c0bb4407a81b GIT binary patch literal 8361 zcmds6S#uoMb?)xj2MZTa+_be6+riKz2+*>M#j+GKqF5kblrmsWQ4}CV9$RDi2Y zE$qE1x(RPz;;H23=w?ySHqUQS{c0dD;d{O2>EEgb)zGx0zMwX!jbq}d7~Q5esm)_T zbi2AvU61$E>IQWq-p{Bl>L$FORi9EfgJD<8-h@}8f_9Y5XIzIR`HhrFkK zPkT7LU&&g1X=9%}G&C3vcXsz{aciIK+&u3bU+C&A-pUtleN_B*ys+?|dupU~d8BZ0 z%o$&H&)x}lw8Gxg8R`|a%tq0``!%{xtB+r ztLL1%$DKRBFMfIgZDr@=3Cc5y&cp@h&PkNIUo4cyC&NDho!j?U7EhN(7o3Tch5Rk& z)X*ODx1s* zzIn1oCGbtgb3`$eq-JS`$+Eku56ce7`((3E$!PoDN@ixdT$EL(3F^-QuVlIf>QC;H{i1~H%GjkKI@zqEj&2q%zRz= z@)PI5S=KR?PCdGB>B`67qkE5@z*1sIH6x`YFk)Bdk-d9Y?%l+IeYq$`jsXXZHR z1F`K`I?-b&aa0#B-E!y0SC&7-+E`Vl+8(v;`PO^LSfeMSDKYIQmgpR$_I}0 z4VtlzgXmu7eOKC25-}~Iz1`XY)?*%+Iux>WVHQl}dspM>rF+mO=k`ggq;&odjHE3d zf^_@RhN>A6ukGjiX9s!`ie+gj``HT9Q)&!L%O?(ci{i>pb}S?$=YtVn09do84XMOgfS79gKBmEIp|Y0Z`JZ z-EZhAO)+9!ox8g_W3MTRw#=ZtuDqZM+B{8KPV|LMjZ&j1iGD#6r0={QNyMAJg2pqb zjYE^Gj|X2>hj(LBPI0YpPR|u)Ps)~(>MM>eJ6GpG zf|UHr`C7+a9n@2HF3r0)ud=4Gbk>r0$n9-Em48erz12|g4SS(v*~5m+V?U7d%21tqBM2{#|^xyyyq z(M_c)cp5;IsO6pGJ%cOA`{^X~K~&g|*c6X@kU{kOk~NVcaxm9f3F ztr3s09u;<|IzQXbX){se2}zEz1<8K=O%exT-FA_c#KG* z-4Q^p-0jX@2HIn(&cqex`g<$)CRgsgS2#Px=%q4k@r0gfRd&isEUBm1O94Y6M+)_s zIY&Y`@wt2MCaiqxsz$)t@4<35S@5AAHRUJ6p+S3Aq9o!qp2f@d(kyKc^xF0z>BxXb zK1SIVGiw;ZTIklB=w~(}5gNP|{gsTrMm|nDTX{nGAd^pJ<`9x`p@*6EvEZ=B5+>k~ zS>n88dFH)2Foak#>|IsvtCssLKRE|EpXyP)s!#P#`-lBGe}4c@gQNzgrT*Y>zzQK3 z*n*N#X*ie*szG1_-e7bnd%c1UI+3|uID_A%@}j8_dYl>I286M+0=`i9j@_*t(C@NT$-Ct4 z)`6We+xtj_1Ja+?dG5OK0k}KB)vs*nA7Q^9B{pqrZdFBuP#vjS-tU76)m6btKeTb! ztNN{@-Kv|G?99Gj8lQ2N;Wy2&X{?eWQfiW*7LpW+CzQGDAqmQP^5TR*(v$Oo z*WzIy3qeL%z>}Jiiui1Cki(6aRvB%sagQ)N+EBLCs@~IcLo|L!-N|4X1VDDE2aaEC zP`4BVS(Ch6L2SlvsJ&vkA8G8{ewIkbm4x|fTa`fXhJWAn`wJhv^Gazyh#ZRe?dCL@ zQzfBWgR!1$ij?eK&9HshOa?EHl}^}!q%xqx8MK??eOmlztQ+pqQO$}-Y)QrrQ0o29 zzCTX<;lEa1jf9yh+ERbQZZfTOCe~{xs;;FhTLScYKnjVTG6*tlkAV^WaPXC+$u=z# z;xyR~JdY;EPHNY{CKc;WXH%-(z^2ACQC~M56{23+ooKKk2z>K?O#K;zH{(mi-z3Nv znj5`a#jWBdalNlu42lNTuCo2K!-O9U&U5%6I42P20GKrVy1eu|X;>o6 znfIzN19hJK;N#XA^@`W7t5zTNV*djt$BohAu_ziwJ9w|0DDx)HW>*<6P4ZFK8Q(31$A( zwRw^1NdOMtz`;|s1~XbDz@^+bq#0=w{u>USKAbOKe#w>;NHtEE#(oS6Z!S;{TG)+x z>KI}pS`2n2xn?1DOMPXQf<@`n0xgV15&_j=apQMh$@n$OR!J`{jUqm5c`xI`5|~j& zjHXx0BZddyB_mCx9?J{o46Yv4kh#uzwg~`QADQj!Ay3&CFXlfhEzfd!rDzt;V(_nd z-kll;7CKXl6na^{)s6%FZSH*4!)5W?OU3*YtH=4vnVBq39xoifzp^}CxW43mz2r>J zb3(iv+QF-sL_Jx#a4e@ij~!DISr!R$zWx`A{2@}u&XrY7lbvhFt8h-GbxseCW^#|A zX==L1qy;XXL4*1^vWZlNszseleMoy{&1eYgU}SmiG?=zmY4flY4^gd}*qRnH8AdA zMA==wdDXc(#zQW$bN?%FuyA?Kz4_VJO zU*#cxuL>l;K?X<&5J8G-31B47hW&H~6ydwUZ=3cG`z(~fmlJcoGSv84o(!np3Ls!y zfto-rkn;m-_-W9aov#?uKLcty7?1ytz=6?Pv)?3uYp~?;O8l2E#Ktcm16?p+UQ#K$ zNij`|rDjw5JK2ht$84POOEk1YFnM%(>rlDu|05N8(<#m5@*Xlsri|YriEeyY`=cQy z>};q8C6TCv`i&^{dW(|ZQSu0ZUlr2r_*b;zvCeht#d3TE5l+8Urw4Jg!6B8ouS^-p z$i^FxRGub$kZ)0ERVnzHlW_Y%_atEmG>#`Ze?C0iU`Z6!D%S;3*N_hkH|83dRZYV+ z{vMKB2?Rq3|GYqXy;)=94C0K~3o`DN1^3R3yO77}zE(l8FcS0!ckAvV)-Pax2LQ_GkkXhdeTf%eWN31?f0gnEM0+%5-Nod5JYMV}BpZ7#jD~%}m$9$G7HQLM5=Ot8}MPrBGNJ!v$TPjUn^?3x$PB`Mp#1?qR3A zm)|_OQ>O|iuJf%Z$J?2_Bz)@zj*zHqw7PyI-E{=>AFTDQ{12V7y zbd%(c-Y6_zMA+K7Is0T|gIE#T^8h0c81pchjDR;xHynY&(ueNHW8jWEccyaPL$@Nh z3^?@qfv5)w%XD`4M{&$bW)ixRiffT3V;jWV_AzO(L+nMeVFxSRx4pz)TqCYxpS|HB z9@!o+!VZ*o+m?=MwvR2<_EKBhgH^1tQFaJrbP;t-TfsC;x_l5eH}hyFzmbr_l^pIV z3={PcT>Si%9f-x$bUYTbeeCAcAeR@XNJLdaNZNQIg;t~MDR0L0%wQYcASwyl$8xrx zXW6kM8lYAQ3d$fcULtCK6d@?YBQznzgyL2@oiGhDnb8*P>=8ph25TNWkhRi;c~MW- zkpqT?YNLk;go!4O+Wu@tMF-4Nx7bY!50FvJmIYtfp_Cn6ICzqwHVCEI!s|A# zcGM02ESXlb3GG!ndYgNYcpHK+@{Qs~{OLzI$~SuO|3R=NxFPs_@TuVXU?{k~Q4DWw G6#o|)#2E%xK5L7ETuD9J2Q2*f1yn06GLV{l5-h=33-8uOsa`_ ziN9l$VqO-*GG$WD$0J)MzX_C?b|!FZ)hkxGBh%z>-!7S?u2FU%g)Ws>iV(qs#$dRWgDM-Hg>Fe?N;-{ zcUFG&-pbD`|P4Vyo;*t-M!qLztOmGZh7$*Dwf|pUS5gsea$7Mg} z%f8~PUd)sG1uw>=ErKT#>=#%p%%eOii3yV6$9>IDxH9$d6i@Zyo;I%@SA5;?@{@kb zPkUWV+A4TS4pSVa3G(^6{fyt^_qq!8pO^eTm(Pt+glc?V#yI`{fIsMO@wa*+lRXi$ zOSGGOE}^RNc`>J#K#v6Wmev#Zws>3d){(l9l1u|kdAir-CB2lF_PV`{*XQ+ky(mw! z#QP8X+q`YqSN-U6$hUucu1Q>MN&w}n|M1vK*RGyt%X|3q_s`i>SN_kN|NeJJ#$S3a z>Z0w>Jsc>r)^xobB>2mitki>;ZA}O9DXZc*Ry`2Qfo|4Jqh{92?nICv{M8Brnj+Ml zs1#4sO~(o3x>=jya5hMcS#F_bpg{|*TdajLs}#gfTD6j0C6bBubSTlW|lonKN!)nKJ8zT97PNYOZNV zl&8zqc%fFbESot)GqZtqgh=Hi0z>a31vik77?r~WtA_(Ms*blR^$0Qcj0O^#NobUb zNGpQKid(Vjxo$fRw%gqVx(Flzf&o5prH+ee6za^Fvdu#;90>Zxj*n82e19wR_^HRz5A#MT+(mQ@SVxaBRX=@|4-er@|uWN)F~ zTdB7?Rj><_&P~A{LiRsi+&y7Un!6!=cIWfIZyY=}`oiuP>L-ow9QeuZF>A`IS>;*d z*pyqDtjs~iSoPiCt<=qeZHyh;4JMA7C1iI^VPDajk4&p_P5}s_jNf)q5~Wp1ObflM zx&$0@A^NSRrGC_}+B-YRRgZoTC|-C>V_h9VKt z`Lq=WE#DOKA`?QT0r4G!xR;r72Fimanabk86sEBR>Qtsf1f+t2O}SDeQP6q_r!&B% z>IFyx4|2xl6-E4&xkZ+k}#GTtO=tiQXy_B%z_9 zSAw_EIqc#v$ziJ93lh+P#4{lAj3iPDz0$l*H;0)vHEGm#BW2K{hu8IT*w=2+gIdT& zqtESM7=Vl%6< z7*8;PH1pB$a>lVcc7g!EgZxUa-Bo_u`9k!;d!&wVS93wGK04bu~-xTKJH`7O2Z+B$BfR3aS$$?w{gnLs*W@ zwR|d)+>GDqpp`vg+9cOo*;ma z;tu1tx|M|1=K0mkS{-nRI|08{HN;kvcRDQ;zHXlTa_yje7eT0;DB# zs(qC3-2}b|5QsB*HQcyB;~QfPk8a$K?y$nO3w5Vt*^{&(PCtMEzG-5bb5{a~C2)A| z(arg|VT^U$q}xb$n{%@S1l)ne3oDkSGs!U#ypk?~udcKJixy?XK}g#TeC@CvQI}fm z@?qmvAvvp%oYf5Z7MDma^>{LPt{k;#Q|*@k0`U+wK!G)6+PPkPl*)fZfH)i^EbahQ zPMRHZVNDVppoY&Ape2OH!EX_Eh`{#=aQlagGCI^bWy40;KSXIh20vp1_w76?Do74l z2?EY)S{w||Go5LiYFQc-=N{ag>us{vc^4%3XyGp+Ij;kNH=X?EJU=1ZqbLX8PY5lH zOKhB4+7~wl-vaf6!n`O7_2+2}BpLnAQ8|)=o*c&*N(=kyF{l@IMUHkw24QpCv*GE? zR_aAq!4CX2*UIjD^r>BYN;;173ut8j9sycRpiJ47x@&)j3VxTs5dvgUtce%<2gtsH z6rJHD_?d)!$w0!;J1u_I8QyUDj^@Ne^pBIksuw0r!w7W4n6%h5tW_gO8pbQrg<4o+ zA0s08iJ~3lM060wNhINC?*SnC9WqokNgLGTy2eY3Bz0O3skio>abe52!h(ndzSGTt z<|Z7L4dJBiebtY{)06rIpZqV*f_CL|7fQP3 zML+4P*moryF`Q#w_ET=$)w~q>?v&;gc$wmS%uh2KFFlW0hyuSGGbFq;lNTkzkaUbD zvKa0-r;#aTc~PX^m>Zs}pW!{ancC4q!5o}&>K9@0Hm0;O%n0L9Dr@UT?7@g#9zjGA zh$jmvNa)^`YjOG3wdMEjG|t{`p1+c9T)K+V#+BDrUOT;f`m^OP{-kl`)8^vc#@#FT?%cii z`Q^r^AFZ6f-MDf)i_E<{uOV~q&L8tZ!(ahCv07mq<1StO`O@vH6l%br_wJmEhCX+8 z`QpvxH_p*)ORrzVcb5} z-dz6j&83Ae8W-L|IHOfNi`wP8zuc@AzD(pDqfJOw^Xjh}cR%8Nsh0c8?aIcjFI$yu zH@RK8bn0`gl=loP;gzB?}Rge!2YlnUy~~8!>MA;%5kAMT|Rr0}ET(WR~jx%Pd@%H&^l4 zq&i^bla6Z=-ZJ}n0m;=S$Nmd#(X2o%s34cIe+REXn}p$2wks@9abfCZ7~&I%F~xHkpUMBMER&OVloK(AiCmK5EW3u!Q>M*q8f*3 z{&=ZYa9y*G=^b}=3Qe#{dq#4DUmv$X5M~HLn4x@X-T7aio#lazeD}I0vt#*WXF)!W zGi{b_me&ZWfs8y4`~1i(ZWjN@tKmV)tKn96C^oLsc3@S1h~$usBn&(Pt1$yzE3Imh z3LN&#TKEBWIYJJwYBHiOK7x0zI6_uFY|BBesgt5j;}vORTMxyQNIM=xa$ErN>`M~t zZ)inmJLnxIouuVve+fj*C6p_PjZ#23OiAPUR2<%YM#mz ziU=jw2HN!fFs=B-f{0$Z^;ZztjqBHUH-Gv;I}_@K#+B2M zaA63!`Qe!;5KL<1Hz%UMR?e=4u1L_fC~v&=dME86nakI2LJa79nO`M<*M{SfFpNn7 zvBsAlHomw*g6@Aa0dZq}01d-5u6z=8=32EwQ~X+N4XxLR-`F}LurspSxp2@JNoB6B zmwy4muksbIh3_`>U7PFr^{?(8Mj?KT6KekiAP1L&(G>?GJC!+e4Dx6i6+x=tIMXmG zr|Xqhrp;U;kl@uP$8D>?inxx3YXvbhuq_(OB)~-%Ns1@;23;gUKq^H6f`ntb4AF&Q zmP>X>5Q=;7C~BaDZstxZ_%?xE0Qh#&jzsfKglCxVD%5JNNOXsnxxv7G0^M&SIS&B{ zy>!`iQG=tM2`@T##9OgORkEr5ItK;jEFy5xb=HBvHNZt)q)?GKO|crj7m><+;^pk{ zXqHRSNG(Cm!eGKKEM12Ea*a0zV<*s&GYp@vv%@cEN4}QmY((IWeW9UY8yF;1Q_)JeUrAzh z!y4=|%KjFqMKR^qiA!?0<4Bz1CxMWdBzn&XE*yx^FrSx3;o4CU5Kgqr-EORh@TyFZ zT(#+w=-j!oSyq|cnVkXbSYPAT2g{d#9@8)Wt;kqrLtsxCguAM^%yNpeES@YpzyeYi`QrB3NlLBeiP#{*; zifmF+f&NUb;va6CXKGflP;;II(&`rOj>2!EJxOCK z+;tX zY>D&Fxdh3>V2w>IJ)N!id;mct{u(g-6~8mHkJ<6=+87s1=!a37rmWif6+$I9N`f$aC=o`*mD+8q>11cT9(#9Z zH8V@RXjUTFU_)h08wlp5Vv2|z3EMPj3Ivz@3H>~uD%G=(FRiM6t16}E+`C@yx`5g$ zv)cPO_dM>obAI=n`B5w;5qQe4-c$eDMaW-p()gbwm=HwHqBOlCL^1$nI zU;p6i-y7`fmk0U=`XuRXwcw1I)}VaqR8f+K-yPRe&YbT)Z13j{TbE&#ERk^%WzWSM5JAJs$-<&ld5H??`Cw_e|)(%GXutQg>1eEJR6VBOEvJH97}Lscr;_G z&Vl_f6hJq0p_G{$HPT*uT%A<(9O_Ue59}XF2`E#4LeO`@{}B9bnJ2_0F6Z(B;Yo>< zxRT(I+nmO^@V@twZlVb=3hX7ji57%Wh|?cG4$wy> zm|n)qo_|z=?PuG5LE@SIQ3>Yr zw*7*{GyO}8Y0=@e_Bp4C`u`yV@q-tB`d#Kt}Xl?Of8kC=jH0@sp_>~`Lk;y!?mj`wL1%P z`Tkx1?sWNgivaQZDzyvS_MgVU%XWlw`PBOT8nPY zzIa+g%+)w-0As|WBfM2>L2Sa%EOTB^bBcL=S@8H=-V3Uhr4}t@PlUW6JxwoMKRYkh za#$WW=5Z6rSj#yq6Kf9$|8uB!%*^V&aP9RblRr@o4G$geJ)AqK{NUh^dq>Q?nK9Ev zWjOB`S>qJkIA*T*9V4f!mNGKjJ2I>sIef4uU-Wio7;(!)R&^|6vKvg;8E8nfgX`lY zB9Rb@e;MM3xNa^4WB7f^^Iy^UAACFicd^6T3-(fYewz=zi5>8#_tdsmpKjQJxlBT~8COp{GHsqbt;eu5jTn z#0Cws7o*EE3mE9D=MWr-J`9MAo=ZDp4FJ@xJZ)Zm&(N9{{gw6R#s3d@nx3-_r`QTR zx+ey^eH#x?S7$!0ehQgxo{zBZsJ`_dys*kF z(F3Wy0%2ew3@O~KHXLD{W3^+0(E`8{;2=F|>+rk?j_FZZyMQH0*G`(5LRL?!aF6Bo^hq&WRikDHVQzg|BG5i3J&Kq-gzy*5JW>j(DJ0$iV=s zwP{k4Fy^6ebK6{j<#{(a6)Z(D$Gf5%a-)kvDdxriFF`GMmLG-e!bWZY$ECO%b_3A2 zyOP`P#;5pF2bGPu0k;EWcepzsLxc)xxV8R^7=CVspbJXYz#QQ}dB73@y!qx-B2xns z+pl}P>o*?Gv75Y^pR!LLA1m5Q{}AZ4$JD$&_|p;7Q8O67Ki}6sl#u=TGaC;tZmS0| z=4ei64n#Dxuln>_<)b-&{SKyzg;oFQ+3LObfjvR3jas^{nWLD?8X94-I`eVm@`LJy zGnF%sC|@3`vkWp#EmzjB`!{YwL^JmcZ4!@wCNeNs`Ft7*%F2@u{rNvI3;Lw9-50P} z_XUd33pXgAtsGVImagfeaEoitkobPP}+4*1R46ffwsrsOB|n2B1j_AYpwIM-dunsM>lJ=OHG$-9-Uh z_!+F*c+-y8-V`@;5 zd&0=^gPw$%*te*KC7Ty2JaXNi16}npDL%viKhLHxcDF)|!Dfk&Dt_ NC$w}dh%bvT{~MKDKSuxn literal 0 HcmV?d00001 diff --git a/Linear_TO/__pycache__/TO_Model.cpython-38.pyc b/Linear_TO/__pycache__/TO_Model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55fd58ef27c791b20bf31746bbc977711a5aef70 GIT binary patch literal 9461 zcmb_iS#%rMd7cfx00cppq!!9n46hQrkdnNl)=?B~wtKQrt}Um@B<^&?dr1Kr0Mwm< zEV1CJ5>l3&n3faAOJ&y~RmX`MZ)xjRk?i!XFFk$kIrCb96i-i2`qZA|7xnx90l-DF zT&IO}@7(*}J9qv6?_Y*bcXwwr{2nj9_Q6O)-_FNjwW z)2Nt^rN^{HCF!JEWxJAgGF&!&z0&1$wdmPOx07p?dn(U&aJv8Z=EfOl(6Rv(9=sK)+Mu#Z7qLC3;04o_9a4iGHzWS`%wzuDof=bT)&sR;&YM zU5io#WxW^xWuQg52b2xsW>9W!QSJrh7O@eOjq+ZMa-Y0keCe?9lCEaOt)hUL2E}dS zc07l~9pX+rw@7PttGG+tE$$vO=X7VA9F#+H%b4jrAU283)0%uxJ}{;4)24LiA@qG% zK7?LhmVY3M@}Vj06s3FQm#HL-r?ld|^?${EW%erRg;Er1)xv?gEbJc%3tJy4Y}vA% z{y$vUwq@JaOlFr?4-eJU_QJ%(@l0m_!Ll3)w-7A6decs zQdL%)XFj-m>F3QeKe%#Y>USTX%@hg+UwW!qs*btK-WeP}zM_c+aIucMFR5t}ThoS% zdX!XNRn#gdrN%uKmOQ^`MCR`5@yOom`M!71m(**dRvj&kMY%FYkyWC)#~yz0zzCSR z^D|H~k@g}58;~`?=!es4ZK%!a4Q)=J)EoL!pgpCX*N2OSq8Fmn6Hg95{&jgA(^`R) zBC=^k?qPZSyr#0~@#$kjhiVl$grgmDobR}M_wRdj=+Wv?7w0x~pf+CfYh%Ytvd7#9F(Wr1N(;#?03IOhmSo+J%J3}c)Kpmw2{`CM5gJsp3zkfr3I?eZ^Vk9 zoRjcU;+^%Pz&p|#X2X~>C#{ABcG2ti;4tXS3s+A4b@TnRmw)p-ZfX9N51Xg|viXZ2 zHD7qOdFm5X3(fc6T6popg`b^YeEp;5+^?D^Pg187AM?9Uc*?8DP^y5d_ji6=Rk&oF zOksP0tHB{}T%xk@#{2WLr{WR8ZcOG5B0SC=Fv0u``h48{_)OuzAz2V|)T{epp%fI# z)$zemzvhJxKRC!Q9rb)L`u4)|wXJ?>%ixy54EKIghILh4v0A3r>LHEM+51Z`VrB6X zxt-VyQ%BZl-S>-$DAn3ql;Gt@30`}Y;MGJXhGTAsvu@l$lmk-+p;rxEPbu&C9iRn# zYq^}E)}gk7J6GJ{6KHK{v)Y_C$pAd9h5D?~(C747b50|Om^4CbHW4PvDcrF!X@<5i zg@u0BY#J1X5%jVe2Drc)HNgRi`q4HA%%6FwdFIuHPky{`>CKo4c+WS!T*#kS3EX-> zhyR$WO~~rL65wX>%@bGN`F?B2v(xc-g&Uu~JQgZ%1i0nKqhdrsW5uI1PhG@ivJ@V-Kuc37&a^xIT4P4Y%sOKtPtvY|2V{0C<)UpTRFsWm~H zAq%S~+Kl0f*O{tCUp-k^3q9Yh)+@wTELig=E`hOmK`XABK!~pUDWu>QWF=hoGyUH) zCULC-*G?DRFkUu}Kyc6NpXx*B^~fsMO4ajbWDIVJ^uv*UbTzwWcF;CQYL)Sw`)FW7 zj-;N{Y;9t5>$nMErQbU8!5w~W#PfrlgG(>2SSG~gHLSwoRziTDrO?PUL93ZPRPg{ z_3Kd<8Y2v4)umlWiBc#lL6j(m=n~nKyA^L-(gRhiUtSuLy;5jySmkIXe< z62|md6dF3?fk6gY3lAXY9Xv3B7mN+h)nY9xU#$@CT9 z*e#I>ZSLe&UaHgG67(20I@KUlbs&#kT6G&y|MgD(>zE_B9hsK3%~L5gtk_z?G4O_PF0IQsJ6$_izVh3b zVJj7yA3oPS`S({oc^B7r`NBm$?jc{O4OJjDb9F;|rmA-le14hYsb+A|LHb2DT!4`352knCQ+|G^4bHfh83@URLeE8+f6IisT@& zlfvAg$y6DNRV1d#t4|A3lc^{>TvHX#hxEY(tRot;$}Q_O6HAf9Ha7DgAI<*vKYh3} zO3Cq{|U6jSjvSQY722|X2@Chu5e!kdp0 zDmbD-^(bD2tp(@VABCMp7~y32;nG`9Ixf?@NqA9cUJa%vg=UH2r0if_lvGhnXO zC)RXM3x_$_?m0{bqw7}TtGaHKaa*AUl(VjTr0)5xmXzy?+KB7&Dx(DZnNg4Hdeth- zdE)f|?-76~%2qReg~+}h|UA_a@YrQm;@j9kGDbX6U0AdN$gWPJB90QIttkN~&= zHx6adudRcw$8ATphoQ;ES}_=1J`Kk90GX}OHyDJ5n<4#T4yzte&+BRp9z~t~+VgOo z;VrY>z3ik|+Uy;^S2-xWoj*l?J}E;t;0YL4C+=F>T<0E_V?q2LKp?>iIUfEQn$}eN z@lfP7t3RQPFJla?8j z6DTL0F7A~=uWY$n*c1*GT{EVWlRfB@1s#Fc*4y29Bj@z;8$D=AmHXoGC|Imt^x{7I zqRbvID2?oaZ+JM$D@itj%S`6hsxs<9Oyj5_ z*U%cTR4s+T4HgOYJ(?*SSK>vlU}jP`DWttf!4HweVKO){kc2V?QGkXNFcDCIkPqOZ zFb1^#b{S?B$_TRzXrmS(rbvk7jD@$g+0>lYFgPq@m*Ga0Ge}*OM$uZ%iqs5(anz1l zH&Ra6GeBXcP&79^%qor{S3QQcEndZsAU7EwW0XcD3E@vjC`7;>Mbx*Bb5=h3c21X0 z1G30M(bIxvM4%4EF0fH9A)d-tsz&jeDpPidGU95-qCkrpqdJYM6lDQrAu=f0B`NzO z5OC5A*DjlT*OR%o(TnX>4Q6c;Zf;gfUNv-h{?M#h7`&EdQT?i=8<5CeeWGuP`#Kmf zTwJqSCIu~u#c>~%I$Ma6h^8Wh%sYe2AFB9#5HTo4d8n9@SA|PN{ET>Xyg|ZCMdG=% z)z6V{iOzq4#Ic(UwqffzJ&S+X218j6W&=w-^}I&vX~JFf`W|H0NMIyEW+T=RxwbbFdofY29hHddD+GjA zMrkin(2I<`3oxrbrcW6pB<*;Q!A22J)dF9~j@9^q{Z105A(pd>AtH;6uz#SDNe(p} z(L{dPx(BIwxY)h=s;P5))wCVOH%|_dnx>4fg<~h`B$b|{Yzh3AG{6)ZevQP|16>(i zSJqHOA3<92qbYYG1%HEV>4M0%Yinm=0+LBn<6=a@Nl6>!B+4m(Gtq6juyU7^~Dxe@O^=m5hGHRI7n)0cm)p=N05JmEos50$WCiy z{;tC-F=PYepA|8E`p38Q$}(Er^~*${=-(Sx^*sV>XLg zw@3|OWWYF{b5I%dPNyx|YQyuIHoMh^=QVBVRvVtzv}Iatc(&WZo^l@Qs7qvL5b?w8 zH;HbMYglupuJJdYp5aKJSaCWA4unpKy}w=*VtWod(aVS|vQ1|UD~i)Vu<0WO3%Qdz zTJuK{u~&Zz5;P2Jf+g`f2rg)!i0zVF$>sHRjX=UP9EEyg_xGGUfq_2*o_` zI7*8Q0d)$r({0lF!B64^f^1*E(*BTu@|Bzx3m4Y(d2gRR-niKGv|U>RYg zc$A5rRR>?h=ggZ|11oQr;gl6AZ=4Q9@p1Ek#d~I_=z(nv1?mOr@yC?$8S_c;`D`H? z9~T)5N>av0NY|jqFZvv9?`#T!FNnjflx-cC35jswPF z$&b>VqS{O&68AV6`cs7XZyBNlc{J4+N!BgV2U%-!1V;r?c6p6K#_8q0_?U4NAIhZ* zzr}aDS5!ID?X2N`BQ+%lNBM(l{58uWiqo}>q#okU;_yboTFnm>T}hP0=X0`SSz_se zq6}>iS`~i`p-2lx$@)0{>O%4#4@9Zfck~$8htSTk=Gd5H5#2{LHilV%r>IKe7J~;d zy&{HG#4M2ok1$ZbKry~niq@&!)RK&EP3@tIO%=jJIPQx4=q0A-5_7)210W8|eJ9-k yTgFzlWu>vL){C4{o@%C@%=g)A?Az?V-h6MScYSX*zd2*r7XBqDtlo>1(f>mMx&Bpp%p&wR zOwJ2nXdJYB1B{FcNM>YKV|2C!=Wu*;$l~xdMZbFlj52qsZ7a z@T_a#3CqTrTXo6;uG?yK zTDzT(tzh4${;%L|yZhVkG`E9VwO6lp_SR29VKjvS3hK4foPP4@o5Pao2$>Y7^=7xd z*4b$WyFYQSS_`%h01E0=`sE${zDnCKyPr3p7Sqj*Fz=6{TOVk96VTvCv**7OZbF~K zqm$#;y|pIY+Uzu1w6zEG%Z}uzMZ754sx+r5m>%!+d>@7xu5MMJ=H#d*^ng4xlVnH) z8W#C5JCg+X80aKu`5qVt?K7aP7xTDl$YUF4(GalRU{`pL@giRI3i6Q045J}*6?yy+ zTH(WU9^YVA;$9p!z|Qnw11K4Rk%mPgWmhLLU2=3+OD0IPV!FmWN%Wnls|I0TC?vl0 zXrWRTN2ds{$ma8hl0%}FVn`a{WVLJ&wy2PZIafMB_C@IZZz{cD>Pi~+mM#>YiIcgh z>GZU*DBhoVn$DY+shRT?F=sidt}X*frjdT28j56#`CK}m6JgaUt3rgLsau*V8AauC zNJ1P8VlXqz2#g&fNuKbcD4Ru5B#Dx=lFaz9w<^&26 literal 0 HcmV?d00001 diff --git a/Linear_TO/__pycache__/TO_Private_Autodiff.cpython-38.pyc b/Linear_TO/__pycache__/TO_Private_Autodiff.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee8b5b1a2240181a54fa8c0251baac013e5d66eb GIT binary patch literal 3016 zcmb7G-EUk+6`z^=v3KvTKa<*V5@3r^VL`Ry(4b<{G#@Ua2-$_W0C+qx?iY-C3wPgJTXAlc#p45CtBs)iT-2Oj4Mi8P6M<&6i*IkW4moe~hUnloq4 zoO@>GeEeqL92s#5w8ek@I`~70kiX&Ncr#&g6}tQj5KcV8DQ7WFm}gMrjo3^q&(gdJ zyzSYVw}3Br1cSpkeb7tue*N`NTE_of$sOC;6#u>Ye4M`E%REI~UIE zoe$53=k^)wz;2T2na=m&V(twQ23g1}X{+*05mlCJS!MoGW&Xm#OBWU{yF4et zm?leAs@z&fW@+L#yX|AMZnJ$rmDvyh2Wg=lh%5@?Dm!HG*`Wah#AB_r#lfB*h(@cE z$QnlEs1<>+px=NlZvn|kZh%MWPg$MKV0?hf4di=fPM^7EWu!@SGq*V7tiiZ(&$z{U zcCV1z9KUsP2Yh9AuD}*<1=_>C`w#EEzq0)BgP%P9?QiNL40&3Y;D^Isy^XRz|MKzv z&mMgAHuwaD58t^D68+(lpkM>8JtjM=GR*wd4L@q}aJ|YDld&?xsIgWpDkljt5v}{z zo^aPyaV5=y*oU2zbF`%zIa2gd^kjOyKP`c|u+EhO!6@eFl@8!4K)Qv2t^P_n<4jyyPkpVBu|1V#fB2oN&mUecE#~bI;0M$e`lB(Q|jm4p`R1Z&uwuM>3C4dwCz~()=TKl_&R&kwS`f=Cv%Fg^h%%@P z^tv*dt+ui@Ly<~lHq*9LMV+yJz_~JHC#jBzanznb?Lu2b37An%7PMfuOo~xl97BT9 z7I;-L1;jgjEFm!mbO%c{Qbor|Ki=-uN+6y0yV|oTs$Qfg)(ck2+ zuj)fUye)z51Z1FfMb}kvHOvC2xux%^Qk>oig}<71T3nSn?e>#NF-v18yjCq#qe-yt zM;XkpW~HK$q8okxVcGZnH?kVF0~u>mY8LVKa-AAo7q%+4P>A4gabm2 z!!G~>p1?Z-_9Ci+R|l6fQ-k$e@1 zGSW`wmD+)newO;HcLpgiyu!qZEBN(0{B`(FVhR%KYKVa-Gna~&K~lrgJ8B3oHgpX! zF9DO&TgF`k7Xu4^hwT`3e6AT3y#W_vTNEA@6G5l{tJ^wsH7xR@hQ_dC(o1|kv?-&G zuM`@hebX>ExCwIxo`bi07mSzCi!kmQTlAbp0x)kaoiLAE;e)8nW6AE)ob0k)vSke1 zt78m4M!!~#zGtleuK@ANQLEj_{CR(fQI~O1W7Ja#3ch^<)+AOBBq*CS9V%W0zRJXx zVGv(=n*WR{_yF-W-&e(Eus#cair|g4^dRsB0l*H{1wJeSGf!Y%s{*Pbb{jwh0->ix zaj>ooB&->$BKQIR9Hg-n=m}MpwWz%@8v`oD4Q2O7{W2UaN%+$yLuC&}CZ&G5mCeIC zyeS$;b$RP@(-+fg7#|((A<^D(E+;AP#Nib@zCWTY*rOR3a7a3Td@?#?HF?;m)qp zB-)jbQjkC@cPdbuBbVN}aO8hrPDoI-zHx<1d2j8y0cJG6dGluG?aaL2dwZi^cLo$o;{m)V$G{n zvKF~<-K$f=>yH{tXU35Ndh;{ITkx9P7%qNDn8oaSgxQD6A@xo$ht=*8?<8|s9r_ki zJ|UgP_zE1wxh!~)bJk9i_V$>!FO3K7<>mJB>V=i13m~j4tt>gtMljBIQgNYuaIo(< zPaZ$`4jR<}Efps;I#(8K-5iwc7cXxX;_oOl!-Ou0e0$~8tWPJ3n6tFvC4^*Z$W(14Z3 zTr%g9HJ9v4GGU!Fmuhp#1<8WB`drhvNZ0~ATT|NG>2;C9MPf;>N&2bO5+3bUxG$TG zXW=06;o2<3$B+((D@QyTg%=^;y{X63zn$4(TlFJ3( z50WI!13X;@Ym{Up&3=>y`P!;96BY;&>`T4U{7A%W;DZYO{;0c?#=Hw0=z88Q|MJ$A ztKF-~uK)3+o89enlt$@b-$x3>;X!4+do4_OApGsE?)H{Hd$jYTePM#lQ|RDx0FGkO z7uBK1t3|7o|Ki6Y>{x(y8Mprez-URBQbNd-3e!rJ=_UBSP?-J-X^{fP*q3@i3gs** zje-^`^d>V)i&>?elSj1Bk7(%>I@3>+vIZKZ(zwhn+yYnu8ntjBjEoodnEs`DhQO{m z%xQ&F)Q^<1!D{6~PKJ7+7Ro>Y$zbk%rEKOVe4Q2w=(WNq%)%<{qFyw>`a;n>Yyhw7 zn}Y}6VVSJYgkRSu*WrGXw*iEKk^=A)8k9fR2UFp<^~rN6RW!|1)&>Ng&hq^c?^xn2 ztdm9{f@Hv@e!xYVNqy>LJL5SE`~77mwG?P5bs)&vYW3)yMFDw?2pa)&o2VhU03Jj@ z9z;G+Q`Aw}KtQL6CV|sWEQSiuki;65G5P&R@cY!D>5JfAs zoa7XotsT)riWC{Dh5jm%gbcMxGSOW~BcPB~D7pT~fQAmFX9?4flp_i=cI8wXwdi7I zd~X7s9B`|-Ir##jflI}=)7c)tcaX;$?5XoT@=CU=_ zq~m-v&i&pFgfWlg$*DF>jw@ZrKfACCbxU$_Jo2+39!2n~Q-i>ZGx)`d9NI{vZ0cTg zv{Z+X5wz@%$fX$vdtp3|J2h!Qh(mCTcVUS)I~*QsYEZ19x%Y6pn^)`*kO+Zo%cx&L zSVefJS^@cqNB-@*ff!`}r9ta|H1Ib!*I`?Rf$)+EjVe{RRH3csnyNrQuYVb~z#W}a z()59*J>U0C_yL_yN10R+v>~FNjAO_~)g#nckO#de$TEln17Zi>N?^bV6+qg!eY*U8 z96LCk{@Vop+odrcfdwwk&fNAN~~?%Yh9cs><{pJ*=YdHax0oQ Pf!a_NOVwy|(K`JP6oeqd literal 0 HcmV?d00001 diff --git a/Linear_TO/__pycache__/TO_Train.cpython-38.pyc b/Linear_TO/__pycache__/TO_Train.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed324ffb0c0f3fe4028f756bce8998f5166b7d23 GIT binary patch literal 4210 zcmb_fNpBp-74Cg{dKM0MQr^bS0*&pErYvua<0Y2mEhZIf62uJ~jpkGjIisHLNp+7* z65WOoDzJe-_94h6>o7o!=n%(1ZgztFha3yJ#fqe3u0~?zd(|{W*@+K9W>BwQRlVBY z^406_)M}1`-__&q`tQyw%HODQ^k<;)1WNK(5T>{aQ<)a3u8LX@wSn&HD)kwmF)&?I z@>*!QmXs^GC0W~pvg^n(Gb{}%ZbkA|SRK^d+QFQp*$9cjZrCTvyZSL;wIL`V;Kj@^Z%_m3g^|hC-{Pe=r_kMTf!@bLI{^{>O zdEZ%CYF_@;2OoX#_UX{yzUTQTzZVbF<@I6eZ(n}zmsj3?-#Gz+AN~OXDffHP)jz*^ z<+ne=BE{(CcYgoThrhY}&JV7>vA5cEn$2by_&n-Gn;zEjw>>!t&86YaO`~1x<_-MF z-z3?g&;3F2A7zd%a19zz)>^7C&iLsT>QWeQ(gI*D?nf*hfUC&UpN+%w0Vf$Op7#8Z zWIA!w?QOdCJ~j^`ecGV>MIu)m+Wgv*?3C z)wPz!$$#C_nU~g{c_G+oX~IZ?fVno^#XA@5TvB)$5+9%b-_N()UiaE$IcQf*BthyG zit=cuyXc%2h@#r6qw)&4(;d3Or#}-+2}R!M4KRl{(Nk@#Gu$fE7nPqWztG6NGmWpM zM&D#wrgSx?XDX=ibB&ppXQib~*;Dt`p31C?T5d#3M`TNxN-KSld}csz_S$asbT9J5 zW*8^QY7=%fk~A2$bY8{zgt{ZtZBiu-3yTFwZ!?lZg$Kl(CD42gvb+w`vW4D{hN2w7 z8~Ns-7jaU6FnUop7S^`UxxXVyr~K51_JjpjjWz?Jhe6abg+;pr*ry8*M!lH$=L4@7 zr2!uVtcO_jiW4MmWg|$^gO<=(B28Nm0y(jOh#?0HxW|KJ6hgMx^EgJ6CWxXr>WpgO zZ>iQNhTiy>Z4&QD-ZJ=Ykhw+<*&)aJ^tS-UMj`7a&trH~$xRgNEoH2v(E6UrlzlBP zWu>vkV4qjC+|J9dm)=tPdf&*M%wgI&jX%ruypoz@D+83})u@Vo1O07g=9M>TbRcGWE1-&F^e_K@kCb6W^vv!$;mS|`jdUg=X04gu4Z#IFP$FG z>?>6FXR|rxu*yY1F?#1@uiQPGTu}1)be>hS2Gi!1%w)AiB?Va18rgi;n1jEn%J<($ zkD>RL%-mOHZDCZ~)8UQvHDML##czjtul)1UKmPLK#+k>j5fohc`1SrD(;q+l_)Q=| zAw^{xB{=~?LDT~@;3`T6Q!oV1MIG1GavdhZR=Apx8JW&B3XrqPC1Z`>fthaYsHY_@ z)h!K1DdE~BO_=_081C>1$Xzm}xN!V1JOEp_JcwBkdhsZA%MoB9N#U%nP5w-smsDX& zi6b6b40inRSvoqo6ye4UkYl$gaT<6?FzIhEJsQT4OCDR2>w#?}`bt}V;`yU(VLroy zodjDLoq<36_}=Do>C);~9_TgkR)JYrSvuKmem4%mKeTjzx7qFuf&<>Ub{)UgtrVZV zR3etphw)HY0BY+-JfpV%k9b z1)@3(co)-06vT-4;V_7pa9D7@*9p8$!hq2bp&7U}sbQ~ZCBoc3v$d08AFM8PNroY$ z@#KtnAgtaX*@|BSS`jse9Jb?A(4jB}fgj1>h0zFZS_`~I1N8wOsveFZADx8Xf=*6) z7uLDg09FZKq#k-{M1}P6sc|bCe%je8Y${F9Pm)aS@O8$DfhZp$4!;}2ZdDqKG~#uk zHt7-2GpV7022VXwkQWT&&enma&@GQ_mbfNOPu%)}QVtFz$_G|yuZzZOTQMHq8%q@) zcARb^fv|qFFFk^RX_L z^Rt+L6DrQ2gH$O|ap4P5k*SMHZogiV7>D2l#+4DQT3^qBfXuw0yG9ayLTS?NRhIB9&m0-V<^X=LSB%xL6DfKc)E@=V604Obi&wFLlR2c>WY-qjiSHG82zMc}+e7yon9kEko$tQ6zK(nJ6KUDUSF( z7~*#lq1@{lo#`4q2FP9>pxy_G$PP<}>_M1eA^t24y#V5tu%tIi04a&gazC3%Owo8T zgfuTrks;@pSL6^BG6)Z?&?fY0BbY4oQZdxsd;Y6tKffySD0+tYi$v}v@--sMMCdu> z_kpw)`3mtAX#9R6ghj3+@k(M0f0%kH&iOY$gpN0Yf1R2S5uw-@ri5(%7KSJB2N<`T;->UP)@|30cp)hP%ZR$h@|{E8b5?l{ya71r3ez`QMpUp%S7Zw z*N9sua)!vaLEJjth@+795{*b_zJY5KpV4dwG70J80N6BJ(~N2r_NZ5p=>gD;DOhDn zzX4|j>;cdW{wn6R)YC1kEmeJ3-z7?($KsCX2|dAcgzukpE9`A>%1B(p-$+84;PP3t zrGiBnZ{kuk8sH1TlvSZV`Eg=l42G9JS-2?KNtEG{oivWagp&`8$$_5^!#Kr8_|PB@ zd6BH8jf6v66rW$5z9$Mp8+0~lX<>U_w})rJ6D4H7