Spaces:
Running on Zero
Running on Zero
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| """Rotation-space conversion utilities for skeleton motion data.""" | |
| import einops | |
| import torch | |
| from ..tools import ensure_batched | |
| from .kinematics import batch_rigid_transform | |
| def global_rots_to_local_rots(global_joint_rots: torch.Tensor, skeleton): | |
| """Convert global rotations to local rotations using a skeleton hierarchy. | |
| Args: | |
| global_joint_rots: Global rotation matrices with shape `(..., J, 3, 3)`. | |
| skeleton: Skeleton object exposing `joint_parents` and `root_idx`. | |
| Returns: | |
| Local rotation matrices with the same leading shape as the input. | |
| """ | |
| # Doing big batch | |
| global_joint_mats, ps = einops.pack( | |
| [global_joint_rots], | |
| "* nbjoints dim1 dim2", | |
| ) | |
| # obtain back the local rotations from the new global rotations | |
| parent_rot_mats = global_joint_mats[:, skeleton.joint_parents] | |
| parent_rot_mats[:, skeleton.root_idx] = torch.eye(3) # the root joint | |
| parent_rot_mats_inv = parent_rot_mats.transpose(2, 3) | |
| local_rot_mats = torch.einsum( | |
| "T N m n, T N n o -> T N m o", | |
| parent_rot_mats_inv, | |
| global_joint_mats, | |
| ) | |
| [local_rot_mats] = einops.unpack(local_rot_mats, ps, "* nbjoints dim1 dim2") | |
| return local_rot_mats | |
| def change_tpose(local_rot_mats: torch.Tensor, global_rot_offsets: torch.Tensor, skeleton): | |
| """Re-express local rotations in another t_pose based on the global rotation offsets. | |
| Args: | |
| local_rot_mats: Local rotation matrices with shape `(..., J, 3, 3)`. | |
| global_rot_offsets: Global rotation offsets with shape `(..., J, 3, 3)`. | |
| skeleton: Skeleton object exposing `joint_parents`, | |
| `root_idx`, and `nbjoints`. | |
| Returns: | |
| Tuple `(new_local_rot_mats, new_global_rot_mats)` in the standard frame. | |
| """ | |
| device, dtype = local_rot_mats.device, local_rot_mats.dtype | |
| global_rot_offsets = global_rot_offsets.to(device=device, dtype=dtype) | |
| root_idx = skeleton.root_idx | |
| joint_parents = skeleton.joint_parents | |
| # These are dummy joint positions, will not be used | |
| neutral_joints = torch.ones((len(local_rot_mats), skeleton.nbjoints, 3), device=device, dtype=dtype) | |
| # get the old joint rotations in the same global space as the t-pose | |
| # Note: the neutral joints we use here doesn't matter, because we are only using the global rotation outputs | |
| _, global_rot_mats = batch_rigid_transform(local_rot_mats, neutral_joints, joint_parents, root_idx) # (T, N, 3, 3) | |
| # compute the desired joint rotations in the frame of the new t-pose | |
| new_global_rot_mats = torch.einsum("T N m n, N o n -> T N m o", global_rot_mats, global_rot_offsets) | |
| # convert back to local rotations | |
| new_local_rot_mats = global_rots_to_local_rots(new_global_rot_mats, skeleton) | |
| return new_local_rot_mats, new_global_rot_mats | |
| def to_standard_tpose(local_rot_mats: torch.Tensor, skeleton): | |
| """Re-express local rotations in the skeleton's standard T-pose convention. | |
| Args: | |
| local_rot_mats: Local rotation matrices with shape `(..., J, 3, 3)`. | |
| skeleton: Skeleton object exposing `global_rot_offsets`, `joint_parents`, | |
| `root_idx`, and `nbjoints`. | |
| Returns: | |
| Tuple `(new_local_rot_mats, new_global_rot_mats)` in the standard frame. | |
| """ | |
| global_rot_offsets = skeleton.global_rot_offsets | |
| return change_tpose(local_rot_mats, global_rot_offsets, skeleton) | |
| def from_standard_tpose(local_rot_mats: torch.Tensor, skeleton): | |
| """Re-express local rotations from the skeleton's standard T-pose convention to the original | |
| formulation. | |
| Args: | |
| local_rot_mats: Local rotation matrices with shape `(..., J, 3, 3)`. | |
| skeleton: Skeleton object exposing `global_rot_offsets`, `joint_parents`, | |
| `root_idx`, and `nbjoints`. | |
| Returns: | |
| Tuple `(new_local_rot_mats, new_global_rot_mats)` in the standard frame. | |
| """ | |
| global_rot_offsets = skeleton.global_rot_offsets | |
| global_rot_offsets_T = global_rot_offsets.mT # do the inverse transform | |
| return change_tpose(local_rot_mats, global_rot_offsets_T, skeleton) | |