Update TO_Model.py
This commit is contained in:
parent
a3ba29abd3
commit
19b64de76e
|
@ -1,15 +1,15 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Tue Oct 19 00:00:06 2021
|
||||
|
||||
@author: zzy
|
||||
|
||||
Object: AuTONR的model构建
|
||||
1. 基于Jax+Flax+JAXopt/Optax
|
||||
2. 关于dataclasses和flax.linen可以参考:
|
||||
learning_dataclasses.py
|
||||
learning_flax_module.py
|
||||
"""
|
||||
'''
|
||||
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
|
||||
@ Author : Zeyu Zhang
|
||||
@ Email : zhangzeyu_work@outlook.com
|
||||
@ Date : 2022-10-25 09:30:27
|
||||
@ LastEditTime : 2022-10-25 09:36:54
|
||||
@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Model.py
|
||||
@
|
||||
@ Description :
|
||||
@ Reference :
|
||||
'''
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
@ -33,37 +33,12 @@ def set_random_seed(seed):
|
|||
|
||||
|
||||
def constant_array(rng_key, shape, dtype=jnp.float64, value=0.0):
|
||||
"""
|
||||
自定义参数初始化函数 实现矩阵填充数值
|
||||
Parameters
|
||||
----------
|
||||
rng_key :
|
||||
shape : 目标矩阵的shape
|
||||
dtype : 数据格式 The default is jnp.float64.
|
||||
value : 填充数值 The default is 0.0.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : 初始化后的矩阵
|
||||
"""
|
||||
out = jnp.full(shape, value, dtype)
|
||||
# print("i am running")
|
||||
return out
|
||||
|
||||
|
||||
def extract_model_params(params_use):
|
||||
"""
|
||||
提取网络参数
|
||||
Parameters
|
||||
----------
|
||||
params_use : FrozenDict 需要提取的参数
|
||||
|
||||
Returns
|
||||
-------
|
||||
extracted_params : dict 提取出的参数 {'Conv_0/bias': ..., 'Conv_0/kernel': ..., ...}
|
||||
params_shape : dict 提取的每一组参数的shape {'Conv_0/bias': (...), 'Conv_0/kernel': (...), ...}
|
||||
params_total_num : 参数总数
|
||||
"""
|
||||
params_unfreeze = params_use.unfreeze()
|
||||
extracted_params = {
|
||||
'/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()}
|
||||
|
@ -73,50 +48,27 @@ def extract_model_params(params_use):
|
|||
|
||||
|
||||
def replace_model_params(pre_trained_params_dict, params_use):
|
||||
"""
|
||||
替换网络参数
|
||||
Parameters
|
||||
----------
|
||||
pre_trained_params_dict : dict 预训练的参数 {'Conv_0/bias': ..., 'Conv_0/kernel': ..., ...}
|
||||
params_use : FrozenDict 需要被替换的参数
|
||||
|
||||
Returns
|
||||
-------
|
||||
new_params_use : FrozenDict 新的参数
|
||||
"""
|
||||
params_unfreeze = params_use.unfreeze()
|
||||
# 有两种方式均可实现功能
|
||||
# Method A 基于flax.traverse_util实现
|
||||
#
|
||||
extracted_params = {
|
||||
'/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()}
|
||||
for key, val in pre_trained_params_dict.items():
|
||||
extracted_params[key] = val
|
||||
new_params_use = flu.unflatten_dict(
|
||||
{tuple(k.split('/')): v for k, v in extracted_params.items()})
|
||||
# Method B 基于jax.tree实现
|
||||
#
|
||||
params_leaves, params_struct = jtree.tree_flatten(params_unfreeze)
|
||||
i = 0
|
||||
for key, val in pre_trained_params_dict.items():
|
||||
params_leaves[i] = val
|
||||
i = i + 1
|
||||
new_params_use = jtree.tree_unflatten(params_struct, params_leaves)
|
||||
# 封装新的参数
|
||||
#
|
||||
new_params_use = fcfd.freeze(new_params_use)
|
||||
return new_params_use
|
||||
|
||||
|
||||
def batched_loss(x_inputs, TopOpt_envs):
|
||||
"""
|
||||
处理带有batch_size和多个实例化的TopOpt类的情况
|
||||
Parameters
|
||||
----------
|
||||
x_inputs : shape:[1, nely, nelx] (3D array) 第一个维度代表batch 第二和第三个维度代表TO_2D问题中的单元相对密度 通常batch_size = 1
|
||||
TopOpt_envs : 多个实例化的TopOpt类 以列表形式存储 通常仅有一个实例化的类
|
||||
|
||||
Returns
|
||||
-------
|
||||
losses_array : loss_list以列表形式存储对应于每一个batch和TopOpt的obj输出 将其转化为array形式
|
||||
"""
|
||||
loss_list = [TopOpt.objective(x_inputs[i], projection=True)
|
||||
for i, TopOpt in enumerate(TopOpt_envs)]
|
||||
losses_array = jnp.stack(loss_list)[0]
|
||||
|
@ -125,14 +77,11 @@ def batched_loss(x_inputs, TopOpt_envs):
|
|||
|
||||
# @jax.jit
|
||||
class Normalize_TO(nn.Module):
|
||||
"""
|
||||
自定义Normalize类
|
||||
"""
|
||||
|
||||
epsilon: float = 1e-6
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input):
|
||||
# 通过对axis设置可实现不同normalization的效果
|
||||
input_mean = jnp.mean(input, axis=(1, 2, 3), keepdims=True)
|
||||
input_var = jnp.var(input, axis=(1, 2, 3), keepdims=True)
|
||||
output = input - input_mean
|
||||
|
@ -181,7 +130,6 @@ class My_TO_Model(nn.Module):
|
|||
extract_params: Callable = extract_model_params
|
||||
|
||||
def loss(self, input_TO):
|
||||
# 以list形式传入TopOpt 其实tuple形式也可以
|
||||
obj_TO = batched_loss(input_TO, [self.TopOpt])
|
||||
return obj_TO
|
||||
|
||||
|
@ -192,7 +140,7 @@ class TO_CNN(My_TO_Model):
|
|||
dense_scale_init: Any = 1.0
|
||||
dense_output_channel: int = 1000
|
||||
dtype: Any = jnp.float32
|
||||
input_init_std: float = 1.0 # 初始化设计变量的标准差 正态分布
|
||||
input_init_std: float = 1.0
|
||||
dense_input_channel: int = 128
|
||||
conv_input_0: int = 32
|
||||
conv_output: tuple = (128, 64, 32, 16, 1)
|
||||
|
@ -222,28 +170,5 @@ class TO_CNN(My_TO_Model):
|
|||
|
||||
nn_output = jnp.squeeze(output, axis=-1)
|
||||
self.sow('intermediates', 'model_out',
|
||||
nn_output) # 在这里是否可以考虑去掉batch维度?
|
||||
nn_output)
|
||||
return nn_output
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def funA(params):
|
||||
out = 1 * jnp.sum(params)
|
||||
return out
|
||||
|
||||
def funB(params):
|
||||
out = 2 * jnp.sum(params)
|
||||
return out
|
||||
|
||||
def funC(params):
|
||||
out = 3 * jnp.sum(params)
|
||||
return out
|
||||
|
||||
# 测试 直观展示batched_loss
|
||||
params_1 = jnp.ones([3, 4, 4])
|
||||
fun_list = [funA, funB, funC]
|
||||
losses = [fun_used(params_1[i]) for i, fun_used in enumerate(fun_list)]
|
||||
loss_out = jnp.stack(losses)
|
||||
print("losses:", losses)
|
||||
print("loss_out:", loss_out)
|
||||
|
|
Loading…
Reference in New Issue