AuTONR/Linear_TO/TO_Model.py

178 lines
5.6 KiB
Python

# -*- coding: utf-8 -*-
'''
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
@ Author : Zeyu Zhang
@ Email : zhangzeyu_work@outlook.com
@ Date : 2022-10-25 09:30:27
@ LastEditTime : 2022-10-25 09:36:54
@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Model.py
@
@ Description :
@ Reference :
'''
import jax
import jax.numpy as jnp
import jax.image as ji
import jax.tree_util as jtree
from jax import random
import flax.linen as nn
import flax.linen.initializers as fli
import flax.traverse_util as flu
import flax.core.frozen_dict as fcfd
from functools import partial
from typing import Any, Callable
from jax.config import config
config.update("jax_enable_x64", True)
def set_random_seed(seed):
if seed is not None:
rand_key = random.PRNGKey(seed)
return rand_key
def constant_array(rng_key, shape, dtype=jnp.float64, value=0.0):
out = jnp.full(shape, value, dtype)
# print("i am running")
return out
def extract_model_params(params_use):
params_unfreeze = params_use.unfreeze()
extracted_params = {
'/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()}
params_shape = jtree.tree_map(jnp.shape, extracted_params)
params_total_num = sum(p.size for p in jtree.tree_leaves(extracted_params))
return extracted_params, params_shape, params_total_num
def replace_model_params(pre_trained_params_dict, params_use):
params_unfreeze = params_use.unfreeze()
#
extracted_params = {
'/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()}
for key, val in pre_trained_params_dict.items():
extracted_params[key] = val
new_params_use = flu.unflatten_dict(
{tuple(k.split('/')): v for k, v in extracted_params.items()})
#
params_leaves, params_struct = jtree.tree_flatten(params_unfreeze)
i = 0
for key, val in pre_trained_params_dict.items():
params_leaves[i] = val
i = i + 1
new_params_use = jtree.tree_unflatten(params_struct, params_leaves)
#
new_params_use = fcfd.freeze(new_params_use)
return new_params_use
def batched_loss(x_inputs, TopOpt_envs):
loss_list = [TopOpt.objective(x_inputs[i], projection=True)
for i, TopOpt in enumerate(TopOpt_envs)]
losses_array = jnp.stack(loss_list)[0]
return losses_array
# @jax.jit
class Normalize_TO(nn.Module):
epsilon: float = 1e-6
@nn.compact
def __call__(self, input):
input_mean = jnp.mean(input, axis=(1, 2, 3), keepdims=True)
input_var = jnp.var(input, axis=(1, 2, 3), keepdims=True)
output = input - input_mean
output = output * jax.lax.rsqrt(input_var + self.epsilon)
return output
class Add_offset(nn.Module):
scale: int = 10
@nn.compact
def __call__(self, input):
add_bias = self.param('add_bias', lambda rng, shape: constant_array(
rng, shape, value=0.0), input.shape)
return input + self.scale * add_bias
class BasicBlock(nn.Module):
resize_scale_one: int
conv_output_one: int
norm_use: nn.Module = Normalize_TO
resize_method: str = "bilinear"
kernel_size: int = 5
conv_kernel_init: partial = fli.variance_scaling(
scale=1.0, mode="fan_in", distribution="truncated_normal")
offset: nn.Module = Add_offset
offset_scale: int = 10
@nn.compact
def __call__(self, input):
output = jnp.tanh(input)
output_shape = output.shape
output = ji.resize(output, (output_shape[0], self.resize_scale_one * output_shape[1],
self.resize_scale_one * output_shape[2], output_shape[-1]), method=self.resize_method, antialias=True)
output = self.norm_use()(output)
output = nn.Conv(features=self.conv_output_one, kernel_size=(
self.kernel_size, self.kernel_size), kernel_init=self.conv_kernel_init)(output)
output = self.offset(scale=self.offset_scale)(output)
return output
class My_TO_Model(nn.Module):
TopOpt: Any
seed: int = 1
replace_params: Callable = replace_model_params
extract_params: Callable = extract_model_params
def loss(self, input_TO):
obj_TO = batched_loss(input_TO, [self.TopOpt])
return obj_TO
class TO_CNN(My_TO_Model):
h: int = 5
w: int = 10
dense_scale_init: Any = 1.0
dense_output_channel: int = 1000
dtype: Any = jnp.float32
input_init_std: float = 1.0
dense_input_channel: int = 128
conv_input_0: int = 32
conv_output: tuple = (128, 64, 32, 16, 1)
up_sampling_scale: tuple = (1, 2, 2, 2, 1)
offset_scale: int = 10
block: nn.Module = BasicBlock
@nn.compact
def __call__(self, x=None): # [N, H, W, C]
nn_input = self.param('z', lambda rng, shape: constant_array(
rng, shape, value=self.TopOpt.volfrac), self.dense_input_channel)
output = nn.Dense(features=self.dense_output_channel, kernel_init=fli.orthogonal(
scale=self.dense_scale_init))(nn_input)
output = output.reshape([1, self.h, self.w, self.conv_input_0])
output = self.block(
self.up_sampling_scale[0], self.conv_output[0])(output)
output = self.block(
self.up_sampling_scale[1], self.conv_output[1])(output)
output = self.block(
self.up_sampling_scale[2], self.conv_output[2])(output)
output = self.block(
self.up_sampling_scale[3], self.conv_output[3])(output)
output = self.block(
self.up_sampling_scale[4], self.conv_output[4])(output)
nn_output = jnp.squeeze(output, axis=-1)
self.sow('intermediates', 'model_out',
nn_output)
return nn_output