Update TO_Model.py

This commit is contained in:
p15806732 2022-11-04 11:25:55 +08:00
parent a3ba29abd3
commit 19b64de76e
1 changed files with 17 additions and 92 deletions

View File

@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 19 00:00:06 2021
@author: zzy
Object: AuTONR的model构建
1. 基于Jax+Flax+JAXopt/Optax
2. 关于dataclasses和flax.linen可以参考
learning_dataclasses.py
learning_flax_module.py
"""
'''
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
@ Author : Zeyu Zhang
@ Email : zhangzeyu_work@outlook.com
@ Date : 2022-10-25 09:30:27
@ LastEditTime : 2022-10-25 09:36:54
@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Model.py
@
@ Description :
@ Reference :
'''
import jax
import jax.numpy as jnp
@ -33,37 +33,12 @@ def set_random_seed(seed):
def constant_array(rng_key, shape, dtype=jnp.float64, value=0.0):
"""
自定义参数初始化函数 实现矩阵填充数值
Parameters
----------
rng_key :
shape : 目标矩阵的shape
dtype : 数据格式 The default is jnp.float64.
value : 填充数值 The default is 0.0.
Returns
-------
out : 初始化后的矩阵
"""
out = jnp.full(shape, value, dtype)
# print("i am running")
return out
def extract_model_params(params_use):
"""
提取网络参数
Parameters
----------
params_use : FrozenDict 需要提取的参数
Returns
-------
extracted_params : dict 提取出的参数 {'Conv_0/bias': ..., 'Conv_0/kernel': ..., ...}
params_shape : dict 提取的每一组参数的shape {'Conv_0/bias': (...), 'Conv_0/kernel': (...), ...}
params_total_num : 参数总数
"""
params_unfreeze = params_use.unfreeze()
extracted_params = {
'/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()}
@ -73,50 +48,27 @@ def extract_model_params(params_use):
def replace_model_params(pre_trained_params_dict, params_use):
"""
替换网络参数
Parameters
----------
pre_trained_params_dict : dict 预训练的参数 {'Conv_0/bias': ..., 'Conv_0/kernel': ..., ...}
params_use : FrozenDict 需要被替换的参数
Returns
-------
new_params_use : FrozenDict 新的参数
"""
params_unfreeze = params_use.unfreeze()
# 有两种方式均可实现功能
# Method A 基于flax.traverse_util实现
#
extracted_params = {
'/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()}
for key, val in pre_trained_params_dict.items():
extracted_params[key] = val
new_params_use = flu.unflatten_dict(
{tuple(k.split('/')): v for k, v in extracted_params.items()})
# Method B 基于jax.tree实现
#
params_leaves, params_struct = jtree.tree_flatten(params_unfreeze)
i = 0
for key, val in pre_trained_params_dict.items():
params_leaves[i] = val
i = i + 1
new_params_use = jtree.tree_unflatten(params_struct, params_leaves)
# 封装新的参数
#
new_params_use = fcfd.freeze(new_params_use)
return new_params_use
def batched_loss(x_inputs, TopOpt_envs):
"""
处理带有batch_size和多个实例化的TopOpt类的情况
Parameters
----------
x_inputs : shape:[1, nely, nelx] (3D array) 第一个维度代表batch 第二和第三个维度代表TO_2D问题中的单元相对密度 通常batch_size = 1
TopOpt_envs : 多个实例化的TopOpt类 以列表形式存储 通常仅有一个实例化的类
Returns
-------
losses_array : loss_list以列表形式存储对应于每一个batch和TopOpt的obj输出 将其转化为array形式
"""
loss_list = [TopOpt.objective(x_inputs[i], projection=True)
for i, TopOpt in enumerate(TopOpt_envs)]
losses_array = jnp.stack(loss_list)[0]
@ -125,14 +77,11 @@ def batched_loss(x_inputs, TopOpt_envs):
# @jax.jit
class Normalize_TO(nn.Module):
"""
自定义Normalize类
"""
epsilon: float = 1e-6
@nn.compact
def __call__(self, input):
# 通过对axis设置可实现不同normalization的效果
input_mean = jnp.mean(input, axis=(1, 2, 3), keepdims=True)
input_var = jnp.var(input, axis=(1, 2, 3), keepdims=True)
output = input - input_mean
@ -181,7 +130,6 @@ class My_TO_Model(nn.Module):
extract_params: Callable = extract_model_params
def loss(self, input_TO):
# 以list形式传入TopOpt 其实tuple形式也可以
obj_TO = batched_loss(input_TO, [self.TopOpt])
return obj_TO
@ -192,7 +140,7 @@ class TO_CNN(My_TO_Model):
dense_scale_init: Any = 1.0
dense_output_channel: int = 1000
dtype: Any = jnp.float32
input_init_std: float = 1.0 # 初始化设计变量的标准差 正态分布
input_init_std: float = 1.0
dense_input_channel: int = 128
conv_input_0: int = 32
conv_output: tuple = (128, 64, 32, 16, 1)
@ -222,28 +170,5 @@ class TO_CNN(My_TO_Model):
nn_output = jnp.squeeze(output, axis=-1)
self.sow('intermediates', 'model_out',
nn_output) # 在这里是否可以考虑去掉batch维度
nn_output)
return nn_output
if __name__ == '__main__':
def funA(params):
out = 1 * jnp.sum(params)
return out
def funB(params):
out = 2 * jnp.sum(params)
return out
def funC(params):
out = 3 * jnp.sum(params)
return out
# 测试 直观展示batched_loss
params_1 = jnp.ones([3, 4, 4])
fun_list = [funA, funB, funC]
losses = [fun_used(params_1[i]) for i, fun_used in enumerate(fun_list)]
loss_out = jnp.stack(losses)
print("losses:", losses)
print("loss_out:", loss_out)