AuTONR/Linear_TO/TO_Train.py

96 lines
3.2 KiB
Python

# -*- coding: utf-8 -*-
'''
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
@ Author : Zeyu Zhang
@ Email : zhangzeyu_work@outlook.com
@ Date : 2022-10-25 09:30:27
@ LastEditTime : 2022-10-25 09:37:03
@ FilePath : /ZZY_CODE/Env_JAX/IDRL/Linear_TO/TO_Train.py
@
@ Description :
@ Reference :
'''
import sys
from pathlib import Path
from absl import logging
import xarray
import jax
import jax.numpy as jnp
from jax import random
from jaxopt import OptaxSolver
import optax
import time
from functools import partial
import matplotlib.pyplot as plt
from jax.config import config
config.update("jax_enable_x64", True)
sys.path.append(
'/home/zzy/ZZY_CODE/Env_JAX/Topology_Optimization/Linear_TO')
here = Path(__file__).resolve().parent
def set_random_seed(seed):
if seed is not None:
rand_key = random.PRNGKey(seed)
return rand_key
def optimizer_result_dataset(losses, frames, save_intermediate_designs=False):
best_design = jnp.nanargmin(losses)
logging.info(f'Final loss: {losses[best_design]}')
if save_intermediate_designs:
ds = xarray.Dataset({
'loss': (('step',), losses),
'design': (('step', 'y', 'x'), frames),
}, coords={'step': jnp.arange(len(losses))})
else:
ds = xarray.Dataset({
'loss': (('step',), losses),
'design': (('y', 'x'), frames[best_design]),
}, coords={'step': jnp.arange(len(losses))})
return ds
def train_TO_Optax(model, max_iterations, save_intermediate_designs=True, seed=1, **kwargs):
losses = []
frames = []
# Initialize parameters
init_params = model.init(set_random_seed(seed), None)
batch_state, params = init_params.pop("params")
del init_params
# Instantiate Optimizer
design_condition = model.TopOpt.design_condition
if design_condition == 1:
learning_rate = 0.001
optimizer = optax.adam(learning_rate)
elif design_condition == 2:
learning_rate = 0.001
optimizer = optax.adam(learning_rate)
# loss
def loss_cal(params):
all_params = {"params": params}
model_out, net_state = model.apply(all_params, None, mutable="intermediates")
loss_out = model.loss(model_out)
return loss_out, net_state["intermediates"]
loss_grad_fn = jax.value_and_grad(loss_cal, has_aux=True)
# Initialize Optimizer State
state = optimizer.init(params)
# Updated
for iter in range(max_iterations):
start_time_epoch = time.perf_counter()
(loss_val, batch_stats), grads = loss_grad_fn(params)
losses.append(jax.device_get(loss_val))
updates_params, state = optimizer.update(grads, state)
params = optax.apply_updates(params, updates_params)
design_TO = jax.device_get(model.TopOpt.xPhys.aval.val)
frames.append(design_TO)
plt.figure()
plt.imshow(design_TO, cmap='Greys')
plt.show()
whole_time_epoch = time.perf_counter() - start_time_epoch
print(' It.:%5i Obj.:%11.4f Vol.:%7.3f Time.:%7.3f\n' %
(iter+1, losses[-1], jnp.mean(design_TO), whole_time_epoch))
return optimizer_result_dataset(jnp.array(losses), jnp.array(frames), save_intermediate_designs)