AuTONR/Linear_TO/TO_Cal.py

114 lines
4.0 KiB
Python

# -*- coding: utf-8 -*-
'''
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
@ Author : Zeyu Zhang
@ Email : zhangzeyu_work@outlook.com
@ Date : 2021-11-17 10:52:26
@ LastEditTime : 2022-11-04 11:24:12
@ FilePath : \Linear_TO\TO_Cal.py
@
@ Description : 存储TO计算过程所需的函数
@ Reference :
'''
import numpy as np
import jax
import jax.numpy as jnp
import jax.scipy.signal as jss
import jax.lax as jl
from jaxopt import Bisection
from functools import partial
import TO_Private_Autodiff as TPA
import TO_FEA
import TO_Obj
from jax.config import config
config.update("jax_enable_x64", True)
@partial(jax.jit, static_argnums=(5, 6, 7))
def design_variable(x, design_area_indices, nondesign_area_indices, filter_kernal, filter_weight, beta_x, volfrac, projection=False):
x = filter(x, filter_kernal, filter_weight)
x_1D = x.flatten(order='F')
x_designed = design_and_volume_constraints(
x_1D[design_area_indices], volfrac, beta_x)
xPhys_1D = matrix_set_0(
x_designed, design_area_indices, nondesign_area_indices)
return xPhys_1D
def projection(x, beta_TONR):
output = jnp.tanh(beta_TONR * 0.5 * x)*.5 + 0.5
return output
def logit(obj_V):
p = jnp.clip(
obj_V, 0, 1)
return jnp.log(p) - jnp.log1p(-p)
# @partial(jax.jit, static_argnums=(1, 2))
def design_and_volume_constraints(x, obj_V, beta_TONR):
def find_yita(y, x, beta):
projection_out = projection(x + y, beta)
volume_diff = jnp.mean(projection_out) - obj_V
return volume_diff
lower_bound = logit(obj_V) - jl.stop_gradient(np.max(x))
upper_bound = logit(obj_V) - jl.stop_gradient(np.min(x))
bisec = Bisection(optimality_fun=find_yita, lower=lower_bound, upper=upper_bound, tol=1e-12, maxiter=64, check_bracket=False)
yita = bisec.run(x=x, beta=beta_TONR).params
x_design = projection(x + yita, beta_TONR)
x_design = x_design + 0.00001
return x_design
def filter_define(design_map, filter_width):
dy, dx = jnp.meshgrid(jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width)),
jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width)))
filter_kernal = jnp.maximum(0, filter_width - jnp.sqrt(dx ** 2 + dy ** 2))
filter_weight = jss.convolve2d(design_map, filter_kernal, 'same')
return filter_kernal, filter_weight
# @jax.jit
def filter(matrixA, filter_kernal, filter_weight):
new_matrix = jss.convolve2d(matrixA, filter_kernal, 'same') / filter_weight
return new_matrix
def inverse_permutation(indices):
inverse_perm = jnp.zeros(len(indices), dtype=jnp.int64)
inverse_perm = inverse_perm.at[indices].set(
jnp.arange(len(indices), dtype=jnp.int64))
return inverse_perm
# @jax.jit
def matrix_set_0(nonzero_values, nonzero_indices, zero_indices):
index_map = inverse_permutation(
jnp.concatenate([nonzero_indices, zero_indices]))
u_values = jnp.concatenate([nonzero_values, jnp.zeros(len(zero_indices))])
new_matrix = u_values[index_map]
return new_matrix
@partial(jax.jit, static_argnums=(2, 3))
def matrix_set_1(nonzero_values, nonzero_indices, origin_matrix_size, zero_size):
all_indices = jnp.arange(origin_matrix_size, dtype=jnp.int64)
zero_indices = jnp.setdiff1d(
all_indices, nonzero_indices, size=zero_size, assume_unique=True)
index_map = inverse_permutation(
jnp.concatenate([nonzero_indices, zero_indices]))
u_values = jnp.concatenate([nonzero_values, jnp.ones(len(zero_indices))])
new_matrix = u_values[index_map]
return new_matrix
@partial(jax.jit, static_argnums=(1, 2, 10))
def objective(xPhys_1D, young, young_min, freedofs, fext, s_K_predefined, dispTD_predefined, idx, edofMat, ke, penal):
# xPhys_1D = xPhys.T.flatten()
disp = TO_FEA.FEA(young, young_min, ke, xPhys_1D, freedofs,
penal, idx, fext, s_K_predefined, dispTD_predefined)
obj = TO_Obj.compliance(young, young_min, ke,
xPhys_1D, edofMat, penal, disp)
return obj