178 lines
5.6 KiB
Python
178 lines
5.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
'''
|
|
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
|
|
@ Author : Zeyu Zhang
|
|
@ Email : zhangzeyu_work@outlook.com
|
|
@ Date : 2022-10-25 09:30:27
|
|
@ LastEditTime : 2022-10-25 09:36:54
|
|
@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Model.py
|
|
@
|
|
@ Description :
|
|
@ Reference :
|
|
'''
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import jax.image as ji
|
|
import jax.tree_util as jtree
|
|
from jax import random
|
|
import flax.linen as nn
|
|
import flax.linen.initializers as fli
|
|
import flax.traverse_util as flu
|
|
import flax.core.frozen_dict as fcfd
|
|
from functools import partial
|
|
from typing import Any, Callable
|
|
from jax.config import config
|
|
config.update("jax_enable_x64", True)
|
|
|
|
|
|
def set_random_seed(seed):
|
|
if seed is not None:
|
|
rand_key = random.PRNGKey(seed)
|
|
return rand_key
|
|
|
|
|
|
def constant_array(rng_key, shape, dtype=jnp.float64, value=0.0):
|
|
out = jnp.full(shape, value, dtype)
|
|
# print("i am running")
|
|
return out
|
|
|
|
|
|
def extract_model_params(params_use):
|
|
params_unfreeze = params_use.unfreeze()
|
|
extracted_params = {
|
|
'/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()}
|
|
params_shape = jtree.tree_map(jnp.shape, extracted_params)
|
|
params_total_num = sum(p.size for p in jtree.tree_leaves(extracted_params))
|
|
return extracted_params, params_shape, params_total_num
|
|
|
|
|
|
def replace_model_params(pre_trained_params_dict, params_use):
|
|
params_unfreeze = params_use.unfreeze()
|
|
#
|
|
extracted_params = {
|
|
'/'.join(k): v for k, v in flu.flatten_dict(params_unfreeze).items()}
|
|
for key, val in pre_trained_params_dict.items():
|
|
extracted_params[key] = val
|
|
new_params_use = flu.unflatten_dict(
|
|
{tuple(k.split('/')): v for k, v in extracted_params.items()})
|
|
#
|
|
params_leaves, params_struct = jtree.tree_flatten(params_unfreeze)
|
|
i = 0
|
|
for key, val in pre_trained_params_dict.items():
|
|
params_leaves[i] = val
|
|
i = i + 1
|
|
new_params_use = jtree.tree_unflatten(params_struct, params_leaves)
|
|
#
|
|
new_params_use = fcfd.freeze(new_params_use)
|
|
return new_params_use
|
|
|
|
|
|
def batched_loss(x_inputs, TopOpt_envs):
|
|
loss_list = [TopOpt.objective(x_inputs[i], projection=True)
|
|
for i, TopOpt in enumerate(TopOpt_envs)]
|
|
losses_array = jnp.stack(loss_list)[0]
|
|
return losses_array
|
|
|
|
|
|
# @jax.jit
|
|
class Normalize_TO(nn.Module):
|
|
|
|
epsilon: float = 1e-6
|
|
|
|
@nn.compact
|
|
def __call__(self, input):
|
|
input_mean = jnp.mean(input, axis=(1, 2, 3), keepdims=True)
|
|
input_var = jnp.var(input, axis=(1, 2, 3), keepdims=True)
|
|
output = input - input_mean
|
|
output = output * jax.lax.rsqrt(input_var + self.epsilon)
|
|
return output
|
|
|
|
|
|
class Add_offset(nn.Module):
|
|
scale: int = 10
|
|
|
|
@nn.compact
|
|
def __call__(self, input):
|
|
add_bias = self.param('add_bias', lambda rng, shape: constant_array(
|
|
rng, shape, value=0.0), input.shape)
|
|
return input + self.scale * add_bias
|
|
|
|
|
|
class BasicBlock(nn.Module):
|
|
resize_scale_one: int
|
|
conv_output_one: int
|
|
norm_use: nn.Module = Normalize_TO
|
|
resize_method: str = "bilinear"
|
|
kernel_size: int = 5
|
|
conv_kernel_init: partial = fli.variance_scaling(
|
|
scale=1.0, mode="fan_in", distribution="truncated_normal")
|
|
offset: nn.Module = Add_offset
|
|
offset_scale: int = 10
|
|
|
|
@nn.compact
|
|
def __call__(self, input):
|
|
output = jnp.tanh(input)
|
|
output_shape = output.shape
|
|
output = ji.resize(output, (output_shape[0], self.resize_scale_one * output_shape[1],
|
|
self.resize_scale_one * output_shape[2], output_shape[-1]), method=self.resize_method, antialias=True)
|
|
output = self.norm_use()(output)
|
|
output = nn.Conv(features=self.conv_output_one, kernel_size=(
|
|
self.kernel_size, self.kernel_size), kernel_init=self.conv_kernel_init)(output)
|
|
output = self.offset(scale=self.offset_scale)(output)
|
|
return output
|
|
|
|
|
|
class My_TO_Model(nn.Module):
|
|
TopOpt: Any
|
|
seed: int = 1
|
|
replace_params: Callable = replace_model_params
|
|
extract_params: Callable = extract_model_params
|
|
|
|
def loss(self, input_TO):
|
|
obj_TO = batched_loss(input_TO, [self.TopOpt])
|
|
return obj_TO
|
|
|
|
|
|
class TO_CNN(My_TO_Model):
|
|
h: int = 5
|
|
w: int = 10
|
|
dense_scale_init: Any = 1.0
|
|
dense_output_channel: int = 1000
|
|
dtype: Any = jnp.float32
|
|
input_init_std: float = 1.0
|
|
dense_input_channel: int = 128
|
|
conv_input_0: int = 32
|
|
conv_output: tuple = (128, 64, 32, 16, 1)
|
|
up_sampling_scale: tuple = (1, 2, 2, 2, 1)
|
|
offset_scale: int = 10
|
|
block: nn.Module = BasicBlock
|
|
|
|
@nn.compact
|
|
def __call__(self, x=None): # [N, H, W, C]
|
|
nn_input = self.param('z', lambda rng, shape: constant_array(
|
|
rng, shape, value=self.TopOpt.volfrac), self.dense_input_channel)
|
|
|
|
output = nn.Dense(features=self.dense_output_channel, kernel_init=fli.orthogonal(
|
|
scale=self.dense_scale_init))(nn_input)
|
|
output = output.reshape([1, self.h, self.w, self.conv_input_0])
|
|
|
|
output = self.block(
|
|
self.up_sampling_scale[0], self.conv_output[0])(output)
|
|
output = self.block(
|
|
self.up_sampling_scale[1], self.conv_output[1])(output)
|
|
output = self.block(
|
|
self.up_sampling_scale[2], self.conv_output[2])(output)
|
|
output = self.block(
|
|
self.up_sampling_scale[3], self.conv_output[3])(output)
|
|
output = self.block(
|
|
self.up_sampling_scale[4], self.conv_output[4])(output)
|
|
|
|
nn_output = jnp.squeeze(output, axis=-1)
|
|
self.sow('intermediates', 'model_out',
|
|
nn_output)
|
|
return nn_output
|
|
|
|
|
|
|