Torch jit script for transform utils somewhat working
This commit is contained in:
parent
c4e89eef4a
commit
08a12ce553
|
@ -508,18 +508,19 @@ def quat2mat(quaternion):
|
|||
inds = th.tensor([3, 0, 1, 2])
|
||||
input_shape = quaternion.shape[:-1]
|
||||
q = quaternion.reshape(-1, 4)[:, inds]
|
||||
|
||||
# Conduct dot product
|
||||
n = th.bmm(q.unsqueeze(1), q.unsqueeze(-1)).squeeze(-1).squeeze(-1) # shape (-1)
|
||||
idx = th.nonzero(n).reshape(-1)
|
||||
q_ = q.clone() # Copy so we don't have inplace operations that fail to backprop
|
||||
q_[idx, :] = q[idx, :] * th.sqrt(2.0 / n[idx].unsqueeze(-1))
|
||||
|
||||
# Conduct outer product
|
||||
q2 = th.bmm(q_.unsqueeze(-1), q_.unsqueeze(1)).squeeze(-1).squeeze(-1) # shape (-1, 4 ,4)
|
||||
|
||||
# Create return array
|
||||
ret = th.eye(3, 3, device=q.device).reshape(1, 3, 3).repeat(th.prod(th.tensor(input_shape)), 1, 1)
|
||||
ret = (
|
||||
th.eye(3, 3, dtype=quaternion.dtype, device=q.device)
|
||||
.reshape(1, 3, 3)
|
||||
.repeat(th.prod(th.tensor(input_shape)), 1, 1)
|
||||
)
|
||||
ret[idx, :, :] = th.stack(
|
||||
[
|
||||
th.stack(
|
||||
|
@ -536,21 +537,18 @@ def quat2mat(quaternion):
|
|||
),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
).to(dtype=quaternion.dtype)
|
||||
# Reshape and return output
|
||||
ret = ret.reshape(list(input_shape) + [3, 3])
|
||||
return ret
|
||||
|
||||
|
||||
# @th.jit.script
|
||||
@th.jit.script
|
||||
def mat2quat(rmat: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Converts given rotation matrix to quaternion.
|
||||
|
||||
Args:
|
||||
rmat (th.Tensor): (..., 3, 3) rotation matrix
|
||||
|
||||
Returns:
|
||||
th.Tensor: (..., 4) (x,y,z,w) float quaternion angles
|
||||
"""
|
||||
|
@ -569,29 +567,37 @@ def mat2quat(rmat: th.Tensor) -> th.Tensor:
|
|||
m10, m11, m12 = rmat[..., 1, 0], rmat[..., 1, 1], rmat[..., 1, 2]
|
||||
m20, m21, m22 = rmat[..., 2, 0], rmat[..., 2, 1], rmat[..., 2, 2]
|
||||
|
||||
q = th.where(
|
||||
m22 >= 0,
|
||||
th.where(
|
||||
(m00 > m11) & (m00 > m22),
|
||||
th.stack([1 + m00 - m11 - m22, m01 + m10, m02 + m20, m21 - m12], dim=-1),
|
||||
th.where(
|
||||
m11 > m22,
|
||||
th.stack([m01 + m10, 1 - m00 + m11 - m22, m12 + m21, m02 - m20], dim=-1),
|
||||
th.stack([m02 + m20, m12 + m21, 1 - m00 - m11 + m22, m10 - m01], dim=-1),
|
||||
),
|
||||
),
|
||||
th.where(
|
||||
(m00 < -m11) & (m00 < -m22),
|
||||
th.stack([m21 - m12, m02 + m20, m10 + m01, 1 + m00 - m11 - m22], dim=-1),
|
||||
th.where(
|
||||
-m11 < -m22,
|
||||
th.stack([m02 - m20, m10 + m01, m21 + m12, 1 - m00 + m11 - m22], dim=-1),
|
||||
th.stack([m10 - m01, m21 + m12, m02 + m20, 1 - m00 - m11 + m22], dim=-1),
|
||||
),
|
||||
),
|
||||
)
|
||||
trace = m00 + m11 + m22
|
||||
|
||||
quat = 0.5 * q / th.sqrt(th.abs(q[..., 3:4]).clamp(min=1e-6))
|
||||
if trace > 0:
|
||||
s = 2.0 * th.sqrt(trace + 1.0)
|
||||
w = 0.25 * s
|
||||
x = (m21 - m12) / s
|
||||
y = (m02 - m20) / s
|
||||
z = (m10 - m01) / s
|
||||
elif m00 > m11 and m00 > m22:
|
||||
s = 2.0 * th.sqrt(1.0 + m00 - m11 - m22)
|
||||
w = (m21 - m12) / s
|
||||
x = 0.25 * s
|
||||
y = (m01 + m10) / s
|
||||
z = (m02 + m20) / s
|
||||
elif m11 > m22:
|
||||
s = 2.0 * th.sqrt(1.0 + m11 - m00 - m22)
|
||||
w = (m02 - m20) / s
|
||||
x = (m01 + m10) / s
|
||||
y = 0.25 * s
|
||||
z = (m12 + m21) / s
|
||||
else:
|
||||
s = 2.0 * th.sqrt(1.0 + m22 - m00 - m11)
|
||||
w = (m10 - m01) / s
|
||||
x = (m02 + m20) / s
|
||||
y = (m12 + m21) / s
|
||||
z = 0.25 * s
|
||||
|
||||
quat = th.stack([x, y, z, w], dim=-1)
|
||||
|
||||
# Normalize the quaternion
|
||||
quat = quat / th.norm(quat, dim=-1, keepdim=True)
|
||||
|
||||
# Remove extra dimensions if they were added
|
||||
if len(original_shape) == 2:
|
||||
|
|
Loading…
Reference in New Issue