# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """Base motion representation: feature layout, normalization, and conditioning helpers.""" import os from typing import Optional import einops import numpy as np import torch from einops import repeat from ...tools import ensure_batched from ..conditioning import build_condition_dicts from ..feature_utils import compute_vel_angle, compute_vel_xyz from ..stats import Stats def _require_split_stats_layout(stats_path: str) -> None: """Raise if stats_path does not contain the required global_root, local_root, body subdirs.""" subdirs = ("global_root", "local_root", "body") missing = [] for name in subdirs: subpath = os.path.join(stats_path, name) mean_path = os.path.join(subpath, "mean.npy") if not os.path.isfile(mean_path): missing.append(f"{subpath}/ (mean.npy)") if missing: raise FileNotFoundError( f"Checkpoint stats must use the split layout with subfolders " f"global_root/, local_root/, and body/ under '{stats_path}'. " f"Missing or incomplete: {', '.join(missing)}. " ) class MotionRepBase: """Base class for motion representations used in generation and conditioning. Subclasses define: - ``size_dict``: feature blocks and their shapes, - ``last_root_feature``: last entry of the root block, - ``local_root_size_dict``: local-root feature layout, and implement transform-specific methods such as ``__call__``, ``inverse``, ``rotate``, ``translate_2d`` and ``create_conditions``. """ def __init__( self, skeleton, fps, stats_path: Optional[str] = None, ): """Initialize feature slicing metadata and optional normalization stats.""" self.skeleton = skeleton self.fps = fps self.nbjoints = skeleton.nbjoints self.feature_names = list(self.size_dict.keys()) self.ps = list(self.size_dict.values()) self.nfeats_dict = {key: val.numel() for key, val in self.size_dict.items()} feats_cumsum = np.cumsum([0] + list(self.nfeats_dict.values())).tolist() self.slice_dict = {key: slice(feats_cumsum[i], feats_cumsum[i + 1]) for i, key in enumerate(self.feature_names)} self.motion_rep_dim = sum(self.nfeats_dict.values()) self.root_slice = slice(0, self.slice_dict[self.last_root_feature].stop) self.body_slice = slice(self.root_slice.stop, self.motion_rep_dim) self.body_dim = self.body_slice.stop - self.body_slice.start self.global_root_dim = self.root_slice.stop self.local_root_dim = sum(val.numel() for val in self.local_root_size_dict.values()) if stats_path: _require_split_stats_layout(stats_path) self.global_root_stats = Stats(os.path.join(stats_path, "global_root")) self.local_root_stats = Stats(os.path.join(stats_path, "local_root")) self.body_stats = Stats(os.path.join(stats_path, "body")) # self.stats not set; normalize/unnormalize apply per-part below def get_root_pos(self, features: torch.Tensor, fallback_to_smooth: bool = True): """Extract root positions from a feature tensor. Supports both ``root_pos`` and ``smooth_root_pos`` representations. """ if "root_pos" in self.slice_dict: return features[..., self.slice_dict["root_pos"]] if "smooth_root_pos" not in self.slice_dict: raise TypeError("This motion rep should have either a root_pos or smooth_root_pos field") if fallback_to_smooth: return features[:, :, self.slice_dict["smooth_root_pos"]] # else compute the root pos from the smooth root and local joints offset smooth_root_pos = features[:, :, self.slice_dict["smooth_root_pos"]].clone() local_joints_positions_flatten = features[..., self.slice_dict["local_joints_positions"]] hips_offset = local_joints_positions_flatten[..., self.skeleton.root_idx : self.skeleton.root_idx + 3] root_pos = torch.stack( [ smooth_root_pos[..., 0] + hips_offset[..., 0], smooth_root_pos[..., 1], smooth_root_pos[..., 2] + hips_offset[..., 2], ], axis=-1, ) return root_pos @ensure_batched(root_features=3, lengths=1) def global_root_to_local_root( self, root_features: torch.Tensor, normalized: bool, lengths: Optional[torch.Tensor], ): """Convert global root features to local-root motion features. Args: root_features: Root feature tensor containing root position and global heading, shaped ``[B, T, D_root]``. normalized: Whether ``root_features`` are normalized. lengths: Optional valid lengths per sequence. Returns: Tensor ``[B, T, 4]`` with local root rotational velocity, planar velocity, and global root height. """ if normalized: root_features = self.global_root_stats.unnormalize(root_features) [root_pos, global_root_heading] = einops.unpack(root_features, self.ps[:2], "batch time *") cos, sin = global_root_heading.unbind(-1) heading_angle = torch.arctan2(sin, cos) local_root_rot_vel = compute_vel_angle(heading_angle, self.fps, lengths=lengths) local_root_vel = compute_vel_xyz( root_pos[..., None, :], self.fps, lengths=lengths, )[..., 0, [0, 2]] global_root_y = root_pos[..., 1] local_root_motion = torch.cat( [ local_root_rot_vel[..., None], local_root_vel, global_root_y[..., None], ], axis=-1, ) if normalized: local_root_motion = self.local_root_stats.normalize(local_root_motion) return local_root_motion def get_root_heading_angle(self, features: torch.Tensor) -> torch.Tensor: """Compute root heading angle from cosine/sine heading features.""" global_root_heading = features[:, :, self.slice_dict["global_root_heading"]] cos, sin = global_root_heading.unbind(-1) return torch.arctan2(sin, cos) @ensure_batched(features=3) def rotate_to( self, features: torch.Tensor, target_angle: torch.Tensor, return_delta_angle=False, ): """Rotate each sequence so frame-0 heading matches ``target_angle``.""" # rotate so that the first frame angle is the target # it put the motion_rep to the angle current_first_angle = self.get_root_heading_angle(features)[:, 0] delta_angle = target_angle - current_first_angle rotated_features = self.rotate(features, delta_angle) if return_delta_angle: return rotated_features, delta_angle return rotated_features @ensure_batched(features=3) def rotate_to_zero( self, features: torch.Tensor, return_delta_angle=False, ): """Rotate each sequence so frame-0 heading becomes zero.""" target_angle = torch.zeros(len(features), device=features.device) return self.rotate_to(features, target_angle, return_delta_angle=return_delta_angle) @ensure_batched(features=3) def randomize_first_heading( self, features: torch.Tensor, return_delta_angle=False, ) -> torch.Tensor: """Rotate each sequence to a random frame-0 heading.""" target_heading_angle = torch.rand(features.shape[0]) * 2 * np.pi return self.rotate_to( features, target_heading_angle, return_delta_angle=return_delta_angle, ) @ensure_batched(features=3, target_2d_pos=2) def translate_2d_to( self, features: torch.Tensor, target_2d_pos: torch.Tensor, return_delta_pos: bool = False, ) -> torch.Tensor: """Translate each sequence so frame-0 root ``(x, z)`` matches a target.""" root_pos = self.get_root_pos(features) current_first_2d_pos = root_pos[:, 0, [0, 2]].clone() delta_2d_pos = target_2d_pos - current_first_2d_pos translated_features = self.translate_2d(features, delta_2d_pos) if return_delta_pos: return translated_features, delta_2d_pos return translated_features @ensure_batched(features=3) def translate_2d_to_zero( self, features: torch.Tensor, return_delta_pos: bool = False, ) -> torch.Tensor: """Translate each sequence so frame-0 root ``(x, z)`` is at the origin.""" target_2d_pos = torch.zeros(len(features), 2, device=features.device) return self.translate_2d_to(features, target_2d_pos, return_delta_pos=return_delta_pos) @ensure_batched(features=3) def canonicalize(self, features: torch.Tensor): """Canonicalize heading and planar position at frame 0.""" rotated_features = self.rotate_to_zero(features) return self.translate_2d_to_zero(rotated_features) def normalize(self, features): """Normalize features using per-part stats (global_root, local_root, body).""" gr = slice(0, self.global_root_dim) lr = slice(self.global_root_dim, self.global_root_dim + self.local_root_dim) out = torch.empty_like(features, device=features.device, dtype=features.dtype) out[..., gr] = self.global_root_stats.normalize(features[..., gr]) out[..., lr] = self.local_root_stats.normalize(features[..., lr]) out[..., self.body_slice] = self.body_stats.normalize(features[..., self.body_slice]) return out def unnormalize(self, features): """Undo feature normalization using per-part stats.""" gr = slice(0, self.global_root_dim) lr = slice(self.global_root_dim, self.global_root_dim + self.local_root_dim) out = torch.empty_like(features, device=features.device, dtype=features.dtype) out[..., gr] = self.global_root_stats.unnormalize(features[..., gr]) out[..., lr] = self.local_root_stats.unnormalize(features[..., lr]) out[..., self.body_slice] = self.body_stats.unnormalize(features[..., self.body_slice]) return out def create_conditions_from_constraints( self, constraints_lst: list, length: int, to_normalize: bool, device: str, ): """Create a conditioning tensor and mask from constraint objects.""" index_dict, data_dict = build_condition_dicts(constraints_lst) return self.create_conditions(index_dict, data_dict, length, to_normalize, device) def create_conditions_from_constraints_batched( self, constraints_lst: list | list[list], lengths: torch.Tensor, to_normalize: bool, device: str, ): """Batched version of ``create_conditions_from_constraints``. Supports either one shared constraint list for all batch elements, or a per-sample list of constraint lists. """ num_samples = len(lengths) if not constraints_lst or not isinstance(constraints_lst[0], list): # If no constraints, or constraints are shared across the batch, # build once and repeat. observed_motion, motion_mask = self.create_conditions_from_constraints( constraints_lst, int(lengths.max()), to_normalize, device ) observed_motion = repeat(observed_motion, "t d -> b t d", b=num_samples) motion_mask = repeat(motion_mask, "t d -> b t d", b=num_samples) return observed_motion, motion_mask length = int(lengths.max()) observed_motion_lst = [] motion_mask_lst = [] for constraints_lst_el in constraints_lst: observed_motion, motion_mask = self.create_conditions_from_constraints( constraints_lst_el, length, to_normalize, device, ) observed_motion_lst.append(observed_motion) motion_mask_lst.append(motion_mask) observed_motion = torch.stack(observed_motion_lst, axis=0) motion_mask = torch.stack(motion_mask_lst, axis=0) return observed_motion, motion_mask