movimento / kimodo /skeleton /transforms.py
Kimodo Bot
Add core kimodo package modules required by native demo
6d5047c
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Rotation-space conversion utilities for skeleton motion data."""
import einops
import torch
from ..tools import ensure_batched
from .kinematics import batch_rigid_transform
def global_rots_to_local_rots(global_joint_rots: torch.Tensor, skeleton):
"""Convert global rotations to local rotations using a skeleton hierarchy.
Args:
global_joint_rots: Global rotation matrices with shape `(..., J, 3, 3)`.
skeleton: Skeleton object exposing `joint_parents` and `root_idx`.
Returns:
Local rotation matrices with the same leading shape as the input.
"""
# Doing big batch
global_joint_mats, ps = einops.pack(
[global_joint_rots],
"* nbjoints dim1 dim2",
)
# obtain back the local rotations from the new global rotations
parent_rot_mats = global_joint_mats[:, skeleton.joint_parents]
parent_rot_mats[:, skeleton.root_idx] = torch.eye(3) # the root joint
parent_rot_mats_inv = parent_rot_mats.transpose(2, 3)
local_rot_mats = torch.einsum(
"T N m n, T N n o -> T N m o",
parent_rot_mats_inv,
global_joint_mats,
)
[local_rot_mats] = einops.unpack(local_rot_mats, ps, "* nbjoints dim1 dim2")
return local_rot_mats
@ensure_batched(local_rot_mats=4)
def change_tpose(local_rot_mats: torch.Tensor, global_rot_offsets: torch.Tensor, skeleton):
"""Re-express local rotations in another t_pose based on the global rotation offsets.
Args:
local_rot_mats: Local rotation matrices with shape `(..., J, 3, 3)`.
global_rot_offsets: Global rotation offsets with shape `(..., J, 3, 3)`.
skeleton: Skeleton object exposing `joint_parents`,
`root_idx`, and `nbjoints`.
Returns:
Tuple `(new_local_rot_mats, new_global_rot_mats)` in the standard frame.
"""
device, dtype = local_rot_mats.device, local_rot_mats.dtype
global_rot_offsets = global_rot_offsets.to(device=device, dtype=dtype)
root_idx = skeleton.root_idx
joint_parents = skeleton.joint_parents
# These are dummy joint positions, will not be used
neutral_joints = torch.ones((len(local_rot_mats), skeleton.nbjoints, 3), device=device, dtype=dtype)
# get the old joint rotations in the same global space as the t-pose
# Note: the neutral joints we use here doesn't matter, because we are only using the global rotation outputs
_, global_rot_mats = batch_rigid_transform(local_rot_mats, neutral_joints, joint_parents, root_idx) # (T, N, 3, 3)
# compute the desired joint rotations in the frame of the new t-pose
new_global_rot_mats = torch.einsum("T N m n, N o n -> T N m o", global_rot_mats, global_rot_offsets)
# convert back to local rotations
new_local_rot_mats = global_rots_to_local_rots(new_global_rot_mats, skeleton)
return new_local_rot_mats, new_global_rot_mats
@ensure_batched(local_rot_mats=4)
def to_standard_tpose(local_rot_mats: torch.Tensor, skeleton):
"""Re-express local rotations in the skeleton's standard T-pose convention.
Args:
local_rot_mats: Local rotation matrices with shape `(..., J, 3, 3)`.
skeleton: Skeleton object exposing `global_rot_offsets`, `joint_parents`,
`root_idx`, and `nbjoints`.
Returns:
Tuple `(new_local_rot_mats, new_global_rot_mats)` in the standard frame.
"""
global_rot_offsets = skeleton.global_rot_offsets
return change_tpose(local_rot_mats, global_rot_offsets, skeleton)
@ensure_batched(local_rot_mats=4)
def from_standard_tpose(local_rot_mats: torch.Tensor, skeleton):
"""Re-express local rotations from the skeleton's standard T-pose convention to the original
formulation.
Args:
local_rot_mats: Local rotation matrices with shape `(..., J, 3, 3)`.
skeleton: Skeleton object exposing `global_rot_offsets`, `joint_parents`,
`root_idx`, and `nbjoints`.
Returns:
Tuple `(new_local_rot_mats, new_global_rot_mats)` in the standard frame.
"""
global_rot_offsets = skeleton.global_rot_offsets
global_rot_offsets_T = global_rot_offsets.mT # do the inverse transform
return change_tpose(local_rot_mats, global_rot_offsets_T, skeleton)