114 lines
4.0 KiB
Python
114 lines
4.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
'''
|
|
@ Copyright (c) 2022 by Zeyu Zhang, All Rights Reserved.
|
|
@ Author : Zeyu Zhang
|
|
@ Email : zhangzeyu_work@outlook.com
|
|
@ Date : 2021-11-17 10:52:26
|
|
@ LastEditTime : 2022-11-04 11:24:12
|
|
@ FilePath : \Linear_TO\TO_Cal.py
|
|
@
|
|
@ Description : 存储TO计算过程所需的函数
|
|
@ Reference :
|
|
'''
|
|
|
|
import numpy as np
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import jax.scipy.signal as jss
|
|
import jax.lax as jl
|
|
from jaxopt import Bisection
|
|
from functools import partial
|
|
import TO_Private_Autodiff as TPA
|
|
import TO_FEA
|
|
import TO_Obj
|
|
from jax.config import config
|
|
config.update("jax_enable_x64", True)
|
|
|
|
|
|
@partial(jax.jit, static_argnums=(5, 6, 7))
|
|
def design_variable(x, design_area_indices, nondesign_area_indices, filter_kernal, filter_weight, beta_x, volfrac, projection=False):
|
|
x = filter(x, filter_kernal, filter_weight)
|
|
x_1D = x.flatten(order='F')
|
|
x_designed = design_and_volume_constraints(
|
|
x_1D[design_area_indices], volfrac, beta_x)
|
|
xPhys_1D = matrix_set_0(
|
|
x_designed, design_area_indices, nondesign_area_indices)
|
|
return xPhys_1D
|
|
|
|
|
|
def projection(x, beta_TONR):
|
|
output = jnp.tanh(beta_TONR * 0.5 * x)*.5 + 0.5
|
|
return output
|
|
|
|
|
|
def logit(obj_V):
|
|
p = jnp.clip(
|
|
obj_V, 0, 1)
|
|
return jnp.log(p) - jnp.log1p(-p)
|
|
|
|
|
|
# @partial(jax.jit, static_argnums=(1, 2))
|
|
def design_and_volume_constraints(x, obj_V, beta_TONR):
|
|
def find_yita(y, x, beta):
|
|
projection_out = projection(x + y, beta)
|
|
volume_diff = jnp.mean(projection_out) - obj_V
|
|
return volume_diff
|
|
lower_bound = logit(obj_V) - jl.stop_gradient(np.max(x))
|
|
upper_bound = logit(obj_V) - jl.stop_gradient(np.min(x))
|
|
bisec = Bisection(optimality_fun=find_yita, lower=lower_bound, upper=upper_bound, tol=1e-12, maxiter=64, check_bracket=False)
|
|
yita = bisec.run(x=x, beta=beta_TONR).params
|
|
x_design = projection(x + yita, beta_TONR)
|
|
x_design = x_design + 0.00001
|
|
return x_design
|
|
|
|
|
|
def filter_define(design_map, filter_width):
|
|
dy, dx = jnp.meshgrid(jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width)),
|
|
jnp.arange(-jnp.ceil(filter_width)+1, jnp.ceil(filter_width)))
|
|
filter_kernal = jnp.maximum(0, filter_width - jnp.sqrt(dx ** 2 + dy ** 2))
|
|
filter_weight = jss.convolve2d(design_map, filter_kernal, 'same')
|
|
return filter_kernal, filter_weight
|
|
|
|
|
|
# @jax.jit
|
|
def filter(matrixA, filter_kernal, filter_weight):
|
|
new_matrix = jss.convolve2d(matrixA, filter_kernal, 'same') / filter_weight
|
|
return new_matrix
|
|
|
|
|
|
def inverse_permutation(indices):
|
|
inverse_perm = jnp.zeros(len(indices), dtype=jnp.int64)
|
|
inverse_perm = inverse_perm.at[indices].set(
|
|
jnp.arange(len(indices), dtype=jnp.int64))
|
|
return inverse_perm
|
|
|
|
|
|
# @jax.jit
|
|
def matrix_set_0(nonzero_values, nonzero_indices, zero_indices):
|
|
index_map = inverse_permutation(
|
|
jnp.concatenate([nonzero_indices, zero_indices]))
|
|
u_values = jnp.concatenate([nonzero_values, jnp.zeros(len(zero_indices))])
|
|
new_matrix = u_values[index_map]
|
|
return new_matrix
|
|
|
|
|
|
@partial(jax.jit, static_argnums=(2, 3))
|
|
def matrix_set_1(nonzero_values, nonzero_indices, origin_matrix_size, zero_size):
|
|
all_indices = jnp.arange(origin_matrix_size, dtype=jnp.int64)
|
|
zero_indices = jnp.setdiff1d(
|
|
all_indices, nonzero_indices, size=zero_size, assume_unique=True)
|
|
index_map = inverse_permutation(
|
|
jnp.concatenate([nonzero_indices, zero_indices]))
|
|
u_values = jnp.concatenate([nonzero_values, jnp.ones(len(zero_indices))])
|
|
new_matrix = u_values[index_map]
|
|
return new_matrix
|
|
|
|
|
|
@partial(jax.jit, static_argnums=(1, 2, 10))
|
|
def objective(xPhys_1D, young, young_min, freedofs, fext, s_K_predefined, dispTD_predefined, idx, edofMat, ke, penal):
|
|
# xPhys_1D = xPhys.T.flatten()
|
|
disp = TO_FEA.FEA(young, young_min, ke, xPhys_1D, freedofs,
|
|
penal, idx, fext, s_K_predefined, dispTD_predefined)
|
|
obj = TO_Obj.compliance(young, young_min, ke,
|
|
xPhys_1D, edofMat, penal, disp)
|
|
return obj |