96 lines
3.2 KiB
Python
96 lines
3.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
'''
|
|
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
|
|
@ Author : Zeyu Zhang
|
|
@ Email : zhangzeyu_work@outlook.com
|
|
@ Date : 2022-10-25 09:30:27
|
|
@ LastEditTime : 2022-10-25 09:37:03
|
|
@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Train.py
|
|
@
|
|
@ Description :
|
|
@ Reference :
|
|
'''
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
from absl import logging
|
|
import xarray
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax import random
|
|
from jaxopt import OptaxSolver
|
|
import optax
|
|
import time
|
|
from functools import partial
|
|
import matplotlib.pyplot as plt
|
|
from jax.config import config
|
|
config.update("jax_enable_x64", True)
|
|
sys.path.append(
|
|
'/home/zzy/ZZY_CODE/Env_JAX/Topology_Optimization/Linear_TO')
|
|
here = Path(__file__).resolve().parent
|
|
|
|
|
|
def set_random_seed(seed):
|
|
if seed is not None:
|
|
rand_key = random.PRNGKey(seed)
|
|
return rand_key
|
|
|
|
|
|
def optimizer_result_dataset(losses, frames, save_intermediate_designs=False):
|
|
best_design = jnp.nanargmin(losses)
|
|
logging.info(f'Final loss: {losses[best_design]}')
|
|
if save_intermediate_designs:
|
|
ds = xarray.Dataset({
|
|
'loss': (('step',), losses),
|
|
'design': (('step', 'y', 'x'), frames),
|
|
}, coords={'step': jnp.arange(len(losses))})
|
|
else:
|
|
ds = xarray.Dataset({
|
|
'loss': (('step',), losses),
|
|
'design': (('y', 'x'), frames[best_design]),
|
|
}, coords={'step': jnp.arange(len(losses))})
|
|
return ds
|
|
|
|
|
|
def train_TO_Optax(model, max_iterations, save_intermediate_designs=True, seed=1, **kwargs):
|
|
losses = []
|
|
frames = []
|
|
# Initialize parameters
|
|
init_params = model.init(set_random_seed(seed), None)
|
|
batch_state, params = init_params.pop("params")
|
|
del init_params
|
|
# Instantiate Optimizer
|
|
design_condition = model.TopOpt.design_condition
|
|
if design_condition == 1:
|
|
learning_rate = 0.001
|
|
optimizer = optax.adam(learning_rate)
|
|
elif design_condition == 2:
|
|
learning_rate = 0.001
|
|
optimizer = optax.adam(learning_rate)
|
|
# loss
|
|
def loss_cal(params):
|
|
all_params = {"params": params}
|
|
model_out, net_state = model.apply(all_params, None, mutable="intermediates")
|
|
loss_out = model.loss(model_out)
|
|
return loss_out, net_state["intermediates"]
|
|
loss_grad_fn = jax.value_and_grad(loss_cal, has_aux=True)
|
|
# Initialize Optimizer State
|
|
state = optimizer.init(params)
|
|
# Updated
|
|
for iter in range(max_iterations):
|
|
start_time_epoch = time.perf_counter()
|
|
(loss_val, batch_stats), grads = loss_grad_fn(params)
|
|
losses.append(jax.device_get(loss_val))
|
|
updates_params, state = optimizer.update(grads, state)
|
|
params = optax.apply_updates(params, updates_params)
|
|
design_TO = jax.device_get(model.TopOpt.xPhys.aval.val)
|
|
frames.append(design_TO)
|
|
plt.figure()
|
|
plt.imshow(design_TO, cmap='Greys')
|
|
plt.show()
|
|
whole_time_epoch = time.perf_counter() - start_time_epoch
|
|
print(' It.:%5i Obj.:%11.4f Vol.:%7.3f Time.:%7.3f\n' %
|
|
(iter+1, losses[-1], jnp.mean(design_TO), whole_time_epoch))
|
|
|
|
return optimizer_result_dataset(jnp.array(losses), jnp.array(frames), save_intermediate_designs)
|