Torch jit script for transform utils somewhat working

This commit is contained in:
hang-yin 2024-07-25 17:08:30 -07:00
parent c4e89eef4a
commit 08a12ce553
1 changed files with 37 additions and 31 deletions

View File

@ -508,18 +508,19 @@ def quat2mat(quaternion):
inds = th.tensor([3, 0, 1, 2])
input_shape = quaternion.shape[:-1]
q = quaternion.reshape(-1, 4)[:, inds]
# Conduct dot product
n = th.bmm(q.unsqueeze(1), q.unsqueeze(-1)).squeeze(-1).squeeze(-1) # shape (-1)
idx = th.nonzero(n).reshape(-1)
q_ = q.clone() # Copy so we don't have inplace operations that fail to backprop
q_[idx, :] = q[idx, :] * th.sqrt(2.0 / n[idx].unsqueeze(-1))
# Conduct outer product
q2 = th.bmm(q_.unsqueeze(-1), q_.unsqueeze(1)).squeeze(-1).squeeze(-1) # shape (-1, 4 ,4)
# Create return array
ret = th.eye(3, 3, device=q.device).reshape(1, 3, 3).repeat(th.prod(th.tensor(input_shape)), 1, 1)
ret = (
th.eye(3, 3, dtype=quaternion.dtype, device=q.device)
.reshape(1, 3, 3)
.repeat(th.prod(th.tensor(input_shape)), 1, 1)
)
ret[idx, :, :] = th.stack(
[
th.stack(
@ -536,21 +537,18 @@ def quat2mat(quaternion):
),
],
dim=1,
)
).to(dtype=quaternion.dtype)
# Reshape and return output
ret = ret.reshape(list(input_shape) + [3, 3])
return ret
# @th.jit.script
@th.jit.script
def mat2quat(rmat: th.Tensor) -> th.Tensor:
"""
Converts given rotation matrix to quaternion.
Args:
rmat (th.Tensor): (..., 3, 3) rotation matrix
Returns:
th.Tensor: (..., 4) (x,y,z,w) float quaternion angles
"""
@ -569,29 +567,37 @@ def mat2quat(rmat: th.Tensor) -> th.Tensor:
m10, m11, m12 = rmat[..., 1, 0], rmat[..., 1, 1], rmat[..., 1, 2]
m20, m21, m22 = rmat[..., 2, 0], rmat[..., 2, 1], rmat[..., 2, 2]
q = th.where(
m22 >= 0,
th.where(
(m00 > m11) & (m00 > m22),
th.stack([1 + m00 - m11 - m22, m01 + m10, m02 + m20, m21 - m12], dim=-1),
th.where(
m11 > m22,
th.stack([m01 + m10, 1 - m00 + m11 - m22, m12 + m21, m02 - m20], dim=-1),
th.stack([m02 + m20, m12 + m21, 1 - m00 - m11 + m22, m10 - m01], dim=-1),
),
),
th.where(
(m00 < -m11) & (m00 < -m22),
th.stack([m21 - m12, m02 + m20, m10 + m01, 1 + m00 - m11 - m22], dim=-1),
th.where(
-m11 < -m22,
th.stack([m02 - m20, m10 + m01, m21 + m12, 1 - m00 + m11 - m22], dim=-1),
th.stack([m10 - m01, m21 + m12, m02 + m20, 1 - m00 - m11 + m22], dim=-1),
),
),
)
trace = m00 + m11 + m22
quat = 0.5 * q / th.sqrt(th.abs(q[..., 3:4]).clamp(min=1e-6))
if trace > 0:
s = 2.0 * th.sqrt(trace + 1.0)
w = 0.25 * s
x = (m21 - m12) / s
y = (m02 - m20) / s
z = (m10 - m01) / s
elif m00 > m11 and m00 > m22:
s = 2.0 * th.sqrt(1.0 + m00 - m11 - m22)
w = (m21 - m12) / s
x = 0.25 * s
y = (m01 + m10) / s
z = (m02 + m20) / s
elif m11 > m22:
s = 2.0 * th.sqrt(1.0 + m11 - m00 - m22)
w = (m02 - m20) / s
x = (m01 + m10) / s
y = 0.25 * s
z = (m12 + m21) / s
else:
s = 2.0 * th.sqrt(1.0 + m22 - m00 - m11)
w = (m10 - m01) / s
x = (m02 + m20) / s
y = (m12 + m21) / s
z = 0.25 * s
quat = th.stack([x, y, z, w], dim=-1)
# Normalize the quaternion
quat = quat / th.norm(quat, dim=-1, keepdim=True)
# Remove extra dimensions if they were added
if len(original_shape) == 2: