Spaces:
Runtime error
Runtime error
File size: 12,529 Bytes
6d5047c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 | # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Base motion representation: feature layout, normalization, and conditioning helpers."""
import os
from typing import Optional
import einops
import numpy as np
import torch
from einops import repeat
from ...tools import ensure_batched
from ..conditioning import build_condition_dicts
from ..feature_utils import compute_vel_angle, compute_vel_xyz
from ..stats import Stats
def _require_split_stats_layout(stats_path: str) -> None:
"""Raise if stats_path does not contain the required global_root, local_root, body subdirs."""
subdirs = ("global_root", "local_root", "body")
missing = []
for name in subdirs:
subpath = os.path.join(stats_path, name)
mean_path = os.path.join(subpath, "mean.npy")
if not os.path.isfile(mean_path):
missing.append(f"{subpath}/ (mean.npy)")
if missing:
raise FileNotFoundError(
f"Checkpoint stats must use the split layout with subfolders "
f"global_root/, local_root/, and body/ under '{stats_path}'. "
f"Missing or incomplete: {', '.join(missing)}. "
)
class MotionRepBase:
"""Base class for motion representations used in generation and conditioning.
Subclasses define:
- ``size_dict``: feature blocks and their shapes,
- ``last_root_feature``: last entry of the root block,
- ``local_root_size_dict``: local-root feature layout,
and implement transform-specific methods such as ``__call__``, ``inverse``,
``rotate``, ``translate_2d`` and ``create_conditions``.
"""
def __init__(
self,
skeleton,
fps,
stats_path: Optional[str] = None,
):
"""Initialize feature slicing metadata and optional normalization stats."""
self.skeleton = skeleton
self.fps = fps
self.nbjoints = skeleton.nbjoints
self.feature_names = list(self.size_dict.keys())
self.ps = list(self.size_dict.values())
self.nfeats_dict = {key: val.numel() for key, val in self.size_dict.items()}
feats_cumsum = np.cumsum([0] + list(self.nfeats_dict.values())).tolist()
self.slice_dict = {key: slice(feats_cumsum[i], feats_cumsum[i + 1]) for i, key in enumerate(self.feature_names)}
self.motion_rep_dim = sum(self.nfeats_dict.values())
self.root_slice = slice(0, self.slice_dict[self.last_root_feature].stop)
self.body_slice = slice(self.root_slice.stop, self.motion_rep_dim)
self.body_dim = self.body_slice.stop - self.body_slice.start
self.global_root_dim = self.root_slice.stop
self.local_root_dim = sum(val.numel() for val in self.local_root_size_dict.values())
if stats_path:
_require_split_stats_layout(stats_path)
self.global_root_stats = Stats(os.path.join(stats_path, "global_root"))
self.local_root_stats = Stats(os.path.join(stats_path, "local_root"))
self.body_stats = Stats(os.path.join(stats_path, "body"))
# self.stats not set; normalize/unnormalize apply per-part below
def get_root_pos(self, features: torch.Tensor, fallback_to_smooth: bool = True):
"""Extract root positions from a feature tensor.
Supports both ``root_pos`` and ``smooth_root_pos`` representations.
"""
if "root_pos" in self.slice_dict:
return features[..., self.slice_dict["root_pos"]]
if "smooth_root_pos" not in self.slice_dict:
raise TypeError("This motion rep should have either a root_pos or smooth_root_pos field")
if fallback_to_smooth:
return features[:, :, self.slice_dict["smooth_root_pos"]]
# else compute the root pos from the smooth root and local joints offset
smooth_root_pos = features[:, :, self.slice_dict["smooth_root_pos"]].clone()
local_joints_positions_flatten = features[..., self.slice_dict["local_joints_positions"]]
hips_offset = local_joints_positions_flatten[..., self.skeleton.root_idx : self.skeleton.root_idx + 3]
root_pos = torch.stack(
[
smooth_root_pos[..., 0] + hips_offset[..., 0],
smooth_root_pos[..., 1],
smooth_root_pos[..., 2] + hips_offset[..., 2],
],
axis=-1,
)
return root_pos
@ensure_batched(root_features=3, lengths=1)
def global_root_to_local_root(
self,
root_features: torch.Tensor,
normalized: bool,
lengths: Optional[torch.Tensor],
):
"""Convert global root features to local-root motion features.
Args:
root_features: Root feature tensor containing root position and
global heading, shaped ``[B, T, D_root]``.
normalized: Whether ``root_features`` are normalized.
lengths: Optional valid lengths per sequence.
Returns:
Tensor ``[B, T, 4]`` with local root rotational velocity, planar
velocity, and global root height.
"""
if normalized:
root_features = self.global_root_stats.unnormalize(root_features)
[root_pos, global_root_heading] = einops.unpack(root_features, self.ps[:2], "batch time *")
cos, sin = global_root_heading.unbind(-1)
heading_angle = torch.arctan2(sin, cos)
local_root_rot_vel = compute_vel_angle(heading_angle, self.fps, lengths=lengths)
local_root_vel = compute_vel_xyz(
root_pos[..., None, :],
self.fps,
lengths=lengths,
)[..., 0, [0, 2]]
global_root_y = root_pos[..., 1]
local_root_motion = torch.cat(
[
local_root_rot_vel[..., None],
local_root_vel,
global_root_y[..., None],
],
axis=-1,
)
if normalized:
local_root_motion = self.local_root_stats.normalize(local_root_motion)
return local_root_motion
def get_root_heading_angle(self, features: torch.Tensor) -> torch.Tensor:
"""Compute root heading angle from cosine/sine heading features."""
global_root_heading = features[:, :, self.slice_dict["global_root_heading"]]
cos, sin = global_root_heading.unbind(-1)
return torch.arctan2(sin, cos)
@ensure_batched(features=3)
def rotate_to(
self,
features: torch.Tensor,
target_angle: torch.Tensor,
return_delta_angle=False,
):
"""Rotate each sequence so frame-0 heading matches ``target_angle``."""
# rotate so that the first frame angle is the target
# it put the motion_rep to the angle
current_first_angle = self.get_root_heading_angle(features)[:, 0]
delta_angle = target_angle - current_first_angle
rotated_features = self.rotate(features, delta_angle)
if return_delta_angle:
return rotated_features, delta_angle
return rotated_features
@ensure_batched(features=3)
def rotate_to_zero(
self,
features: torch.Tensor,
return_delta_angle=False,
):
"""Rotate each sequence so frame-0 heading becomes zero."""
target_angle = torch.zeros(len(features), device=features.device)
return self.rotate_to(features, target_angle, return_delta_angle=return_delta_angle)
@ensure_batched(features=3)
def randomize_first_heading(
self,
features: torch.Tensor,
return_delta_angle=False,
) -> torch.Tensor:
"""Rotate each sequence to a random frame-0 heading."""
target_heading_angle = torch.rand(features.shape[0]) * 2 * np.pi
return self.rotate_to(
features,
target_heading_angle,
return_delta_angle=return_delta_angle,
)
@ensure_batched(features=3, target_2d_pos=2)
def translate_2d_to(
self,
features: torch.Tensor,
target_2d_pos: torch.Tensor,
return_delta_pos: bool = False,
) -> torch.Tensor:
"""Translate each sequence so frame-0 root ``(x, z)`` matches a target."""
root_pos = self.get_root_pos(features)
current_first_2d_pos = root_pos[:, 0, [0, 2]].clone()
delta_2d_pos = target_2d_pos - current_first_2d_pos
translated_features = self.translate_2d(features, delta_2d_pos)
if return_delta_pos:
return translated_features, delta_2d_pos
return translated_features
@ensure_batched(features=3)
def translate_2d_to_zero(
self,
features: torch.Tensor,
return_delta_pos: bool = False,
) -> torch.Tensor:
"""Translate each sequence so frame-0 root ``(x, z)`` is at the origin."""
target_2d_pos = torch.zeros(len(features), 2, device=features.device)
return self.translate_2d_to(features, target_2d_pos, return_delta_pos=return_delta_pos)
@ensure_batched(features=3)
def canonicalize(self, features: torch.Tensor):
"""Canonicalize heading and planar position at frame 0."""
rotated_features = self.rotate_to_zero(features)
return self.translate_2d_to_zero(rotated_features)
def normalize(self, features):
"""Normalize features using per-part stats (global_root, local_root, body)."""
gr = slice(0, self.global_root_dim)
lr = slice(self.global_root_dim, self.global_root_dim + self.local_root_dim)
out = torch.empty_like(features, device=features.device, dtype=features.dtype)
out[..., gr] = self.global_root_stats.normalize(features[..., gr])
out[..., lr] = self.local_root_stats.normalize(features[..., lr])
out[..., self.body_slice] = self.body_stats.normalize(features[..., self.body_slice])
return out
def unnormalize(self, features):
"""Undo feature normalization using per-part stats."""
gr = slice(0, self.global_root_dim)
lr = slice(self.global_root_dim, self.global_root_dim + self.local_root_dim)
out = torch.empty_like(features, device=features.device, dtype=features.dtype)
out[..., gr] = self.global_root_stats.unnormalize(features[..., gr])
out[..., lr] = self.local_root_stats.unnormalize(features[..., lr])
out[..., self.body_slice] = self.body_stats.unnormalize(features[..., self.body_slice])
return out
def create_conditions_from_constraints(
self,
constraints_lst: list,
length: int,
to_normalize: bool,
device: str,
):
"""Create a conditioning tensor and mask from constraint objects."""
index_dict, data_dict = build_condition_dicts(constraints_lst)
return self.create_conditions(index_dict, data_dict, length, to_normalize, device)
def create_conditions_from_constraints_batched(
self,
constraints_lst: list | list[list],
lengths: torch.Tensor,
to_normalize: bool,
device: str,
):
"""Batched version of ``create_conditions_from_constraints``.
Supports either one shared constraint list for all batch elements, or a per-sample list of
constraint lists.
"""
num_samples = len(lengths)
if not constraints_lst or not isinstance(constraints_lst[0], list):
# If no constraints, or constraints are shared across the batch,
# build once and repeat.
observed_motion, motion_mask = self.create_conditions_from_constraints(
constraints_lst, int(lengths.max()), to_normalize, device
)
observed_motion = repeat(observed_motion, "t d -> b t d", b=num_samples)
motion_mask = repeat(motion_mask, "t d -> b t d", b=num_samples)
return observed_motion, motion_mask
length = int(lengths.max())
observed_motion_lst = []
motion_mask_lst = []
for constraints_lst_el in constraints_lst:
observed_motion, motion_mask = self.create_conditions_from_constraints(
constraints_lst_el,
length,
to_normalize,
device,
)
observed_motion_lst.append(observed_motion)
motion_mask_lst.append(motion_mask)
observed_motion = torch.stack(observed_motion_lst, axis=0)
motion_mask = torch.stack(motion_mask_lst, axis=0)
return observed_motion, motion_mask
|