Spaces:
Running on Zero
Running on Zero
Kimodo Bot commited on
Commit ·
6d5047c
1
Parent(s): d6cb863
Add core kimodo package modules required by native demo
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- kimodo/__init__.py +11 -0
- kimodo/assets.py +19 -0
- kimodo/constraints.py +625 -0
- kimodo/exports/__init__.py +65 -0
- kimodo/exports/bvh.py +282 -0
- kimodo/exports/motion_convert_lib.py +155 -0
- kimodo/exports/motion_formats.py +78 -0
- kimodo/exports/motion_io.py +443 -0
- kimodo/exports/mujoco.py +588 -0
- kimodo/exports/smplx.py +251 -0
- kimodo/geometry.py +216 -0
- kimodo/meta.py +80 -0
- kimodo/metrics/__init__.py +39 -0
- kimodo/metrics/base.py +66 -0
- kimodo/metrics/constraints.py +87 -0
- kimodo/metrics/foot_skate.py +232 -0
- kimodo/metrics/tmr.py +530 -0
- kimodo/model/__init__.py +31 -0
- kimodo/model/backbone.py +312 -0
- kimodo/model/cfg.py +133 -0
- kimodo/model/common.py +48 -0
- kimodo/model/diffusion.py +133 -0
- kimodo/model/kimodo_model.py +605 -0
- kimodo/model/llm2vec/README.md +1 -0
- kimodo/model/llm2vec/__init__.py +11 -0
- kimodo/model/llm2vec/llm2vec.py +477 -0
- kimodo/model/llm2vec/llm2vec_wrapper.py +73 -0
- kimodo/model/llm2vec/models/__init__.py +4 -0
- kimodo/model/llm2vec/models/attn_mask_utils.py +181 -0
- kimodo/model/llm2vec/models/bidirectional_llama.py +224 -0
- kimodo/model/llm2vec/models/utils.py +32 -0
- kimodo/model/load_model.py +194 -0
- kimodo/model/loading.py +81 -0
- kimodo/model/registry.py +473 -0
- kimodo/model/text_encoder_api.py +74 -0
- kimodo/model/tmr.py +382 -0
- kimodo/model/twostage_denoiser.py +153 -0
- kimodo/motion_rep/__init__.py +11 -0
- kimodo/motion_rep/conditioning.py +28 -0
- kimodo/motion_rep/feature_utils.py +212 -0
- kimodo/motion_rep/feet.py +60 -0
- kimodo/motion_rep/reps/__init__.py +13 -0
- kimodo/motion_rep/reps/base.py +300 -0
- kimodo/motion_rep/reps/kimodo_motionrep.py +301 -0
- kimodo/motion_rep/reps/tmr_motionrep.py +222 -0
- kimodo/motion_rep/smooth_root.py +234 -0
- kimodo/motion_rep/stats.py +123 -0
- kimodo/pipeline/__init__.py +28 -0
- kimodo/pipeline/blend_quality.py +116 -0
- kimodo/pipeline/scheduler_runtime.py +139 -0
kimodo/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Kimodo: text-driven and constrained motion generation model."""
|
| 4 |
+
|
| 5 |
+
from .model.load_model import AVAILABLE_MODELS, DEFAULT_MODEL, load_model
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"AVAILABLE_MODELS",
|
| 9 |
+
"DEFAULT_MODEL",
|
| 10 |
+
"load_model",
|
| 11 |
+
]
|
kimodo/assets.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
PACKAGE_ROOT = Path(__file__).resolve().parent
|
| 7 |
+
ASSETS_ROOT = PACKAGE_ROOT / "assets"
|
| 8 |
+
DEMO_ASSETS_ROOT = ASSETS_ROOT / "demo"
|
| 9 |
+
DEMO_EXAMPLES_ROOT = DEMO_ASSETS_ROOT / "examples"
|
| 10 |
+
SKELETONS_ROOT = ASSETS_ROOT / "skeletons"
|
| 11 |
+
SOMA_ASSETS_ROOT = ASSETS_ROOT / "SOMA"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def skeleton_asset_path(*parts: str) -> Path:
|
| 15 |
+
return SKELETONS_ROOT.joinpath(*parts)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def demo_asset_path(*parts: str) -> Path:
|
| 19 |
+
return DEMO_ASSETS_ROOT.joinpath(*parts)
|
kimodo/constraints.py
ADDED
|
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Constraint sets for conditioning motion generation (root 2D, full body, end-effectors)."""
|
| 4 |
+
|
| 5 |
+
from typing import Optional, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
from kimodo.motion_rep.feature_utils import compute_heading_angle
|
| 11 |
+
from kimodo.skeleton import SkeletonBase, SOMASkeleton30, SOMASkeleton77
|
| 12 |
+
from kimodo.tools import ensure_batched, load_json, save_json
|
| 13 |
+
|
| 14 |
+
from .geometry import axis_angle_to_matrix, matrix_to_axis_angle
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _convert_constraint_local_rots_to_skeleton(local_rot_mats: Tensor, skeleton: SkeletonBase) -> Tensor:
|
| 18 |
+
"""Convert loaded local rotation matrices to match the skeleton's joint count.
|
| 19 |
+
|
| 20 |
+
Handles SOMA 30↔77: constraint files may have been saved with 30 or 77 joints while the session
|
| 21 |
+
skeleton (e.g. from the SOMA30 model) uses SOMASkeleton77.
|
| 22 |
+
"""
|
| 23 |
+
n_joints = local_rot_mats.shape[-3]
|
| 24 |
+
skeleton_joints = skeleton.nbjoints
|
| 25 |
+
if n_joints == skeleton_joints:
|
| 26 |
+
return local_rot_mats
|
| 27 |
+
if n_joints == 77 and skeleton_joints == 30 and isinstance(skeleton, SOMASkeleton30):
|
| 28 |
+
return skeleton.from_SOMASkeleton77(local_rot_mats)
|
| 29 |
+
if n_joints == 30 and skeleton_joints == 77 and isinstance(skeleton, SOMASkeleton77):
|
| 30 |
+
skel30 = SOMASkeleton30()
|
| 31 |
+
return skel30.to_SOMASkeleton77(local_rot_mats)
|
| 32 |
+
raise ValueError(
|
| 33 |
+
f"Constraint joint count ({n_joints}) does not match skeleton joint count "
|
| 34 |
+
f"({skeleton_joints}). Only SOMA 30↔77 conversion is supported."
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def create_pairs(tensor_A: Tensor, tensor_B: Tensor) -> Tensor:
|
| 39 |
+
"""Form all (a, b) pairs from two 1D tensors; output shape (len(A)*len(B), 2)."""
|
| 40 |
+
pairs = torch.stack(
|
| 41 |
+
(
|
| 42 |
+
tensor_A[:, None].expand(-1, len(tensor_B)),
|
| 43 |
+
tensor_B.expand(len(tensor_A), -1),
|
| 44 |
+
),
|
| 45 |
+
dim=-1,
|
| 46 |
+
).reshape(-1, 2)
|
| 47 |
+
return pairs
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def compute_global_heading(global_joints_positions: Tensor, skeleton: SkeletonBase) -> Tensor:
|
| 51 |
+
"""Compute global root heading (cos, sin) from global joint positions using skeleton."""
|
| 52 |
+
root_heading_angle = compute_heading_angle(global_joints_positions, skeleton)
|
| 53 |
+
global_root_heading = torch.stack([torch.cos(root_heading_angle), torch.sin(root_heading_angle)], dim=-1)
|
| 54 |
+
return global_root_heading
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _tensor_to(
|
| 58 |
+
t: Tensor,
|
| 59 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 60 |
+
dtype: Optional[torch.dtype] = None,
|
| 61 |
+
) -> Tensor:
|
| 62 |
+
"""Move tensor to device and/or dtype.
|
| 63 |
+
|
| 64 |
+
Returns same tensor if no args.
|
| 65 |
+
"""
|
| 66 |
+
if device is not None and dtype is not None:
|
| 67 |
+
return t.to(device=device, dtype=dtype)
|
| 68 |
+
if device is not None:
|
| 69 |
+
return t.to(device=device)
|
| 70 |
+
if dtype is not None:
|
| 71 |
+
return t.to(dtype=dtype)
|
| 72 |
+
return t
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class Root2DConstraintSet:
|
| 76 |
+
"""Constraint set fixing root (x, z) trajectory and optionally global heading on given
|
| 77 |
+
frames."""
|
| 78 |
+
|
| 79 |
+
name = "root2d"
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
skeleton: SkeletonBase,
|
| 84 |
+
frame_indices: Tensor,
|
| 85 |
+
smooth_root_2d: Tensor,
|
| 86 |
+
to_crop: bool = False,
|
| 87 |
+
global_root_heading: Optional[Tensor] = None,
|
| 88 |
+
) -> None:
|
| 89 |
+
self.skeleton = skeleton
|
| 90 |
+
|
| 91 |
+
# if we pass the full smooth root 3D as input
|
| 92 |
+
if smooth_root_2d.shape[-1] == 3:
|
| 93 |
+
smooth_root_2d = smooth_root_2d[..., [0, 1]]
|
| 94 |
+
|
| 95 |
+
if to_crop:
|
| 96 |
+
smooth_root_2d = smooth_root_2d[frame_indices]
|
| 97 |
+
if global_root_heading is not None:
|
| 98 |
+
global_root_heading = global_root_heading[frame_indices]
|
| 99 |
+
else:
|
| 100 |
+
assert len(smooth_root_2d) == len(
|
| 101 |
+
frame_indices
|
| 102 |
+
), "The number of smooth root 2d should be match the number of frames"
|
| 103 |
+
if global_root_heading is not None:
|
| 104 |
+
assert len(global_root_heading) == len(
|
| 105 |
+
frame_indices
|
| 106 |
+
), "The number of global root heading should be match the number of frames"
|
| 107 |
+
|
| 108 |
+
self.smooth_root_2d = smooth_root_2d
|
| 109 |
+
self.global_root_heading = global_root_heading
|
| 110 |
+
self.frame_indices = frame_indices
|
| 111 |
+
|
| 112 |
+
def update_constraints(self, data_dict: dict, index_dict: dict) -> None:
|
| 113 |
+
"""Append this constraint's smooth_root_2d (and optional global_root_heading) to data/index
|
| 114 |
+
dicts."""
|
| 115 |
+
data_dict["smooth_root_2d"].append(self.smooth_root_2d)
|
| 116 |
+
index_dict["smooth_root_2d"].append(self.frame_indices)
|
| 117 |
+
|
| 118 |
+
if self.global_root_heading is not None:
|
| 119 |
+
# constraint the global heading
|
| 120 |
+
data_dict["global_root_heading"].append(self.global_root_heading)
|
| 121 |
+
index_dict["global_root_heading"].append(self.frame_indices)
|
| 122 |
+
|
| 123 |
+
def crop_move(self, start: int, end: int) -> "Root2DConstraintSet":
|
| 124 |
+
"""Return a new constraint set for the cropped frame range [start, end)."""
|
| 125 |
+
mask = (self.frame_indices >= start) & (self.frame_indices < end)
|
| 126 |
+
|
| 127 |
+
if self.global_root_heading is not None:
|
| 128 |
+
masked_global_root_heading = self.global_root_heading[mask]
|
| 129 |
+
else:
|
| 130 |
+
masked_global_root_heading = None
|
| 131 |
+
|
| 132 |
+
return Root2DConstraintSet(
|
| 133 |
+
self.skeleton,
|
| 134 |
+
self.frame_indices[mask] - start,
|
| 135 |
+
self.smooth_root_2d[mask],
|
| 136 |
+
global_root_heading=masked_global_root_heading,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def get_save_info(self) -> dict:
|
| 140 |
+
"""Return a dict suitable for JSON serialization (frame_indices, smooth_root_2d, optional
|
| 141 |
+
global_root_heading)."""
|
| 142 |
+
out = {
|
| 143 |
+
"type": self.name,
|
| 144 |
+
"frame_indices": self.frame_indices,
|
| 145 |
+
"smooth_root_2d": self.smooth_root_2d,
|
| 146 |
+
}
|
| 147 |
+
if self.global_root_heading is not None:
|
| 148 |
+
out["global_root_heading"] = self.global_root_heading
|
| 149 |
+
return out
|
| 150 |
+
|
| 151 |
+
def to(
|
| 152 |
+
self,
|
| 153 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 154 |
+
dtype: Optional[torch.dtype] = None,
|
| 155 |
+
) -> "Root2DConstraintSet":
|
| 156 |
+
self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype)
|
| 157 |
+
self.frame_indices = _tensor_to(self.frame_indices, device, dtype)
|
| 158 |
+
if self.global_root_heading is not None:
|
| 159 |
+
self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype)
|
| 160 |
+
if device is not None and hasattr(self.skeleton, "to"):
|
| 161 |
+
self.skeleton = self.skeleton.to(device)
|
| 162 |
+
return self
|
| 163 |
+
|
| 164 |
+
@classmethod
|
| 165 |
+
def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> "Root2DConstraintSet":
|
| 166 |
+
"""Build a Root2DConstraintSet from a dict (e.g. loaded from JSON)."""
|
| 167 |
+
device = skeleton.device if hasattr(skeleton, "device") else "cpu"
|
| 168 |
+
|
| 169 |
+
if "global_root_heading" in dico:
|
| 170 |
+
global_root_heading = torch.tensor(dico["global_root_heading"], device=device)
|
| 171 |
+
else:
|
| 172 |
+
global_root_heading = None
|
| 173 |
+
|
| 174 |
+
return cls(
|
| 175 |
+
skeleton,
|
| 176 |
+
frame_indices=torch.tensor(dico["frame_indices"]),
|
| 177 |
+
smooth_root_2d=torch.tensor(dico["smooth_root_2d"], device=device),
|
| 178 |
+
global_root_heading=global_root_heading,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class FullBodyConstraintSet:
|
| 183 |
+
"""Constraint set fixing full-body global positions and rotations on given keyframes."""
|
| 184 |
+
|
| 185 |
+
name = "fullbody"
|
| 186 |
+
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
skeleton: SkeletonBase,
|
| 190 |
+
frame_indices: Tensor,
|
| 191 |
+
global_joints_positions: Tensor,
|
| 192 |
+
global_joints_rots: Tensor,
|
| 193 |
+
smooth_root_2d: Optional[Tensor] = None,
|
| 194 |
+
to_crop: bool = False,
|
| 195 |
+
):
|
| 196 |
+
self.skeleton = skeleton
|
| 197 |
+
self.frame_indices = frame_indices
|
| 198 |
+
|
| 199 |
+
# if we pass the full smooth root 3D as input
|
| 200 |
+
if smooth_root_2d is not None and smooth_root_2d.shape[-1] == 3:
|
| 201 |
+
smooth_root_2d = smooth_root_2d[..., [0, 1]]
|
| 202 |
+
|
| 203 |
+
if to_crop:
|
| 204 |
+
global_joints_positions = global_joints_positions[frame_indices]
|
| 205 |
+
global_joints_rots = global_joints_rots[frame_indices]
|
| 206 |
+
if smooth_root_2d is not None:
|
| 207 |
+
smooth_root_2d = smooth_root_2d[frame_indices]
|
| 208 |
+
else:
|
| 209 |
+
assert len(global_joints_positions) == len(
|
| 210 |
+
frame_indices
|
| 211 |
+
), "The number of global positions should be match the number of frames"
|
| 212 |
+
assert len(global_joints_rots) == len(
|
| 213 |
+
frame_indices
|
| 214 |
+
), "The number of global joint rotations should be match the number of frames"
|
| 215 |
+
|
| 216 |
+
if smooth_root_2d is not None:
|
| 217 |
+
assert len(smooth_root_2d) == len(
|
| 218 |
+
frame_indices
|
| 219 |
+
), "The number of smooth root 2d (if specified) should be match the number of frames"
|
| 220 |
+
|
| 221 |
+
if smooth_root_2d is None:
|
| 222 |
+
# substitute the smooth root 2d with the real root
|
| 223 |
+
smooth_root_2d = global_joints_positions[:, skeleton.root_idx, [0, 2]]
|
| 224 |
+
|
| 225 |
+
# root y: from smooth or pelvis is the same
|
| 226 |
+
self.root_y_pos = global_joints_positions[:, skeleton.root_idx, 1]
|
| 227 |
+
|
| 228 |
+
self.global_joints_positions = global_joints_positions
|
| 229 |
+
self.global_joints_rots = global_joints_rots
|
| 230 |
+
self.global_root_heading = compute_global_heading(global_joints_positions, skeleton)
|
| 231 |
+
self.smooth_root_2d = smooth_root_2d
|
| 232 |
+
|
| 233 |
+
def update_constraints(self, data_dict: dict, index_dict: dict) -> None:
|
| 234 |
+
"""Append global positions, smooth root 2D, root y, and global heading to data/index
|
| 235 |
+
dicts."""
|
| 236 |
+
nbjoints = self.skeleton.nbjoints
|
| 237 |
+
indices_lst = create_pairs(
|
| 238 |
+
self.frame_indices,
|
| 239 |
+
torch.arange(nbjoints, device=self.frame_indices.device),
|
| 240 |
+
)
|
| 241 |
+
data_dict["global_joints_positions"].append(
|
| 242 |
+
self.global_joints_positions.reshape(-1, 3)
|
| 243 |
+
) # flatten the global positions
|
| 244 |
+
index_dict["global_joints_positions"].append(indices_lst)
|
| 245 |
+
|
| 246 |
+
# global rotations are not used here
|
| 247 |
+
|
| 248 |
+
# as we use smooth root, also constraint the smooth root to get the same full body
|
| 249 |
+
# maybe keep storing the hips offset, if we smooth it ourselves
|
| 250 |
+
data_dict["smooth_root_2d"].append(self.smooth_root_2d)
|
| 251 |
+
index_dict["smooth_root_2d"].append(self.frame_indices)
|
| 252 |
+
|
| 253 |
+
# constraint the y pos of the root
|
| 254 |
+
data_dict["root_y_pos"].append(self.root_y_pos)
|
| 255 |
+
index_dict["root_y_pos"].append(self.frame_indices)
|
| 256 |
+
|
| 257 |
+
# constraint the global heading
|
| 258 |
+
data_dict["global_root_heading"].append(self.global_root_heading)
|
| 259 |
+
index_dict["global_root_heading"].append(self.frame_indices)
|
| 260 |
+
|
| 261 |
+
def crop_move(self, start: int, end: int) -> "FullBodyConstraintSet":
|
| 262 |
+
"""Return a new FullBodyConstraintSet for the cropped frame range [start, end)."""
|
| 263 |
+
mask = (self.frame_indices >= start) & (self.frame_indices < end)
|
| 264 |
+
return FullBodyConstraintSet(
|
| 265 |
+
self.skeleton,
|
| 266 |
+
self.frame_indices[mask] - start,
|
| 267 |
+
self.global_joints_positions[mask],
|
| 268 |
+
self.global_joints_rots[mask],
|
| 269 |
+
self.smooth_root_2d[mask],
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
def get_save_info(self) -> dict:
|
| 273 |
+
"""Return a dict for JSON save: type, frame_indices, local_joints_rot, root_positions, smooth_root_2d."""
|
| 274 |
+
local_joints_rot = self.skeleton.global_rots_to_local_rots(self.global_joints_rots)
|
| 275 |
+
if isinstance(self.skeleton, SOMASkeleton30):
|
| 276 |
+
local_joints_rot = self.skeleton.to_SOMASkeleton77(local_joints_rot)
|
| 277 |
+
local_joints_rot = matrix_to_axis_angle(local_joints_rot)
|
| 278 |
+
|
| 279 |
+
root_positions = self.global_joints_positions[:, self.skeleton.root_idx]
|
| 280 |
+
return {
|
| 281 |
+
"type": self.name,
|
| 282 |
+
"frame_indices": self.frame_indices,
|
| 283 |
+
"local_joints_rot": local_joints_rot,
|
| 284 |
+
"root_positions": root_positions,
|
| 285 |
+
"smooth_root_2d": self.smooth_root_2d,
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
def to(
|
| 289 |
+
self,
|
| 290 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 291 |
+
dtype: Optional[torch.dtype] = None,
|
| 292 |
+
) -> "FullBodyConstraintSet":
|
| 293 |
+
self.frame_indices = _tensor_to(self.frame_indices, device, dtype)
|
| 294 |
+
self.global_joints_positions = _tensor_to(self.global_joints_positions, device, dtype)
|
| 295 |
+
self.global_joints_rots = _tensor_to(self.global_joints_rots, device, dtype)
|
| 296 |
+
self.root_y_pos = _tensor_to(self.root_y_pos, device, dtype)
|
| 297 |
+
self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype)
|
| 298 |
+
self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype)
|
| 299 |
+
if device is not None and hasattr(self.skeleton, "to"):
|
| 300 |
+
self.skeleton = self.skeleton.to(device)
|
| 301 |
+
return self
|
| 302 |
+
|
| 303 |
+
@classmethod
|
| 304 |
+
def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> "FullBodyConstraintSet":
|
| 305 |
+
"""Build a FullBodyConstraintSet from a dict (e.g. loaded from JSON)."""
|
| 306 |
+
frame_indices = torch.tensor(dico["frame_indices"])
|
| 307 |
+
device = skeleton.device if hasattr(skeleton, "device") else "cpu"
|
| 308 |
+
local_rot = torch.tensor(dico["local_joints_rot"], device=device)
|
| 309 |
+
local_rot_mats = axis_angle_to_matrix(local_rot)
|
| 310 |
+
local_rot_mats = _convert_constraint_local_rots_to_skeleton(local_rot_mats, skeleton)
|
| 311 |
+
global_joints_rots, global_joints_positions, _ = skeleton.fk(
|
| 312 |
+
local_rot_mats,
|
| 313 |
+
torch.tensor(dico["root_positions"], device=device),
|
| 314 |
+
)
|
| 315 |
+
smooth_root_2d = None
|
| 316 |
+
if "smooth_root_2d" in dico:
|
| 317 |
+
smooth_root_2d = torch.tensor(dico["smooth_root_2d"], device=device)
|
| 318 |
+
|
| 319 |
+
return cls(
|
| 320 |
+
skeleton,
|
| 321 |
+
frame_indices=frame_indices,
|
| 322 |
+
global_joints_positions=global_joints_positions,
|
| 323 |
+
global_joints_rots=global_joints_rots,
|
| 324 |
+
smooth_root_2d=smooth_root_2d,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class EndEffectorConstraintSet:
|
| 329 |
+
"""Constraint set fixing selected end-effector positions and rotations on given frames."""
|
| 330 |
+
|
| 331 |
+
name = "end-effector"
|
| 332 |
+
|
| 333 |
+
def __init__(
|
| 334 |
+
self,
|
| 335 |
+
skeleton: SkeletonBase,
|
| 336 |
+
frame_indices: Tensor,
|
| 337 |
+
global_joints_positions: Tensor,
|
| 338 |
+
global_joints_rots: Tensor,
|
| 339 |
+
smooth_root_2d: Optional[Tensor],
|
| 340 |
+
*,
|
| 341 |
+
joint_names: list[str],
|
| 342 |
+
to_crop: bool = False,
|
| 343 |
+
) -> None:
|
| 344 |
+
self.skeleton = skeleton
|
| 345 |
+
self.frame_indices = frame_indices
|
| 346 |
+
self.joint_names = joint_names
|
| 347 |
+
|
| 348 |
+
# joint_names are constant for all the frames
|
| 349 |
+
rot_joint_names, pos_joint_names = self.skeleton.expand_joint_names(self.joint_names)
|
| 350 |
+
# indexing works for motion_rep with smooth root only (contains pelvis index)
|
| 351 |
+
self.pos_indices = torch.tensor([self.skeleton.bone_index[jname] for jname in pos_joint_names])
|
| 352 |
+
self.rot_indices = torch.tensor([self.skeleton.bone_index[jname] for jname in rot_joint_names])
|
| 353 |
+
|
| 354 |
+
# if we pass the full smooth root 3D as input
|
| 355 |
+
if smooth_root_2d is not None and smooth_root_2d.shape[-1] == 3:
|
| 356 |
+
smooth_root_2d = smooth_root_2d[..., [0, 1]]
|
| 357 |
+
|
| 358 |
+
if to_crop:
|
| 359 |
+
global_joints_positions = global_joints_positions[frame_indices]
|
| 360 |
+
global_joints_rots = global_joints_rots[frame_indices]
|
| 361 |
+
if smooth_root_2d is not None:
|
| 362 |
+
smooth_root_2d = smooth_root_2d[frame_indices]
|
| 363 |
+
else:
|
| 364 |
+
assert len(global_joints_positions) == len(
|
| 365 |
+
frame_indices
|
| 366 |
+
), "The number of global positions should be match the number of frames"
|
| 367 |
+
assert len(global_joints_rots) == len(
|
| 368 |
+
frame_indices
|
| 369 |
+
), "The number of global joint rotations should be match the number of frames"
|
| 370 |
+
if smooth_root_2d is not None:
|
| 371 |
+
assert len(smooth_root_2d) == len(
|
| 372 |
+
frame_indices
|
| 373 |
+
), "The number of smooth root 2d (if specified) should be match the number of frames"
|
| 374 |
+
|
| 375 |
+
if smooth_root_2d is None:
|
| 376 |
+
# substitute the smooth root 2d with the real root
|
| 377 |
+
smooth_root_2d = global_joints_positions[:, skeleton.root_idx, [0, 2]]
|
| 378 |
+
|
| 379 |
+
# root y: from smooth or pelvis is the same
|
| 380 |
+
self.root_y_pos = global_joints_positions[:, skeleton.root_idx, 1]
|
| 381 |
+
|
| 382 |
+
self.global_joints_positions = global_joints_positions
|
| 383 |
+
self.global_root_heading = compute_global_heading(global_joints_positions, skeleton)
|
| 384 |
+
self.global_joints_rots = global_joints_rots
|
| 385 |
+
self.smooth_root_2d = smooth_root_2d
|
| 386 |
+
|
| 387 |
+
def update_constraints(self, data_dict: dict, index_dict: dict) -> None:
|
| 388 |
+
"""Append constrained joint positions/rots, smooth root 2D, root y, and heading to
|
| 389 |
+
data/index dicts."""
|
| 390 |
+
crop_frames_indexing = torch.arange(len(self.frame_indices), device=self.frame_indices.device)
|
| 391 |
+
|
| 392 |
+
# constraint positions
|
| 393 |
+
pos_indices_real = create_pairs(
|
| 394 |
+
self.frame_indices,
|
| 395 |
+
self.pos_indices,
|
| 396 |
+
)
|
| 397 |
+
pos_indices_crop = create_pairs(
|
| 398 |
+
crop_frames_indexing,
|
| 399 |
+
self.pos_indices,
|
| 400 |
+
)
|
| 401 |
+
data_dict["global_joints_positions"].append(self.global_joints_positions[tuple(pos_indices_crop.T)])
|
| 402 |
+
index_dict["global_joints_positions"].append(pos_indices_real)
|
| 403 |
+
|
| 404 |
+
# constraint rotations
|
| 405 |
+
rot_indices_real = create_pairs(
|
| 406 |
+
self.frame_indices,
|
| 407 |
+
self.rot_indices,
|
| 408 |
+
)
|
| 409 |
+
rot_indices_crop = create_pairs(
|
| 410 |
+
crop_frames_indexing,
|
| 411 |
+
self.rot_indices,
|
| 412 |
+
)
|
| 413 |
+
data_dict["global_joints_rots"].append(self.global_joints_rots[tuple(rot_indices_crop.T)])
|
| 414 |
+
index_dict["global_joints_rots"].append(rot_indices_real)
|
| 415 |
+
|
| 416 |
+
# as we use smooth root, also constraint the smooth root to get the same full body
|
| 417 |
+
# maybe keep storing the hips offset, if we smooth it ourselves
|
| 418 |
+
data_dict["smooth_root_2d"].append(self.smooth_root_2d)
|
| 419 |
+
index_dict["smooth_root_2d"].append(self.frame_indices)
|
| 420 |
+
|
| 421 |
+
# constraint the y pos of the root
|
| 422 |
+
data_dict["root_y_pos"].append(self.root_y_pos)
|
| 423 |
+
index_dict["root_y_pos"].append(self.frame_indices)
|
| 424 |
+
|
| 425 |
+
# constraint the global heading
|
| 426 |
+
data_dict["global_root_heading"].append(self.global_root_heading)
|
| 427 |
+
index_dict["global_root_heading"].append(self.frame_indices)
|
| 428 |
+
|
| 429 |
+
def crop_move(self, start: int, end: int) -> "EndEffectorConstraintSet":
|
| 430 |
+
"""Return a new EndEffectorConstraintSet for the cropped frame range [start, end)."""
|
| 431 |
+
mask = (self.frame_indices >= start) & (self.frame_indices < end)
|
| 432 |
+
|
| 433 |
+
cls = type(self)
|
| 434 |
+
kwargs = {}
|
| 435 |
+
if not hasattr(cls, "joint_names"):
|
| 436 |
+
kwargs["joint_names"] = self.joint_names
|
| 437 |
+
|
| 438 |
+
return cls(
|
| 439 |
+
self.skeleton,
|
| 440 |
+
self.frame_indices[mask] - start,
|
| 441 |
+
self.global_joints_positions[mask],
|
| 442 |
+
self.global_joints_rots[mask],
|
| 443 |
+
self.smooth_root_2d[mask],
|
| 444 |
+
**kwargs,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
def get_save_info(self) -> dict:
|
| 448 |
+
"""Return a dict for JSON save: type, frame_indices, local_joints_rot, root_positions, smooth_root_2d, joint_names."""
|
| 449 |
+
local_joints_rot = self.skeleton.global_rots_to_local_rots(self.global_joints_rots)
|
| 450 |
+
if isinstance(self.skeleton, SOMASkeleton30):
|
| 451 |
+
local_joints_rot = self.skeleton.to_SOMASkeleton77(local_joints_rot)
|
| 452 |
+
local_joints_rot = matrix_to_axis_angle(local_joints_rot)
|
| 453 |
+
|
| 454 |
+
root_positions = self.global_joints_positions[:, self.skeleton.root_idx]
|
| 455 |
+
output = {
|
| 456 |
+
"type": self.name,
|
| 457 |
+
"frame_indices": self.frame_indices,
|
| 458 |
+
"local_joints_rot": local_joints_rot,
|
| 459 |
+
"root_positions": root_positions,
|
| 460 |
+
"smooth_root_2d": self.smooth_root_2d,
|
| 461 |
+
}
|
| 462 |
+
if not hasattr(self.__class__, "joint_names"):
|
| 463 |
+
# save the joint_names for this base class
|
| 464 |
+
# but not for children
|
| 465 |
+
output["joint_names"] = self.joint_names
|
| 466 |
+
return output
|
| 467 |
+
|
| 468 |
+
def to(
|
| 469 |
+
self,
|
| 470 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 471 |
+
dtype: Optional[torch.dtype] = None,
|
| 472 |
+
) -> "EndEffectorConstraintSet":
|
| 473 |
+
self.frame_indices = _tensor_to(self.frame_indices, device, dtype)
|
| 474 |
+
self.pos_indices = _tensor_to(self.pos_indices, device, dtype)
|
| 475 |
+
self.rot_indices = _tensor_to(self.rot_indices, device, dtype)
|
| 476 |
+
self.root_y_pos = _tensor_to(self.root_y_pos, device, dtype)
|
| 477 |
+
self.global_joints_positions = _tensor_to(self.global_joints_positions, device, dtype)
|
| 478 |
+
self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype)
|
| 479 |
+
self.global_joints_rots = _tensor_to(self.global_joints_rots, device, dtype)
|
| 480 |
+
self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype)
|
| 481 |
+
if device is not None and hasattr(self.skeleton, "to"):
|
| 482 |
+
self.skeleton = self.skeleton.to(device)
|
| 483 |
+
return self
|
| 484 |
+
|
| 485 |
+
@classmethod
|
| 486 |
+
def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> "EndEffectorConstraintSet":
|
| 487 |
+
"""Build an EndEffectorConstraintSet from a dict (e.g. loaded from JSON)."""
|
| 488 |
+
frame_indices = torch.tensor(dico["frame_indices"])
|
| 489 |
+
device = skeleton.device if hasattr(skeleton, "device") else "cpu"
|
| 490 |
+
local_rot = torch.tensor(dico["local_joints_rot"], device=device)
|
| 491 |
+
local_rot_mats = axis_angle_to_matrix(local_rot)
|
| 492 |
+
local_rot_mats = _convert_constraint_local_rots_to_skeleton(local_rot_mats, skeleton)
|
| 493 |
+
global_joints_rots, global_joints_positions, _ = skeleton.fk(
|
| 494 |
+
local_rot_mats,
|
| 495 |
+
torch.tensor(dico["root_positions"], device=device),
|
| 496 |
+
)
|
| 497 |
+
smooth_root_2d = None
|
| 498 |
+
if "smooth_root_2d" in dico:
|
| 499 |
+
smooth_root_2d = torch.tensor(dico["smooth_root_2d"], device=device)
|
| 500 |
+
|
| 501 |
+
kwargs = {}
|
| 502 |
+
if not hasattr(cls, "joint_names"):
|
| 503 |
+
kwargs["joint_names"] = dico["joint_names"]
|
| 504 |
+
|
| 505 |
+
return cls(
|
| 506 |
+
skeleton,
|
| 507 |
+
frame_indices=frame_indices,
|
| 508 |
+
global_joints_positions=global_joints_positions,
|
| 509 |
+
global_joints_rots=global_joints_rots,
|
| 510 |
+
smooth_root_2d=smooth_root_2d,
|
| 511 |
+
**kwargs,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class LeftHandConstraintSet(EndEffectorConstraintSet):
|
| 516 |
+
"""End-effector constraint for the left hand only."""
|
| 517 |
+
|
| 518 |
+
name = "left-hand"
|
| 519 |
+
joint_names: list[str] = ["LeftHand"]
|
| 520 |
+
|
| 521 |
+
def __init__(self, *args, **kwargs: dict):
|
| 522 |
+
super().__init__(*args, joint_names=self.joint_names, **kwargs)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
class RightHandConstraintSet(EndEffectorConstraintSet):
|
| 526 |
+
"""End-effector constraint for the right hand only."""
|
| 527 |
+
|
| 528 |
+
name = "right-hand"
|
| 529 |
+
joint_names: list[str] = ["RightHand"]
|
| 530 |
+
|
| 531 |
+
def __init__(self, *args, **kwargs: dict):
|
| 532 |
+
super().__init__(*args, joint_names=self.joint_names, **kwargs)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
class LeftFootConstraintSet(EndEffectorConstraintSet):
|
| 536 |
+
"""End-effector constraint for the left foot only."""
|
| 537 |
+
|
| 538 |
+
name = "left-foot"
|
| 539 |
+
joint_names: list[str] = ["LeftFoot"]
|
| 540 |
+
|
| 541 |
+
def __init__(self, *args, **kwargs: dict):
|
| 542 |
+
super().__init__(*args, joint_names=self.joint_names, **kwargs)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
class RightFootConstraintSet(EndEffectorConstraintSet):
|
| 546 |
+
"""End-effector constraint for the right foot only."""
|
| 547 |
+
|
| 548 |
+
name = "right-foot"
|
| 549 |
+
joint_names: list[str] = ["RightFoot"]
|
| 550 |
+
|
| 551 |
+
def __init__(self, *args, **kwargs: dict):
|
| 552 |
+
super().__init__(*args, joint_names=self.joint_names, **kwargs)
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
TYPE_TO_CLASS = {
|
| 556 |
+
"root2d": Root2DConstraintSet,
|
| 557 |
+
"fullbody": FullBodyConstraintSet,
|
| 558 |
+
"left-hand": LeftHandConstraintSet,
|
| 559 |
+
"right-hand": RightHandConstraintSet,
|
| 560 |
+
"left-foot": LeftFootConstraintSet,
|
| 561 |
+
"right-foot": RightFootConstraintSet,
|
| 562 |
+
"end-effector": EndEffectorConstraintSet,
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def load_constraints_lst(
|
| 567 |
+
path_or_data: str | list,
|
| 568 |
+
skeleton: SkeletonBase,
|
| 569 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 570 |
+
dtype: Optional[torch.dtype] = None,
|
| 571 |
+
):
|
| 572 |
+
"""Load a list of constraints from JSON path or list of dicts.
|
| 573 |
+
|
| 574 |
+
Args:
|
| 575 |
+
path_or_data: Path to constraints.json or list of constraint dicts.
|
| 576 |
+
skeleton: Skeleton instance (used for from_dict).
|
| 577 |
+
device: If set, move all constraint tensors and skeleton to this device.
|
| 578 |
+
dtype: If set, cast constraint tensors to this dtype.
|
| 579 |
+
"""
|
| 580 |
+
if isinstance(path_or_data, str):
|
| 581 |
+
saved = load_json(path_or_data)
|
| 582 |
+
else:
|
| 583 |
+
saved = path_or_data
|
| 584 |
+
|
| 585 |
+
constraints_lst = []
|
| 586 |
+
for el in saved:
|
| 587 |
+
cls = TYPE_TO_CLASS[el["type"]]
|
| 588 |
+
c = cls.from_dict(skeleton, el)
|
| 589 |
+
if device is not None or dtype is not None:
|
| 590 |
+
c.to(device=device, dtype=dtype)
|
| 591 |
+
constraints_lst.append(c)
|
| 592 |
+
return constraints_lst
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def save_constraints_lst(path: str, constraints_lst: list) -> list | None:
|
| 596 |
+
"""Save a list of constraint sets to a JSON file.
|
| 597 |
+
|
| 598 |
+
Returns None if list is empty.
|
| 599 |
+
"""
|
| 600 |
+
if not constraints_lst:
|
| 601 |
+
print("The constraints lst is empty. Skip saving")
|
| 602 |
+
return
|
| 603 |
+
|
| 604 |
+
to_save = []
|
| 605 |
+
|
| 606 |
+
def tensor_to_list(obj):
|
| 607 |
+
"""Recursively convert tensors to lists for JSON serialization."""
|
| 608 |
+
if isinstance(obj, Tensor):
|
| 609 |
+
return obj.cpu().tolist()
|
| 610 |
+
elif isinstance(obj, dict):
|
| 611 |
+
return {k: tensor_to_list(v) for k, v in obj.items()}
|
| 612 |
+
elif isinstance(obj, list):
|
| 613 |
+
return [tensor_to_list(v) for v in obj]
|
| 614 |
+
else:
|
| 615 |
+
return obj
|
| 616 |
+
|
| 617 |
+
for constraint in constraints_lst:
|
| 618 |
+
constraint_info = constraint.get_save_info()
|
| 619 |
+
# Convert all tensors to lists for JSON serialization
|
| 620 |
+
constraint_info = tensor_to_list(constraint_info)
|
| 621 |
+
to_save.append(constraint_info)
|
| 622 |
+
|
| 623 |
+
save_json(path, to_save)
|
| 624 |
+
print(f"Saved constraints to {path}")
|
| 625 |
+
return to_save
|
kimodo/exports/__init__.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Export utilities: MuJoCo, BVH, SMPLX/AMASS, and motion I/O helpers."""
|
| 4 |
+
|
| 5 |
+
from .bvh import bvh_to_kimodo_motion, motion_to_bvh_bytes, read_bvh_frame_time_seconds, save_motion_bvh
|
| 6 |
+
from .motion_convert_lib import convert_motion_files
|
| 7 |
+
from .motion_formats import (
|
| 8 |
+
infer_npz_kind,
|
| 9 |
+
infer_source_format_from_path,
|
| 10 |
+
infer_target_format_from_path,
|
| 11 |
+
resolve_source_fps,
|
| 12 |
+
)
|
| 13 |
+
from .motion_io import (
|
| 14 |
+
KIMODO_CONVERT_TARGET_FPS,
|
| 15 |
+
amass_npz_to_bytes,
|
| 16 |
+
complete_motion_dict,
|
| 17 |
+
g1_csv_to_bytes,
|
| 18 |
+
kimodo_npz_to_bytes,
|
| 19 |
+
load_amass_npz,
|
| 20 |
+
load_g1_csv,
|
| 21 |
+
load_kimodo_npz,
|
| 22 |
+
load_kimodo_npz_as_torch,
|
| 23 |
+
load_motion_file,
|
| 24 |
+
motion_dict_to_numpy,
|
| 25 |
+
save_kimodo_npz,
|
| 26 |
+
save_kimodo_npz_at_target_fps,
|
| 27 |
+
)
|
| 28 |
+
from .mujoco import MujocoQposConverter, apply_g1_real_robot_projection
|
| 29 |
+
from .smplx import (
|
| 30 |
+
AMASSConverter,
|
| 31 |
+
amass_npz_to_kimodo_motion,
|
| 32 |
+
get_amass_parameters,
|
| 33 |
+
kimodo_y_up_to_amass_coord_rotation_matrix,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
__all__ = [
|
| 37 |
+
"AMASSConverter",
|
| 38 |
+
"KIMODO_CONVERT_TARGET_FPS",
|
| 39 |
+
"MujocoQposConverter",
|
| 40 |
+
"amass_npz_to_bytes",
|
| 41 |
+
"amass_npz_to_kimodo_motion",
|
| 42 |
+
"apply_g1_real_robot_projection",
|
| 43 |
+
"bvh_to_kimodo_motion",
|
| 44 |
+
"complete_motion_dict",
|
| 45 |
+
"convert_motion_files",
|
| 46 |
+
"g1_csv_to_bytes",
|
| 47 |
+
"get_amass_parameters",
|
| 48 |
+
"infer_npz_kind",
|
| 49 |
+
"infer_source_format_from_path",
|
| 50 |
+
"infer_target_format_from_path",
|
| 51 |
+
"kimodo_npz_to_bytes",
|
| 52 |
+
"kimodo_y_up_to_amass_coord_rotation_matrix",
|
| 53 |
+
"load_amass_npz",
|
| 54 |
+
"load_g1_csv",
|
| 55 |
+
"load_kimodo_npz",
|
| 56 |
+
"load_kimodo_npz_as_torch",
|
| 57 |
+
"load_motion_file",
|
| 58 |
+
"motion_dict_to_numpy",
|
| 59 |
+
"motion_to_bvh_bytes",
|
| 60 |
+
"read_bvh_frame_time_seconds",
|
| 61 |
+
"resolve_source_fps",
|
| 62 |
+
"save_kimodo_npz",
|
| 63 |
+
"save_kimodo_npz_at_target_fps",
|
| 64 |
+
"save_motion_bvh",
|
| 65 |
+
]
|
kimodo/exports/bvh.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Export utilities for converting internal motion representations into common file formats.
|
| 4 |
+
|
| 5 |
+
This module is intended to hold lightweight serialization / export helpers that can be reused
|
| 6 |
+
outside of interactive demos.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import tempfile
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Tuple, Union
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
from kimodo.geometry import matrix_to_quaternion as _matrix_to_quaternion
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _strip_end_site_blocks(bvh_text: str) -> str:
|
| 21 |
+
"""Remove all 'End Site { ... }' blocks from BVH text so output matches original format.
|
| 22 |
+
|
| 23 |
+
bvhio adds an End Site for every leaf joint when writing; we do not set EndSite on joints, so we
|
| 24 |
+
post-process the string to remove these blocks for Blender/original compatibility.
|
| 25 |
+
"""
|
| 26 |
+
lines = bvh_text.splitlines(keepends=True)
|
| 27 |
+
result = []
|
| 28 |
+
i = 0
|
| 29 |
+
while i < len(lines):
|
| 30 |
+
line = lines[i]
|
| 31 |
+
if "End Site" in line:
|
| 32 |
+
# Skip this line and the following block { ... }; brace-count to find closing }
|
| 33 |
+
i += 1
|
| 34 |
+
if i < len(lines) and "{" in lines[i]:
|
| 35 |
+
i += 1
|
| 36 |
+
depth = 1
|
| 37 |
+
while i < len(lines) and depth > 0:
|
| 38 |
+
if "{" in lines[i]:
|
| 39 |
+
depth += 1
|
| 40 |
+
if "}" in lines[i]:
|
| 41 |
+
depth -= 1
|
| 42 |
+
i += 1
|
| 43 |
+
continue
|
| 44 |
+
result.append(line)
|
| 45 |
+
i += 1
|
| 46 |
+
return "".join(result)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _coerce_batch(name: str, x: torch.Tensor, *, expected_ndim: int) -> torch.Tensor:
|
| 50 |
+
"""Coerce (T, ...) or (1, T, ...) into (T, ...)."""
|
| 51 |
+
if x.ndim == expected_ndim:
|
| 52 |
+
return x
|
| 53 |
+
if x.ndim == expected_ndim + 1:
|
| 54 |
+
if int(x.shape[0]) != 1:
|
| 55 |
+
raise ValueError(
|
| 56 |
+
f"{name} has batch dimension B={int(x.shape[0])}, but BVH export " "only supports a single clip (B==1)."
|
| 57 |
+
)
|
| 58 |
+
return x[0]
|
| 59 |
+
raise ValueError(f"{name} must have shape (T, ...) or (1, T, ...); got {tuple(x.shape)}")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def motion_to_bvh(
|
| 63 |
+
local_rot_mats: torch.Tensor,
|
| 64 |
+
root_positions: torch.Tensor,
|
| 65 |
+
*,
|
| 66 |
+
skeleton,
|
| 67 |
+
fps: float,
|
| 68 |
+
) -> str:
|
| 69 |
+
"""Convert local rotations and root positions to BVH format; return UTF-8 string.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
local_rot_mats: (T, J, 3, 3) or (1, T, J, 3, 3) local rotation matrices.
|
| 73 |
+
root_positions: (T, 3) or (1, T, 3) root joint positions (e.g. from posed joints).
|
| 74 |
+
skeleton: Skeleton with bone_order_names, bvh_neutral_joints, etc.
|
| 75 |
+
fps: Frames per second for the motion.
|
| 76 |
+
|
| 77 |
+
Notes:
|
| 78 |
+
BVH is plain-text. Root is named "Root" with ZYX rotation order; leaf joints
|
| 79 |
+
have no End Site block.
|
| 80 |
+
"""
|
| 81 |
+
try:
|
| 82 |
+
import bvhio # type: ignore[import-not-found]
|
| 83 |
+
import glm # type: ignore[import-not-found]
|
| 84 |
+
from SpatialTransform import Pose # type: ignore[import-not-found]
|
| 85 |
+
except Exception as e: # pragma: no cover
|
| 86 |
+
raise ImportError(
|
| 87 |
+
"BVH export requires `bvhio` (and its deps `PyGLM` + `SpatialTransform`). "
|
| 88 |
+
"Install with: `pip install bvhio`."
|
| 89 |
+
) from e
|
| 90 |
+
|
| 91 |
+
local_rot_mats = local_rot_mats.detach()
|
| 92 |
+
root_positions = root_positions.detach()
|
| 93 |
+
# SOMA: accept either somaskel30 (convert to 77) or somaskel77 (use as-is)
|
| 94 |
+
if skeleton.name == "somaskel30":
|
| 95 |
+
local_rot_mats = skeleton.to_SOMASkeleton77(local_rot_mats)
|
| 96 |
+
skeleton = skeleton.somaskel77
|
| 97 |
+
|
| 98 |
+
local_rot_mats, _ = skeleton.from_standard_tpose(local_rot_mats)
|
| 99 |
+
|
| 100 |
+
neutral = skeleton.bvh_neutral_joints.detach().cpu().numpy()
|
| 101 |
+
joint_names = list(skeleton.bone_order_names)
|
| 102 |
+
parents = skeleton.joint_parents.detach().cpu().numpy().astype(int)
|
| 103 |
+
root_idx = int(skeleton.root_idx)
|
| 104 |
+
|
| 105 |
+
local_rot_mats = _coerce_batch("local_rot_mats", local_rot_mats, expected_ndim=4)
|
| 106 |
+
T, J = local_rot_mats.shape[:2]
|
| 107 |
+
q_wxyz = _matrix_to_quaternion(local_rot_mats).detach().cpu().numpy() # [T, J, 4]
|
| 108 |
+
|
| 109 |
+
root_xyz = _coerce_batch("root_positions", root_positions, expected_ndim=2)
|
| 110 |
+
root_xyz = root_xyz.cpu().numpy() # [T, 3]
|
| 111 |
+
|
| 112 |
+
# Build BVH hierarchy: Root (wrapper at origin) -> Hips (pelvis with offset in meters) -> ...
|
| 113 |
+
# Offsets are in meters to match the original format.
|
| 114 |
+
children: dict[int, list[int]] = {i: [] for i in range(J)}
|
| 115 |
+
for i, p in enumerate(parents):
|
| 116 |
+
if p >= 0:
|
| 117 |
+
children[int(p)].append(int(i))
|
| 118 |
+
|
| 119 |
+
_ROOT_CHANNELS = [
|
| 120 |
+
"Xposition",
|
| 121 |
+
"Yposition",
|
| 122 |
+
"Zposition",
|
| 123 |
+
"Zrotation",
|
| 124 |
+
"Yrotation",
|
| 125 |
+
"Xrotation",
|
| 126 |
+
]
|
| 127 |
+
_JOINT_CHANNELS = ["Zrotation", "Yrotation", "Xrotation"]
|
| 128 |
+
|
| 129 |
+
# Scale from meters to centimeters (match original BVH scale).
|
| 130 |
+
neutral = neutral * 100
|
| 131 |
+
root_xyz = root_xyz * 100
|
| 132 |
+
|
| 133 |
+
# Hips offset from Root: use skeleton neutral; if root is at origin (zeros), use a
|
| 134 |
+
# nominal pelvis height so the hierarchy is non-degenerate in Blender.
|
| 135 |
+
hips_offset = neutral[root_idx]
|
| 136 |
+
if (hips_offset == 0).all():
|
| 137 |
+
hips_offset = np.array([0.0, 100.0, 0.0], dtype=neutral.dtype) # 1 m in cm
|
| 138 |
+
|
| 139 |
+
def _make_joint(i: int) -> "bvhio.BvhJoint":
|
| 140 |
+
name = joint_names[i]
|
| 141 |
+
j = bvhio.BvhJoint(name, offset=glm.vec3(0, 0, 0))
|
| 142 |
+
if i == root_idx:
|
| 143 |
+
# Hips: offset from Root (origin) in cm
|
| 144 |
+
off = hips_offset
|
| 145 |
+
j.Offset = glm.vec3(float(off[0]), float(off[1]), float(off[2]))
|
| 146 |
+
j.Channels = _ROOT_CHANNELS.copy()
|
| 147 |
+
else:
|
| 148 |
+
p = int(parents[i])
|
| 149 |
+
off = neutral[i] - neutral[p]
|
| 150 |
+
j.Offset = glm.vec3(float(off[0]), float(off[1]), float(off[2]))
|
| 151 |
+
j.Channels = _JOINT_CHANNELS.copy()
|
| 152 |
+
|
| 153 |
+
for c in children[i]:
|
| 154 |
+
j.Children.append(_make_joint(c))
|
| 155 |
+
return j
|
| 156 |
+
|
| 157 |
+
# Wrapper Root at origin; single child is Hips (skeleton root).
|
| 158 |
+
root_wrapper = bvhio.BvhJoint("Root", offset=glm.vec3(0.0, 0.0, 0.0))
|
| 159 |
+
root_wrapper.Channels = _ROOT_CHANNELS.copy()
|
| 160 |
+
root_wrapper.Children.append(_make_joint(root_idx))
|
| 161 |
+
root_joint = root_wrapper
|
| 162 |
+
|
| 163 |
+
# Populate keyframes: Root = identity/zero, Hips = root motion, others = local rotation.
|
| 164 |
+
bvh_layout = root_joint.layout()
|
| 165 |
+
name_to_id = {n: idx for idx, n in enumerate(joint_names)}
|
| 166 |
+
ordered_joint_ids = []
|
| 167 |
+
for bj, _, _ in bvh_layout:
|
| 168 |
+
if bj.Name == "Root":
|
| 169 |
+
ordered_joint_ids.append(None)
|
| 170 |
+
else:
|
| 171 |
+
ordered_joint_ids.append(name_to_id[bj.Name])
|
| 172 |
+
|
| 173 |
+
bvh_joints = [bj for bj, _, _ in bvh_layout]
|
| 174 |
+
for bj in bvh_joints:
|
| 175 |
+
bj.Keyframes = [None] * T # type: ignore[list-item]
|
| 176 |
+
|
| 177 |
+
identity_quat = glm.quat(1.0, 0.0, 0.0, 0.0)
|
| 178 |
+
zero_vec = glm.vec3(0.0, 0.0, 0.0)
|
| 179 |
+
for t in range(T):
|
| 180 |
+
for bj, jid in zip(bvh_joints, ordered_joint_ids):
|
| 181 |
+
if jid is None:
|
| 182 |
+
position = zero_vec
|
| 183 |
+
rotation = identity_quat
|
| 184 |
+
elif jid == root_idx:
|
| 185 |
+
pos = root_xyz[t]
|
| 186 |
+
position = glm.vec3(float(pos[0]), float(pos[1]), float(pos[2]))
|
| 187 |
+
qw, qx, qy, qz = q_wxyz[t, jid]
|
| 188 |
+
rotation = glm.quat(float(qw), float(qx), float(qy), float(qz))
|
| 189 |
+
else:
|
| 190 |
+
position = zero_vec
|
| 191 |
+
qw, qx, qy, qz = q_wxyz[t, jid]
|
| 192 |
+
rotation = glm.quat(float(qw), float(qx), float(qy), float(qz))
|
| 193 |
+
bj.Keyframes[t] = Pose(position, rotation) # type: ignore[index]
|
| 194 |
+
|
| 195 |
+
container = bvhio.BvhContainer(root_joint, frameCount=T, frameTime=1.0 / float(fps))
|
| 196 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".bvh", delete=False, encoding="utf-8") as f:
|
| 197 |
+
tmp_path = f.name
|
| 198 |
+
try:
|
| 199 |
+
bvhio.writeBvh(tmp_path, container, percision=6)
|
| 200 |
+
bvh_text = Path(tmp_path).read_text(encoding="utf-8")
|
| 201 |
+
return _strip_end_site_blocks(bvh_text)
|
| 202 |
+
finally:
|
| 203 |
+
try:
|
| 204 |
+
os.remove(tmp_path)
|
| 205 |
+
except Exception:
|
| 206 |
+
pass
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def motion_to_bvh_bytes(
|
| 210 |
+
local_rot_mats: torch.Tensor,
|
| 211 |
+
root_positions: torch.Tensor,
|
| 212 |
+
*,
|
| 213 |
+
skeleton,
|
| 214 |
+
fps: float,
|
| 215 |
+
) -> bytes:
|
| 216 |
+
"""Convert local rotations and root positions to BVH bytes (UTF-8).
|
| 217 |
+
|
| 218 |
+
Convenience wrapper around :func:`motion_to_bvh`.
|
| 219 |
+
"""
|
| 220 |
+
return motion_to_bvh(local_rot_mats, root_positions, skeleton=skeleton, fps=fps).encode("utf-8")
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def save_motion_bvh(
|
| 224 |
+
path: Union[str, Path],
|
| 225 |
+
local_rot_mats: torch.Tensor,
|
| 226 |
+
root_positions: torch.Tensor,
|
| 227 |
+
*,
|
| 228 |
+
skeleton,
|
| 229 |
+
fps: float,
|
| 230 |
+
) -> None:
|
| 231 |
+
"""Write local rotations and root positions to a BVH file at the given path."""
|
| 232 |
+
Path(path).write_text(
|
| 233 |
+
motion_to_bvh(local_rot_mats, root_positions, skeleton=skeleton, fps=fps),
|
| 234 |
+
encoding="utf-8",
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def read_bvh_frame_time_seconds(path: Union[str, Path]) -> float:
|
| 239 |
+
"""Read ``Frame Time`` from a BVH file (seconds per frame)."""
|
| 240 |
+
with open(path, encoding="utf-8") as f:
|
| 241 |
+
for line in f:
|
| 242 |
+
if "Frame Time:" in line:
|
| 243 |
+
parts = line.split()
|
| 244 |
+
return float(parts[-1])
|
| 245 |
+
raise ValueError(f"Could not find 'Frame Time:' in {path}")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def bvh_to_kimodo_motion(
|
| 249 |
+
path: Union[str, Path],
|
| 250 |
+
skeleton=None,
|
| 251 |
+
) -> Tuple:
|
| 252 |
+
"""Load a Kimodo-style SOMA BVH into a Kimodo motion dict.
|
| 253 |
+
|
| 254 |
+
Expects the same hierarchy as :func:`save_motion_bvh` (``Root`` wrapper + SOMA77 joints).
|
| 255 |
+
The frame rate is always read from the BVH ``Frame Time`` header. Callers
|
| 256 |
+
that need a different playback rate should resample the returned motion dict
|
| 257 |
+
(see :func:`~kimodo.exports.motion_io.resample_motion_dict_to_kimodo_fps`).
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
``(motion_dict, source_fps)`` where ``source_fps`` is the native BVH
|
| 261 |
+
frame rate read from the file header.
|
| 262 |
+
"""
|
| 263 |
+
from kimodo.exports.motion_io import complete_motion_dict
|
| 264 |
+
from kimodo.skeleton.bvh import parse_bvh_motion
|
| 265 |
+
from kimodo.skeleton.registry import build_skeleton
|
| 266 |
+
|
| 267 |
+
if skeleton is None:
|
| 268 |
+
skeleton = build_skeleton(77)
|
| 269 |
+
device = skeleton.neutral_joints.device
|
| 270 |
+
|
| 271 |
+
local_rot_mats, root_trans, bvh_fps = parse_bvh_motion(str(path))
|
| 272 |
+
local_rot_mats = local_rot_mats.to(device=device)
|
| 273 |
+
root_trans = root_trans.to(device=device)
|
| 274 |
+
|
| 275 |
+
if int(local_rot_mats.shape[1]) != int(skeleton.nbjoints):
|
| 276 |
+
raise ValueError(
|
| 277 |
+
f"BVH has {local_rot_mats.shape[1]} joints but skeleton has {skeleton.nbjoints}; "
|
| 278 |
+
"use a Kimodo-exported SOMA BVH or matching skeleton."
|
| 279 |
+
)
|
| 280 |
+
local_rot_mats, _ = skeleton.to_standard_tpose(local_rot_mats)
|
| 281 |
+
|
| 282 |
+
return complete_motion_dict(local_rot_mats, root_trans, skeleton, float(bvh_fps)), bvh_fps
|
kimodo/exports/motion_convert_lib.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Library API for converting between Kimodo NPZ, AMASS NPZ, SOMA BVH, and G1 MuJoCo CSV."""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from kimodo.exports.bvh import bvh_to_kimodo_motion, save_motion_bvh
|
| 12 |
+
from kimodo.exports.motion_formats import (
|
| 13 |
+
infer_source_format_from_path,
|
| 14 |
+
infer_target_format_from_path,
|
| 15 |
+
resolve_source_fps,
|
| 16 |
+
)
|
| 17 |
+
from kimodo.exports.motion_io import (
|
| 18 |
+
load_amass_npz,
|
| 19 |
+
load_g1_csv,
|
| 20 |
+
load_kimodo_npz_as_torch,
|
| 21 |
+
save_kimodo_npz_at_target_fps,
|
| 22 |
+
)
|
| 23 |
+
from kimodo.exports.mujoco import MujocoQposConverter
|
| 24 |
+
from kimodo.exports.smplx import AMASSConverter
|
| 25 |
+
from kimodo.skeleton.registry import build_skeleton
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def convert_motion_files(
|
| 29 |
+
input_path: str,
|
| 30 |
+
output_path: str,
|
| 31 |
+
*,
|
| 32 |
+
from_fmt: str | None = None,
|
| 33 |
+
to_fmt: str | None = None,
|
| 34 |
+
source_fps: float | None = None,
|
| 35 |
+
z_up: bool = True,
|
| 36 |
+
mujoco_rest_zero: bool = False,
|
| 37 |
+
) -> None:
|
| 38 |
+
"""Convert a motion file between Kimodo-supported formats.
|
| 39 |
+
|
| 40 |
+
Supported pairs (hub-and-spoke through Kimodo NPZ):
|
| 41 |
+
|
| 42 |
+
- amass <-> kimodo
|
| 43 |
+
- soma-bvh <-> kimodo
|
| 44 |
+
- g1-csv <-> kimodo
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
input_path: Source file (``.npz``, ``.bvh``, or ``.csv``).
|
| 48 |
+
output_path: Destination file.
|
| 49 |
+
from_fmt: Source format; inferred from extension/contents when ``None``.
|
| 50 |
+
to_fmt: Target format; inferred from extension when ``None``.
|
| 51 |
+
source_fps: Source motion frame rate (Hz). If provided, trusted as-is.
|
| 52 |
+
If ``None``, auto-detected from BVH ``Frame Time``, AMASS
|
| 53 |
+
``mocap_frame_rate``, or default 30.
|
| 54 |
+
z_up: For AMASS conversions, apply the Z-up <-> Kimodo Y-up transform.
|
| 55 |
+
mujoco_rest_zero: For G1 CSV, joint angles relative to MuJoCo rest pose.
|
| 56 |
+
"""
|
| 57 |
+
from_fmt = from_fmt or infer_source_format_from_path(input_path)
|
| 58 |
+
to_fmt = to_fmt or infer_target_format_from_path(output_path, from_fmt)
|
| 59 |
+
|
| 60 |
+
_validate_output_extension(to_fmt, output_path)
|
| 61 |
+
|
| 62 |
+
pair = (from_fmt, to_fmt)
|
| 63 |
+
|
| 64 |
+
if pair == ("amass", "kimodo"):
|
| 65 |
+
sk = build_skeleton(22)
|
| 66 |
+
effective_source = source_fps
|
| 67 |
+
if effective_source is None:
|
| 68 |
+
with np.load(input_path, allow_pickle=True) as z:
|
| 69 |
+
effective_source = float(z["mocap_frame_rate"]) if "mocap_frame_rate" in z.files else 30.0
|
| 70 |
+
motion = load_amass_npz(input_path, source_fps=effective_source, z_up=z_up)
|
| 71 |
+
save_kimodo_npz_at_target_fps(motion, sk, effective_source, output_path)
|
| 72 |
+
return
|
| 73 |
+
|
| 74 |
+
if pair == ("kimodo", "amass"):
|
| 75 |
+
data, J = load_kimodo_npz_as_torch(input_path, ensure_complete=False)
|
| 76 |
+
if J != 22:
|
| 77 |
+
raise ValueError(f"Kimodo→AMASS requires 22 joints (SMPL-X); this file has J={J}.")
|
| 78 |
+
sk = build_skeleton(22)
|
| 79 |
+
effective_source = resolve_source_fps(source_fps, "kimodo", input_path, None)
|
| 80 |
+
converter = AMASSConverter(fps=effective_source, skeleton=sk)
|
| 81 |
+
converter.convert_save_npz(data, output_path, z_up=z_up)
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
if pair == ("soma-bvh", "kimodo"):
|
| 85 |
+
sk = build_skeleton(77)
|
| 86 |
+
motion, bvh_fps = bvh_to_kimodo_motion(input_path, skeleton=sk)
|
| 87 |
+
effective_source = source_fps if source_fps is not None else bvh_fps
|
| 88 |
+
save_kimodo_npz_at_target_fps(motion, sk, effective_source, output_path)
|
| 89 |
+
return
|
| 90 |
+
|
| 91 |
+
if pair == ("kimodo", "soma-bvh"):
|
| 92 |
+
data, J = load_kimodo_npz_as_torch(input_path, ensure_complete=False)
|
| 93 |
+
if J == 30:
|
| 94 |
+
warnings.warn(
|
| 95 |
+
f"Input has 30 joints (somaskel30); expanding to somaskel77 for BVH export.",
|
| 96 |
+
UserWarning,
|
| 97 |
+
stacklevel=2,
|
| 98 |
+
)
|
| 99 |
+
sk = build_skeleton(30)
|
| 100 |
+
elif J == 77:
|
| 101 |
+
sk = build_skeleton(77)
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError(f"Kimodo→BVH requires a SOMA skeleton (30 or 77 joints); this file has J={J}.")
|
| 104 |
+
effective_source = resolve_source_fps(source_fps, "kimodo", input_path, None)
|
| 105 |
+
save_motion_bvh(
|
| 106 |
+
output_path,
|
| 107 |
+
data["local_rot_mats"],
|
| 108 |
+
data["root_positions"],
|
| 109 |
+
skeleton=sk,
|
| 110 |
+
fps=effective_source,
|
| 111 |
+
)
|
| 112 |
+
return
|
| 113 |
+
|
| 114 |
+
if pair == ("g1-csv", "kimodo"):
|
| 115 |
+
sk = build_skeleton(34)
|
| 116 |
+
effective_source = resolve_source_fps(source_fps, "g1-csv", input_path, None)
|
| 117 |
+
motion = load_g1_csv(input_path, source_fps=effective_source, mujoco_rest_zero=mujoco_rest_zero)
|
| 118 |
+
save_kimodo_npz_at_target_fps(motion, sk, effective_source, output_path)
|
| 119 |
+
return
|
| 120 |
+
|
| 121 |
+
if pair == ("kimodo", "g1-csv"):
|
| 122 |
+
data, J = load_kimodo_npz_as_torch(input_path, ensure_complete=False)
|
| 123 |
+
if J != 34:
|
| 124 |
+
raise ValueError(f"Kimodo→CSV requires G1 with 34 joints; this file has J={J}.")
|
| 125 |
+
sk = build_skeleton(34)
|
| 126 |
+
effective_source = resolve_source_fps(source_fps, "kimodo", input_path, None)
|
| 127 |
+
converter = MujocoQposConverter(sk)
|
| 128 |
+
qpos = converter.dict_to_qpos(
|
| 129 |
+
{k: v for k, v in data.items() if k in ("local_rot_mats", "root_positions")},
|
| 130 |
+
device=str(sk.neutral_joints.device),
|
| 131 |
+
numpy=True,
|
| 132 |
+
mujoco_rest_zero=mujoco_rest_zero,
|
| 133 |
+
)
|
| 134 |
+
converter.save_csv(qpos, output_path)
|
| 135 |
+
return
|
| 136 |
+
|
| 137 |
+
raise ValueError(
|
| 138 |
+
f"Unsupported conversion {from_fmt!r} → {to_fmt!r}. "
|
| 139 |
+
"Supported: amass↔kimodo (SMPL-X NPZ), soma-bvh↔kimodo, g1-csv↔kimodo."
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _validate_output_extension(to_fmt: str, output_path: str) -> None:
|
| 144 |
+
lower = output_path.lower()
|
| 145 |
+
if to_fmt == "kimodo" and lower.endswith(".npz"):
|
| 146 |
+
return
|
| 147 |
+
if to_fmt == "amass":
|
| 148 |
+
if not lower.endswith(".npz"):
|
| 149 |
+
raise ValueError("AMASS output must use a .npz path.")
|
| 150 |
+
elif to_fmt == "soma-bvh":
|
| 151 |
+
if not lower.endswith(".bvh"):
|
| 152 |
+
raise ValueError("SOMA BVH output must use a .bvh path.")
|
| 153 |
+
elif to_fmt == "g1-csv":
|
| 154 |
+
if not lower.endswith(".csv"):
|
| 155 |
+
raise ValueError("G1 CSV output must use a .csv path.")
|
kimodo/exports/motion_formats.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Infer motion file formats from paths and NPZ contents."""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from typing import Literal
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
MotionSourceFormat = Literal["amass", "kimodo", "soma-bvh", "g1-csv"]
|
| 13 |
+
MotionTargetFormat = Literal["amass", "kimodo", "soma-bvh", "g1-csv"]
|
| 14 |
+
NpzMotionKind = Literal["amass", "kimodo"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def infer_npz_kind(path: str) -> NpzMotionKind:
|
| 18 |
+
"""Classify a ``.npz`` as AMASS SMPL-X or Kimodo from required array keys."""
|
| 19 |
+
with np.load(path, allow_pickle=False) as z:
|
| 20 |
+
keys = set(z.files)
|
| 21 |
+
if "trans" in keys and "pose_body" in keys and "root_orient" in keys:
|
| 22 |
+
return "amass"
|
| 23 |
+
if "local_rot_mats" in keys or "posed_joints" in keys:
|
| 24 |
+
return "kimodo"
|
| 25 |
+
raise ValueError(
|
| 26 |
+
f"Unrecognized NPZ {path!r}: expected AMASS keys (trans, pose_body, ...) "
|
| 27 |
+
"or Kimodo keys (local_rot_mats, posed_joints, ...)."
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def infer_source_format_from_path(path: str) -> MotionSourceFormat:
|
| 32 |
+
"""Infer converter input format from file extension and NPZ contents when needed."""
|
| 33 |
+
ext = os.path.splitext(path)[1].lower()
|
| 34 |
+
if ext == ".bvh":
|
| 35 |
+
return "soma-bvh"
|
| 36 |
+
if ext == ".csv":
|
| 37 |
+
return "g1-csv"
|
| 38 |
+
if ext == ".npz":
|
| 39 |
+
return infer_npz_kind(path) # type: ignore[return-value]
|
| 40 |
+
raise ValueError(f"Cannot infer format from extension of {path!r}")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def infer_target_format_from_path(path: str, from_fmt: MotionSourceFormat) -> MotionTargetFormat:
|
| 44 |
+
"""Infer converter output format from destination path and source format."""
|
| 45 |
+
ext = os.path.splitext(path)[1].lower()
|
| 46 |
+
if ext == ".bvh":
|
| 47 |
+
return "soma-bvh"
|
| 48 |
+
if ext == ".csv":
|
| 49 |
+
return "g1-csv"
|
| 50 |
+
if ext == ".npz":
|
| 51 |
+
if from_fmt == "amass":
|
| 52 |
+
return "kimodo"
|
| 53 |
+
if from_fmt == "kimodo":
|
| 54 |
+
return "amass"
|
| 55 |
+
if from_fmt in ("g1-csv", "soma-bvh"):
|
| 56 |
+
return "kimodo"
|
| 57 |
+
raise ValueError(
|
| 58 |
+
"Ambiguous .npz output: set --to to 'kimodo' or 'amass' when the input format is not amass/kimodo."
|
| 59 |
+
)
|
| 60 |
+
raise ValueError(f"Cannot infer output format from extension of {path!r}")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def resolve_source_fps(
|
| 64 |
+
fps: float | None,
|
| 65 |
+
from_kind: str,
|
| 66 |
+
input_path: str,
|
| 67 |
+
data: dict | None,
|
| 68 |
+
) -> float:
|
| 69 |
+
"""Resolve source frame rate (Hz) for conversion when ``fps`` is not overridden."""
|
| 70 |
+
if fps is not None:
|
| 71 |
+
return float(fps)
|
| 72 |
+
if data is not None and "mocap_frame_rate" in data:
|
| 73 |
+
return float(np.asarray(data["mocap_frame_rate"]).item())
|
| 74 |
+
if from_kind == "soma-bvh":
|
| 75 |
+
from kimodo.exports.bvh import read_bvh_frame_time_seconds
|
| 76 |
+
|
| 77 |
+
return 1.0 / read_bvh_frame_time_seconds(input_path)
|
| 78 |
+
return 30.0
|
kimodo/exports/motion_io.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Assemble Kimodo NPZ-compatible motion dicts from local rotations + root trajectory."""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import warnings
|
| 9 |
+
from typing import Any, Dict, Tuple
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from kimodo.geometry import matrix_to_quaternion, quaternion_to_matrix
|
| 15 |
+
from kimodo.motion_rep.feature_utils import compute_heading_angle, compute_vel_xyz
|
| 16 |
+
from kimodo.motion_rep.feet import foot_detect_from_pos_and_vel
|
| 17 |
+
from kimodo.motion_rep.smooth_root import get_smooth_root_pos
|
| 18 |
+
from kimodo.skeleton import SkeletonBase
|
| 19 |
+
from kimodo.skeleton.registry import build_skeleton
|
| 20 |
+
from kimodo.tools import to_numpy
|
| 21 |
+
|
| 22 |
+
# Default motion rate for Kimodo NPZ produced by format conversion (matches common model FPS).
|
| 23 |
+
KIMODO_CONVERT_TARGET_FPS = 30.0
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _quaternion_slerp(q0: torch.Tensor, q1: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
| 27 |
+
"""Spherical linear interpolation; ``q0``, ``q1`` (..., 4) wxyz; ``t`` broadcastable to (...,
|
| 28 |
+
1)."""
|
| 29 |
+
if t.dim() < q0.dim():
|
| 30 |
+
t = t.unsqueeze(-1)
|
| 31 |
+
dot = (q0 * q1).sum(dim=-1, keepdim=True)
|
| 32 |
+
q1 = torch.where(dot < 0, -q1, q1)
|
| 33 |
+
dot = torch.abs(dot).clamp(-1.0, 1.0)
|
| 34 |
+
theta_0 = torch.acos(dot)
|
| 35 |
+
sin_theta = torch.sin(theta_0)
|
| 36 |
+
s0 = torch.sin((1.0 - t) * theta_0) / sin_theta.clamp(min=1e-8)
|
| 37 |
+
s1 = torch.sin(t * theta_0) / sin_theta.clamp(min=1e-8)
|
| 38 |
+
q = s0 * q0 + s1 * q1
|
| 39 |
+
return q / torch.linalg.norm(q, dim=-1, keepdim=True).clamp(min=1e-8)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def resample_motion_dict_to_kimodo_fps(
|
| 43 |
+
motion_dict: Dict[str, torch.Tensor],
|
| 44 |
+
skeleton: SkeletonBase,
|
| 45 |
+
source_fps: float,
|
| 46 |
+
target_fps: float = KIMODO_CONVERT_TARGET_FPS,
|
| 47 |
+
) -> Tuple[Dict[str, torch.Tensor], bool]:
|
| 48 |
+
"""Resample a Kimodo motion dict to ``target_fps``.
|
| 49 |
+
|
| 50 |
+
When the fps ratio is close to an integer (e.g. 120 / 30 = 4), the faster
|
| 51 |
+
stepping method is used (take every *step*-th frame). Otherwise falls back
|
| 52 |
+
to linear interp (root) + quaternion slerp (joints).
|
| 53 |
+
|
| 54 |
+
Re-runs :func:`complete_motion_dict` at the target rate so derived channels stay consistent.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
The motion dict and ``True`` if time resampling was applied, else ``False`` (already at
|
| 58 |
+
``target_fps`` with matching frame count; only re-derived via FK).
|
| 59 |
+
"""
|
| 60 |
+
local_rot_mats = motion_dict["local_rot_mats"]
|
| 61 |
+
root_positions = motion_dict["root_positions"]
|
| 62 |
+
local_rot_mats, root_positions = _coerce_time_local_root(local_rot_mats, root_positions)
|
| 63 |
+
t_in = int(local_rot_mats.shape[0])
|
| 64 |
+
if t_in < 1:
|
| 65 |
+
raise ValueError("Motion must have at least one frame.")
|
| 66 |
+
if source_fps <= 0:
|
| 67 |
+
raise ValueError(f"source_fps must be positive; got {source_fps}")
|
| 68 |
+
|
| 69 |
+
t_out = max(1, int(round(t_in * target_fps / source_fps)))
|
| 70 |
+
if t_out == t_in and abs(float(source_fps) - float(target_fps)) < 1e-3:
|
| 71 |
+
return complete_motion_dict(local_rot_mats, root_positions, skeleton, float(target_fps)), False
|
| 72 |
+
|
| 73 |
+
ratio = source_fps / target_fps
|
| 74 |
+
step = round(ratio)
|
| 75 |
+
if step >= 2 and abs(ratio - step) < 0.05:
|
| 76 |
+
local_out = local_rot_mats[::step]
|
| 77 |
+
root_out = root_positions[::step]
|
| 78 |
+
else:
|
| 79 |
+
device = local_rot_mats.device
|
| 80 |
+
dtype = local_rot_mats.dtype
|
| 81 |
+
u = torch.linspace(0, t_in - 1, t_out, device=device, dtype=dtype)
|
| 82 |
+
i0 = u.floor().long().clamp(0, t_in - 1)
|
| 83 |
+
i1 = torch.minimum(i0 + 1, torch.tensor(t_in - 1, device=device))
|
| 84 |
+
tau_1d = (u - i0.float()).unsqueeze(-1)
|
| 85 |
+
rp0 = root_positions[i0]
|
| 86 |
+
rp1 = root_positions[i1]
|
| 87 |
+
root_out = (1.0 - tau_1d) * rp0 + tau_1d * rp1
|
| 88 |
+
|
| 89 |
+
quats = matrix_to_quaternion(local_rot_mats)
|
| 90 |
+
q0 = quats[i0]
|
| 91 |
+
q1 = quats[i1]
|
| 92 |
+
tau_q = (u - i0.float()).view(t_out, 1, 1)
|
| 93 |
+
quat_out = _quaternion_slerp(q0, q1, tau_q)
|
| 94 |
+
local_out = quaternion_to_matrix(quat_out)
|
| 95 |
+
|
| 96 |
+
return complete_motion_dict(local_out, root_out, skeleton, float(target_fps)), True
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def warn_kimodo_npz_framerate(source_fps: float, t_before: int, t_after: int) -> None:
|
| 100 |
+
"""Emit a warning after time resampling for Kimodo NPZ (linear root, quaternion slerp per
|
| 101 |
+
joint)."""
|
| 102 |
+
warnings.warn(
|
| 103 |
+
f"Resampled motion to {KIMODO_CONVERT_TARGET_FPS:.0f} Hz for Kimodo NPZ "
|
| 104 |
+
f"(source ~{source_fps:.4g} Hz, {t_before} input frames → {t_after} output frames). "
|
| 105 |
+
"Pass --source-fps if the detected source rate is wrong.",
|
| 106 |
+
UserWarning,
|
| 107 |
+
stacklevel=3,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _coerce_time_local_root(
|
| 112 |
+
local_rot_mats: torch.Tensor,
|
| 113 |
+
root_positions: torch.Tensor,
|
| 114 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 115 |
+
"""Normalize to shapes (T, J, 3, 3) and (T, 3)."""
|
| 116 |
+
if local_rot_mats.dim() == 5:
|
| 117 |
+
if int(local_rot_mats.shape[0]) != 1:
|
| 118 |
+
raise ValueError(f"local_rot_mats batch size must be 1 for single clip; got {local_rot_mats.shape[0]}")
|
| 119 |
+
local_rot_mats = local_rot_mats[0]
|
| 120 |
+
if root_positions.dim() == 3:
|
| 121 |
+
if int(root_positions.shape[0]) != 1:
|
| 122 |
+
raise ValueError(f"root_positions batch size must be 1; got {root_positions.shape[0]}")
|
| 123 |
+
root_positions = root_positions[0]
|
| 124 |
+
if local_rot_mats.dim() != 4:
|
| 125 |
+
raise ValueError(f"local_rot_mats must be (T,J,3,3); got {tuple(local_rot_mats.shape)}")
|
| 126 |
+
if root_positions.dim() != 2 or int(root_positions.shape[-1]) != 3:
|
| 127 |
+
raise ValueError(f"root_positions must be (T,3); got {tuple(root_positions.shape)}")
|
| 128 |
+
if int(local_rot_mats.shape[0]) != int(root_positions.shape[0]):
|
| 129 |
+
raise ValueError("local_rot_mats and root_positions must have the same number of frames")
|
| 130 |
+
return local_rot_mats, root_positions
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def complete_motion_dict(
|
| 134 |
+
local_rot_mats: torch.Tensor,
|
| 135 |
+
root_positions: torch.Tensor,
|
| 136 |
+
skeleton: SkeletonBase,
|
| 137 |
+
fps: float,
|
| 138 |
+
) -> Dict[str, torch.Tensor]:
|
| 139 |
+
"""Build the Kimodo motion output dict from local rotations and root positions.
|
| 140 |
+
|
| 141 |
+
Matches keys written by CLI generation (see docs/source/user_guide/output_formats.md).
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
local_rot_mats: (T, J, 3, 3) or (1, T, J, 3, 3) local rotation matrices.
|
| 145 |
+
root_positions: (T, 3) or (1, T, 3) root / pelvis world positions (meters).
|
| 146 |
+
skeleton: Skeleton instance (SOMA77, G1, SMPL-X, etc.).
|
| 147 |
+
fps: Sampling rate (Hz).
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Dict with tensors ``posed_joints``, ``global_rot_mats``, ``local_rot_mats``,
|
| 151 |
+
``foot_contacts``, ``smooth_root_pos``, ``root_positions``, ``global_root_heading``.
|
| 152 |
+
"""
|
| 153 |
+
device = local_rot_mats.device
|
| 154 |
+
dtype = local_rot_mats.dtype
|
| 155 |
+
local_rot_mats, root_positions = _coerce_time_local_root(
|
| 156 |
+
local_rot_mats.to(device=device, dtype=dtype),
|
| 157 |
+
root_positions.to(device=device, dtype=dtype),
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
global_rot_mats, posed_joints, _ = skeleton.fk(local_rot_mats, root_positions)
|
| 161 |
+
|
| 162 |
+
smooth_root_pos = get_smooth_root_pos(root_positions.unsqueeze(0)).squeeze(0)
|
| 163 |
+
|
| 164 |
+
lengths = torch.tensor([posed_joints.shape[0]], device=device)
|
| 165 |
+
velocities = compute_vel_xyz(posed_joints.unsqueeze(0), fps, lengths=lengths).squeeze(0)
|
| 166 |
+
|
| 167 |
+
heading_angle = compute_heading_angle(posed_joints.unsqueeze(0), skeleton).squeeze(0)
|
| 168 |
+
global_root_heading = torch.stack([torch.cos(heading_angle), torch.sin(heading_angle)], dim=-1)
|
| 169 |
+
|
| 170 |
+
foot_contacts = foot_detect_from_pos_and_vel(
|
| 171 |
+
posed_joints.unsqueeze(0),
|
| 172 |
+
velocities.unsqueeze(0),
|
| 173 |
+
skeleton,
|
| 174 |
+
0.15,
|
| 175 |
+
0.10,
|
| 176 |
+
).squeeze(0)
|
| 177 |
+
|
| 178 |
+
return {
|
| 179 |
+
"posed_joints": posed_joints,
|
| 180 |
+
"global_rot_mats": global_rot_mats,
|
| 181 |
+
"local_rot_mats": local_rot_mats,
|
| 182 |
+
"foot_contacts": foot_contacts,
|
| 183 |
+
"smooth_root_pos": smooth_root_pos,
|
| 184 |
+
"root_positions": root_positions,
|
| 185 |
+
"global_root_heading": global_root_heading,
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def motion_dict_to_numpy(d: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
| 190 |
+
"""Convert motion dict values to numpy arrays for ``np.savez``."""
|
| 191 |
+
out: Dict[str, np.ndarray] = {}
|
| 192 |
+
for k, v in d.items():
|
| 193 |
+
if hasattr(v, "detach"):
|
| 194 |
+
out[k] = to_numpy(v)
|
| 195 |
+
elif isinstance(v, np.ndarray):
|
| 196 |
+
out[k] = v
|
| 197 |
+
else:
|
| 198 |
+
out[k] = np.asarray(v)
|
| 199 |
+
return out
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def save_kimodo_npz(path: str, motion_dict: Dict[str, Any]) -> None:
|
| 203 |
+
"""Save a Kimodo-compatible motion dict to ``.npz`` (numpy arrays)."""
|
| 204 |
+
np.savez(path, **motion_dict_to_numpy(motion_dict))
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def load_kimodo_npz(path: str) -> Dict[str, np.ndarray]:
|
| 208 |
+
"""Load arrays from a Kimodo ``.npz`` file."""
|
| 209 |
+
with np.load(path, allow_pickle=False) as data:
|
| 210 |
+
return {k: np.asarray(data[k]) for k in data.files}
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def load_g1_csv(
|
| 214 |
+
path: str,
|
| 215 |
+
source_fps: float = KIMODO_CONVERT_TARGET_FPS,
|
| 216 |
+
*,
|
| 217 |
+
mujoco_rest_zero: bool = False,
|
| 218 |
+
) -> Dict[str, torch.Tensor]:
|
| 219 |
+
"""Load a G1 MuJoCo ``qpos`` CSV (``(T, 36)``) into a Kimodo motion dict.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
path: CSV path (comma-separated, no header).
|
| 223 |
+
source_fps: Source frame rate (Hz) of the CSV data.
|
| 224 |
+
mujoco_rest_zero: Must match how the CSV was written (see :class:`MujocoQposConverter`).
|
| 225 |
+
"""
|
| 226 |
+
from kimodo.exports.mujoco import MujocoQposConverter
|
| 227 |
+
|
| 228 |
+
qpos = np.loadtxt(path, delimiter=",")
|
| 229 |
+
if qpos.ndim != 2 or qpos.shape[-1] != 36:
|
| 230 |
+
raise ValueError(f"Expected G1 CSV with shape (T, 36); got {qpos.shape}")
|
| 231 |
+
sk = build_skeleton(34)
|
| 232 |
+
converter = MujocoQposConverter(sk)
|
| 233 |
+
return converter.qpos_to_motion_dict(qpos, float(source_fps), mujoco_rest_zero=mujoco_rest_zero)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def load_amass_npz(
|
| 237 |
+
path: str,
|
| 238 |
+
source_fps: float | None = None,
|
| 239 |
+
*,
|
| 240 |
+
z_up: bool = True,
|
| 241 |
+
) -> Dict[str, torch.Tensor]:
|
| 242 |
+
"""Load an AMASS-style SMPL-X ``.npz`` into a Kimodo motion dict (22 joints).
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
path: NPZ with ``trans``, ``root_orient``, ``pose_body``, etc.
|
| 246 |
+
source_fps: Source frame rate (Hz); if ``None``, uses ``mocap_frame_rate``
|
| 247 |
+
from the file when present, else 30 Hz.
|
| 248 |
+
z_up: If ``True``, apply AMASS Z-up to Kimodo Y-up transform (same as CLI).
|
| 249 |
+
"""
|
| 250 |
+
from kimodo.exports.smplx import amass_npz_to_kimodo_motion
|
| 251 |
+
|
| 252 |
+
sk = build_skeleton(22)
|
| 253 |
+
return amass_npz_to_kimodo_motion(path, sk, source_fps=source_fps, z_up=z_up)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def load_kimodo_npz_as_torch(
|
| 257 |
+
path: str,
|
| 258 |
+
source_fps: float = KIMODO_CONVERT_TARGET_FPS,
|
| 259 |
+
*,
|
| 260 |
+
ensure_complete: bool = True,
|
| 261 |
+
) -> tuple[Dict[str, torch.Tensor], int]:
|
| 262 |
+
"""Load a Kimodo NPZ and return all arrays as torch tensors on the skeleton device.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
path: Kimodo NPZ file path.
|
| 266 |
+
source_fps: Source frame rate (Hz) used for derived channels when
|
| 267 |
+
``ensure_complete=True``.
|
| 268 |
+
ensure_complete: If ``True`` and the NPZ lacks derived channels
|
| 269 |
+
(``posed_joints``, ``global_rot_mats``, …), run :func:`complete_motion_dict`
|
| 270 |
+
to fill them from ``local_rot_mats`` + ``root_positions``.
|
| 271 |
+
If ``False``, load all arrays verbatim (requires ``local_rot_mats``).
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
``(tensor_dict, num_joints)``
|
| 275 |
+
"""
|
| 276 |
+
raw = load_kimodo_npz(path)
|
| 277 |
+
if "local_rot_mats" in raw:
|
| 278 |
+
j = int(raw["local_rot_mats"].shape[1])
|
| 279 |
+
elif "posed_joints" in raw:
|
| 280 |
+
j = int(raw["posed_joints"].shape[1])
|
| 281 |
+
else:
|
| 282 |
+
raise ValueError("Kimodo NPZ must contain 'local_rot_mats' or 'posed_joints'.")
|
| 283 |
+
sk = build_skeleton(j)
|
| 284 |
+
device = sk.neutral_joints.device
|
| 285 |
+
dtype = torch.float32
|
| 286 |
+
|
| 287 |
+
if not ensure_complete:
|
| 288 |
+
if "local_rot_mats" not in raw:
|
| 289 |
+
raise ValueError("Kimodo NPZ must contain 'local_rot_mats' (and typically 'root_positions').")
|
| 290 |
+
out: Dict[str, torch.Tensor] = {}
|
| 291 |
+
for k, v in raw.items():
|
| 292 |
+
out[k] = torch.from_numpy(np.asarray(v)).to(device=device, dtype=dtype)
|
| 293 |
+
return out, j
|
| 294 |
+
|
| 295 |
+
if "posed_joints" in raw and "global_rot_mats" in raw:
|
| 296 |
+
out = {}
|
| 297 |
+
for k, v in raw.items():
|
| 298 |
+
out[k] = torch.from_numpy(np.asarray(v)).to(device=device, dtype=dtype)
|
| 299 |
+
return out, j
|
| 300 |
+
|
| 301 |
+
if "local_rot_mats" not in raw or "root_positions" not in raw:
|
| 302 |
+
raise ValueError("Kimodo NPZ must contain posed_joints+global_rot_mats, or local_rot_mats+root_positions.")
|
| 303 |
+
local = torch.from_numpy(np.asarray(raw["local_rot_mats"])).to(device=device, dtype=dtype)
|
| 304 |
+
root = torch.from_numpy(np.asarray(raw["root_positions"])).to(device=device, dtype=dtype)
|
| 305 |
+
return complete_motion_dict(local, root, sk, float(source_fps)), j
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def save_kimodo_npz_at_target_fps(
|
| 309 |
+
motion: Dict[str, torch.Tensor],
|
| 310 |
+
skeleton: SkeletonBase,
|
| 311 |
+
source_fps: float,
|
| 312 |
+
output_path: str,
|
| 313 |
+
target_fps: float = KIMODO_CONVERT_TARGET_FPS,
|
| 314 |
+
) -> None:
|
| 315 |
+
"""Resample a motion dict to ``target_fps`` when needed, then save Kimodo NPZ."""
|
| 316 |
+
t_before = int(motion["local_rot_mats"].shape[0])
|
| 317 |
+
motion, did_resample = resample_motion_dict_to_kimodo_fps(motion, skeleton, source_fps, target_fps)
|
| 318 |
+
t_after = int(motion["local_rot_mats"].shape[0])
|
| 319 |
+
if did_resample:
|
| 320 |
+
warn_kimodo_npz_framerate(source_fps, t_before, t_after)
|
| 321 |
+
save_kimodo_npz(output_path, motion)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def kimodo_npz_to_bytes(motion_dict: Dict[str, Any]) -> bytes:
|
| 325 |
+
"""Serialize a Kimodo motion dict to in-memory NPZ bytes."""
|
| 326 |
+
import io
|
| 327 |
+
|
| 328 |
+
buf = io.BytesIO()
|
| 329 |
+
np.savez(buf, **motion_dict_to_numpy(motion_dict))
|
| 330 |
+
return buf.getvalue()
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def g1_csv_to_bytes(motion_dict: Dict[str, Any], skeleton: SkeletonBase, device: Any) -> bytes:
|
| 334 |
+
"""Convert a motion dict to G1 MuJoCo CSV bytes via :class:`MujocoQposConverter`."""
|
| 335 |
+
import io
|
| 336 |
+
|
| 337 |
+
from kimodo.exports.mujoco import MujocoQposConverter
|
| 338 |
+
|
| 339 |
+
converter = MujocoQposConverter(skeleton)
|
| 340 |
+
qpos = converter.dict_to_qpos(
|
| 341 |
+
{k: v for k, v in motion_dict.items() if k in ("local_rot_mats", "root_positions")},
|
| 342 |
+
device,
|
| 343 |
+
numpy=True,
|
| 344 |
+
)
|
| 345 |
+
buf = io.StringIO()
|
| 346 |
+
np.savetxt(buf, qpos, delimiter=",")
|
| 347 |
+
return buf.getvalue().encode("utf-8")
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def amass_npz_to_bytes(motion_dict: Dict[str, Any], skeleton: SkeletonBase, fps: float) -> bytes:
|
| 351 |
+
"""Convert a motion dict to AMASS NPZ bytes via :class:`AMASSConverter`."""
|
| 352 |
+
import io
|
| 353 |
+
|
| 354 |
+
from kimodo.exports.smplx import AMASSConverter
|
| 355 |
+
|
| 356 |
+
converter = AMASSConverter(skeleton=skeleton, fps=fps)
|
| 357 |
+
buf = io.BytesIO()
|
| 358 |
+
converter.convert_save_npz(
|
| 359 |
+
{k: v for k, v in motion_dict.items() if k in ("local_rot_mats", "root_positions")},
|
| 360 |
+
buf,
|
| 361 |
+
)
|
| 362 |
+
return buf.getvalue()
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def _read_amass_source_fps(path: str) -> float:
|
| 366 |
+
"""Read the source frame rate from an AMASS NPZ, defaulting to 30 Hz."""
|
| 367 |
+
with np.load(path, allow_pickle=True) as z:
|
| 368 |
+
if "mocap_frame_rate" in z.files:
|
| 369 |
+
return float(z["mocap_frame_rate"])
|
| 370 |
+
return 30.0
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def load_motion_file(
|
| 374 |
+
path: str,
|
| 375 |
+
source_fps: float | None = None,
|
| 376 |
+
target_fps: float | None = None,
|
| 377 |
+
*,
|
| 378 |
+
z_up: bool = True,
|
| 379 |
+
mujoco_rest_zero: bool = False,
|
| 380 |
+
) -> tuple[Dict[str, torch.Tensor], int]:
|
| 381 |
+
"""Load a motion file and return a Kimodo motion dict plus joint count.
|
| 382 |
+
|
| 383 |
+
Supports SOMA BVH (``.bvh``), G1 MuJoCo CSV (``.csv``), Kimodo NPZ, and AMASS SMPL-X NPZ
|
| 384 |
+
(``.npz``).
|
| 385 |
+
|
| 386 |
+
The motion is loaded at its native (or overridden) source rate, then
|
| 387 |
+
resampled to ``target_fps`` when they differ.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
path: Path to ``.bvh``, ``.csv``, or ``.npz``.
|
| 391 |
+
source_fps: Source frame rate (Hz). If provided, trusted as-is.
|
| 392 |
+
If ``None``, auto-detected per format: BVH ``Frame Time`` header,
|
| 393 |
+
AMASS ``mocap_frame_rate``, or :data:`KIMODO_CONVERT_TARGET_FPS`
|
| 394 |
+
(30 Hz) for CSV / Kimodo NPZ.
|
| 395 |
+
target_fps: Desired output frame rate (Hz). Defaults to
|
| 396 |
+
:data:`KIMODO_CONVERT_TARGET_FPS` (30 Hz). The motion is
|
| 397 |
+
resampled when ``source_fps`` and ``target_fps`` differ.
|
| 398 |
+
z_up: AMASS NPZ only; passed to :func:`load_amass_npz`.
|
| 399 |
+
mujoco_rest_zero: G1 CSV only; passed to :func:`load_g1_csv`.
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
``(motion_dict, num_joints)`` with the same keys as :func:`complete_motion_dict`.
|
| 403 |
+
"""
|
| 404 |
+
from kimodo.exports.motion_formats import infer_npz_kind
|
| 405 |
+
|
| 406 |
+
if target_fps is None:
|
| 407 |
+
target_fps = KIMODO_CONVERT_TARGET_FPS
|
| 408 |
+
|
| 409 |
+
ext = os.path.splitext(path)[1].lower()
|
| 410 |
+
if ext == ".bvh":
|
| 411 |
+
from kimodo.exports.bvh import bvh_to_kimodo_motion
|
| 412 |
+
|
| 413 |
+
motion_dict, bvh_fps = bvh_to_kimodo_motion(path)
|
| 414 |
+
effective_source = source_fps if source_fps is not None else bvh_fps
|
| 415 |
+
num_joints = int(motion_dict["local_rot_mats"].shape[1])
|
| 416 |
+
elif ext == ".csv":
|
| 417 |
+
effective_source = source_fps if source_fps is not None else KIMODO_CONVERT_TARGET_FPS
|
| 418 |
+
motion_dict = load_g1_csv(path, source_fps=effective_source, mujoco_rest_zero=mujoco_rest_zero)
|
| 419 |
+
num_joints = 34
|
| 420 |
+
elif ext == ".npz":
|
| 421 |
+
kind = infer_npz_kind(path)
|
| 422 |
+
if kind == "amass":
|
| 423 |
+
effective_source = source_fps if source_fps is not None else _read_amass_source_fps(path)
|
| 424 |
+
motion_dict = load_amass_npz(path, source_fps=effective_source, z_up=z_up)
|
| 425 |
+
num_joints = 22
|
| 426 |
+
else:
|
| 427 |
+
effective_source = source_fps if source_fps is not None else KIMODO_CONVERT_TARGET_FPS
|
| 428 |
+
motion_dict, num_joints = load_kimodo_npz_as_torch(path, source_fps=effective_source)
|
| 429 |
+
else:
|
| 430 |
+
raise ValueError(f"Unsupported motion file {path!r}; expected .bvh, .csv, or .npz")
|
| 431 |
+
|
| 432 |
+
if abs(effective_source - target_fps) > 0.5:
|
| 433 |
+
sk = build_skeleton(num_joints)
|
| 434 |
+
motion_dict, did_resample = resample_motion_dict_to_kimodo_fps(motion_dict, sk, effective_source, target_fps)
|
| 435 |
+
if did_resample:
|
| 436 |
+
t_out = int(motion_dict["local_rot_mats"].shape[0])
|
| 437 |
+
warnings.warn(
|
| 438 |
+
f"Resampled motion from {effective_source:.4g} Hz to " f"{target_fps:.0f} Hz ({t_out} frames).",
|
| 439 |
+
UserWarning,
|
| 440 |
+
stacklevel=2,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
return motion_dict, num_joints
|
kimodo/exports/mujoco.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Convert kimodo motion (y-up, z-forward) to MuJoCo qpos (z-up, x-forward) for G1 skeleton."""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import xml.etree.ElementTree as ET
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from scipy.spatial.transform import Rotation
|
| 12 |
+
|
| 13 |
+
from kimodo.assets import skeleton_asset_path
|
| 14 |
+
from kimodo.geometry import (
|
| 15 |
+
axis_angle_to_matrix,
|
| 16 |
+
matrix_to_axis_angle,
|
| 17 |
+
matrix_to_quaternion,
|
| 18 |
+
quaternion_to_matrix,
|
| 19 |
+
)
|
| 20 |
+
from kimodo.skeleton import G1Skeleton34, SkeletonBase, global_rots_to_local_rots
|
| 21 |
+
from kimodo.tools import ensure_batched, to_numpy, to_torch
|
| 22 |
+
|
| 23 |
+
# Cache so that the same (skeleton, xml_path) returns the same converter instance.
|
| 24 |
+
_converter_cache: dict[tuple[int, str], "MujocoQposConverter"] = {}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class MujocoQposConverter:
|
| 28 |
+
"""Fast batch converter from our dictionary format to mujoco qpos with precomputed transforms.
|
| 29 |
+
|
| 30 |
+
In mujoco, the coordination is z up and x forward, right handed.
|
| 31 |
+
|
| 32 |
+
Features (30 joints):
|
| 33 |
+
- root (pelvis, 7 = translation + rotation) + 29 dof joints (29)
|
| 34 |
+
|
| 35 |
+
In kimodo, the coordinate system is y up and z forward, right handed.
|
| 36 |
+
Features (34 joints):
|
| 37 |
+
- root (pelvis) + (34 - 1) joints; among these joints, 4 are end-effector joints added by kimodo.
|
| 38 |
+
|
| 39 |
+
Cached by (input_skeleton id, xml_path); repeated calls with the same args return the same instance.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __new__(
|
| 43 |
+
cls,
|
| 44 |
+
input_skeleton: SkeletonBase,
|
| 45 |
+
xml_path: str = str(skeleton_asset_path("g1skel34", "xml", "g1.xml")),
|
| 46 |
+
):
|
| 47 |
+
key = (id(input_skeleton), xml_path)
|
| 48 |
+
if key not in _converter_cache:
|
| 49 |
+
inst = object.__new__(cls)
|
| 50 |
+
_converter_cache[key] = inst
|
| 51 |
+
return _converter_cache[key]
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
input_skeleton: SkeletonBase,
|
| 56 |
+
xml_path: str = str(skeleton_asset_path("g1skel34", "xml", "g1.xml")),
|
| 57 |
+
):
|
| 58 |
+
"""Initialize converter with precomputed transforms.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
xml_path: Path to the mujoco XML file containing joint definitions
|
| 62 |
+
"""
|
| 63 |
+
if getattr(self, "_initialized", False):
|
| 64 |
+
return
|
| 65 |
+
self.xml_path = xml_path
|
| 66 |
+
self.skeleton = input_skeleton
|
| 67 |
+
self._prepare_transforms()
|
| 68 |
+
self._subtree_joints = {}
|
| 69 |
+
self._initialized = True
|
| 70 |
+
|
| 71 |
+
def _prepare_transforms(self):
|
| 72 |
+
"""Precompute all necessary transforms for efficient batch processing."""
|
| 73 |
+
# Define coordinate transformations between mujoco and kimodo space
|
| 74 |
+
# 1) R_zup_to_yup: rotation around x-axis by -90 degrees
|
| 75 |
+
# 2) x_forward_to_y_forward: rotation around z-axis by -90 degrees
|
| 76 |
+
# Combined transformation matrix: mujoco_to_kimodo = R_zup_to_yup * x_forward_to_y_forward
|
| 77 |
+
self.mujoco_to_kimodo_matrix = torch.tensor(
|
| 78 |
+
[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], dtype=torch.float32
|
| 79 |
+
)
|
| 80 |
+
self.kimodo_to_mujoco_matrix = self.mujoco_to_kimodo_matrix.T # Inverse transformation: kimodo_to_mujoco
|
| 81 |
+
|
| 82 |
+
# Parse XML once and extract joint information
|
| 83 |
+
tree = ET.parse(self.xml_path)
|
| 84 |
+
root = tree.getroot()
|
| 85 |
+
|
| 86 |
+
xml_classes = [x for x in tree.findall(".//default") if "class" in x.attrib]
|
| 87 |
+
joint_axes = dict()
|
| 88 |
+
class_ranges: dict[str, tuple[float, float]] = {}
|
| 89 |
+
for xml_class in xml_classes:
|
| 90 |
+
j = xml_class.findall("joint")
|
| 91 |
+
if j:
|
| 92 |
+
joint_axes[xml_class.get("class")] = j[0].get("axis")
|
| 93 |
+
range_str = j[0].get("range")
|
| 94 |
+
if range_str:
|
| 95 |
+
range_vals = [float(x) for x in range_str.split()]
|
| 96 |
+
if len(range_vals) == 2:
|
| 97 |
+
class_ranges[xml_class.get("class")] = (
|
| 98 |
+
range_vals[0],
|
| 99 |
+
range_vals[1],
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
mujoco_hinge_joints = root.find("worldbody").findall(".//joint") # skip the base joint
|
| 103 |
+
self._mujoco_joint_axis_values_kimodo_space = torch.zeros(
|
| 104 |
+
(len(mujoco_hinge_joints), 3), dtype=torch.float32
|
| 105 |
+
) # mujoco order but kimodo space
|
| 106 |
+
self._mujoco_joint_axis_values_mujoco_space = torch.zeros(
|
| 107 |
+
(len(mujoco_hinge_joints), 3), dtype=torch.float32
|
| 108 |
+
) # mujoco order but mujoco space
|
| 109 |
+
|
| 110 |
+
# for the below indices, mujoco_indices_to_kimodo_indices does not include mujoco root (30 - 1 = 29 elements),
|
| 111 |
+
# while kimodo_indices_to_mujoco_indices inclues the kimodo root (32 elements).
|
| 112 |
+
self._mujoco_indices_to_kimodo_indices = torch.zeros((len(mujoco_hinge_joints),), dtype=torch.int32)
|
| 113 |
+
self._kimodo_indices_to_mujoco_indices = (
|
| 114 |
+
torch.ones((self.skeleton.nbjoints,), dtype=torch.int32) * -1
|
| 115 |
+
) # -1 means not in the csv skeleton
|
| 116 |
+
|
| 117 |
+
self._nb_joints_mujoco = len(mujoco_hinge_joints) + 1
|
| 118 |
+
self._nb_joints_kimodo = self.skeleton.nbjoints
|
| 119 |
+
self._mujoco_joint_including_root_parent_list = torch.full(
|
| 120 |
+
(len(mujoco_hinge_joints) + 1,), -1, dtype=torch.int32
|
| 121 |
+
)
|
| 122 |
+
self._mujoco_joint_including_root_list = ["pelvis_skel"]
|
| 123 |
+
|
| 124 |
+
for joint_id_in_csv, joint in enumerate(mujoco_hinge_joints):
|
| 125 |
+
joint_name_in_skeleton = joint.get("name").replace("_joint", "_skel")
|
| 126 |
+
joint_parent_name_in_skeleton = self.skeleton.bone_parents[joint_name_in_skeleton]
|
| 127 |
+
|
| 128 |
+
self._mujoco_joint_including_root_list.append(joint_name_in_skeleton)
|
| 129 |
+
self._mujoco_joint_including_root_parent_list[joint_id_in_csv + 1] = (
|
| 130 |
+
self._mujoco_joint_including_root_list.index(joint_parent_name_in_skeleton)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
joint_idx_in_kimodo_skeleton = self.skeleton.bone_order_names.index(joint_name_in_skeleton)
|
| 134 |
+
axis_values = [float(x) for x in (joint.get("axis") or joint_axes[joint.get("class")]).split(" ")]
|
| 135 |
+
|
| 136 |
+
# the mapped axis in kimodo skeleton space is calculated as bones_axis = mujoco_to_kimodo.apply(axis_values)
|
| 137 |
+
# [1, 0, 0] -> [0, 0, 1]; [0, 1, 0] -> [1, 0, 0]; [0, 0, 1] -> [0, 1, 0]
|
| 138 |
+
mujoco_joint_axis_mapping_kimodo_space = [
|
| 139 |
+
torch.tensor([0, 0, 1]),
|
| 140 |
+
torch.tensor([1, 0, 0]),
|
| 141 |
+
torch.tensor([0, 1, 0]),
|
| 142 |
+
][np.argmax(axis_values)]
|
| 143 |
+
|
| 144 |
+
self._mujoco_joint_axis_values_kimodo_space[joint_id_in_csv] = mujoco_joint_axis_mapping_kimodo_space
|
| 145 |
+
self._mujoco_joint_axis_values_mujoco_space[joint_id_in_csv] = torch.tensor(axis_values)
|
| 146 |
+
|
| 147 |
+
self._mujoco_indices_to_kimodo_indices[joint_id_in_csv] = joint_idx_in_kimodo_skeleton
|
| 148 |
+
self._kimodo_indices_to_mujoco_indices[joint_idx_in_kimodo_skeleton] = (
|
| 149 |
+
joint_id_in_csv + 1
|
| 150 |
+
) # +1 for the root
|
| 151 |
+
self._kimodo_indices_to_mujoco_indices[0] = 0 # the root joint mapping
|
| 152 |
+
|
| 153 |
+
# Joint limits (min, max) in radians for each mujoco hinge, for clamping
|
| 154 |
+
self._joint_limits_min = torch.full((len(mujoco_hinge_joints),), float("-inf"), dtype=torch.float32)
|
| 155 |
+
self._joint_limits_max = torch.full((len(mujoco_hinge_joints),), float("inf"), dtype=torch.float32)
|
| 156 |
+
for joint_id_in_csv, joint in enumerate(mujoco_hinge_joints):
|
| 157 |
+
range_vals = None
|
| 158 |
+
if joint.get("range"):
|
| 159 |
+
range_vals = [float(x) for x in joint.get("range").split()]
|
| 160 |
+
elif joint.get("class") and joint.get("class") in class_ranges:
|
| 161 |
+
lo, hi = class_ranges[joint.get("class")]
|
| 162 |
+
range_vals = [lo, hi]
|
| 163 |
+
if range_vals is not None and len(range_vals) == 2:
|
| 164 |
+
self._joint_limits_min[joint_id_in_csv] = range_vals[0]
|
| 165 |
+
self._joint_limits_max[joint_id_in_csv] = range_vals[1]
|
| 166 |
+
|
| 167 |
+
# load the offset matrices from the xml
|
| 168 |
+
R_zup_to_yup = Rotation.from_euler("x", -90, degrees=True)
|
| 169 |
+
x_forward_to_y_forward = Rotation.from_euler("z", -90, degrees=True)
|
| 170 |
+
mujoco_to_kimodo = R_zup_to_yup * x_forward_to_y_forward
|
| 171 |
+
|
| 172 |
+
self._rot_offsets_q2t = torch.zeros(len(self._kimodo_indices_to_mujoco_indices), 3, 3, dtype=torch.float32)
|
| 173 |
+
self._rot_offsets_q2t[...] = torch.eye(3)[None]
|
| 174 |
+
|
| 175 |
+
self._rot_offsets_f2q = torch.zeros(len(self._kimodo_indices_to_mujoco_indices), 3, 3, dtype=torch.float32)
|
| 176 |
+
self._rot_offsets_f2q[...] = torch.eye(3)[None]
|
| 177 |
+
parent_map = {child: parent for parent in root.iter() for child in parent}
|
| 178 |
+
for i, joint in enumerate(mujoco_hinge_joints):
|
| 179 |
+
body = parent_map[joint]
|
| 180 |
+
if "quat" in body.attrib:
|
| 181 |
+
rot = Rotation.from_quat(
|
| 182 |
+
[float(x) for x in body.get("quat").strip().split(" ")],
|
| 183 |
+
scalar_first=True,
|
| 184 |
+
)
|
| 185 |
+
idx = self._mujoco_indices_to_kimodo_indices[i]
|
| 186 |
+
self._rot_offsets_q2t[idx] = torch.from_numpy(rot.as_matrix())
|
| 187 |
+
rot = mujoco_to_kimodo * rot * mujoco_to_kimodo.inv()
|
| 188 |
+
self._rot_offsets_f2q[idx] = torch.from_numpy(rot.as_matrix().T)
|
| 189 |
+
|
| 190 |
+
# Hinge axis in f2q space so extraction uses the same frame as joint_rot_f2q.
|
| 191 |
+
# Then extract(offset) gives the angle s.t. axis_angle(angle * axis_f2q) = offset, and
|
| 192 |
+
# reconstruction R_local = offset.T @ axis_angle(angle * axis_f2q) = I when input is identity.
|
| 193 |
+
axis_kimodo = self._mujoco_joint_axis_values_kimodo_space
|
| 194 |
+
self._mujoco_joint_axis_values_f2q_space = torch.zeros_like(axis_kimodo)
|
| 195 |
+
for i in range(len(mujoco_hinge_joints)):
|
| 196 |
+
j = self._mujoco_indices_to_kimodo_indices[i].item()
|
| 197 |
+
axis_f2q = torch.mv(self._rot_offsets_f2q[j], axis_kimodo[i])
|
| 198 |
+
n = axis_f2q.norm()
|
| 199 |
+
if n > 1e-8:
|
| 200 |
+
axis_f2q = axis_f2q / n
|
| 201 |
+
self._mujoco_joint_axis_values_f2q_space[i] = axis_f2q
|
| 202 |
+
|
| 203 |
+
# Rest-pose DOFs: angle we extract when R_local = I (t-pose). MuJoCo limits are
|
| 204 |
+
# relative to joint zero (rest pose), so we must clamp in MuJoCo space: convert
|
| 205 |
+
# joint_dofs to mujoco_angle = joint_dofs - rest_dofs, clamp, then back.
|
| 206 |
+
rest_rot_f2q = self._rot_offsets_f2q[self._mujoco_indices_to_kimodo_indices]
|
| 207 |
+
rest_rot_f2q = rest_rot_f2q.unsqueeze(0).unsqueeze(0)
|
| 208 |
+
self._rest_dofs = self._local_rots_f2q_to_joint_dofs(rest_rot_f2q).squeeze(0).squeeze(0)
|
| 209 |
+
# Axis-angle rest DOFs: angle s.t. axis_angle(angle * axis_f2q) = offset. Used in
|
| 210 |
+
# project_to_real_robot_rotations so extract+reconstruct round-trip and t-pose is preserved.
|
| 211 |
+
rest_rot_f2q_flat = self._rot_offsets_f2q[self._mujoco_indices_to_kimodo_indices]
|
| 212 |
+
full_aa = matrix_to_axis_angle(rest_rot_f2q_flat)
|
| 213 |
+
self._rest_dofs_axis_angle = (full_aa * self._mujoco_joint_axis_values_f2q_space).sum(dim=-1)
|
| 214 |
+
|
| 215 |
+
def dict_to_qpos(
|
| 216 |
+
self,
|
| 217 |
+
output: dict,
|
| 218 |
+
device: Optional[str] = None,
|
| 219 |
+
root_quat_w_first: bool = True,
|
| 220 |
+
numpy: bool = True,
|
| 221 |
+
mujoco_rest_zero: bool = False,
|
| 222 |
+
):
|
| 223 |
+
"""Convert kimodo output dict to mujoco qpos format.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
output: dict with keys "local_rot_mats" and "root_positions".
|
| 227 |
+
device: device to use for the output.
|
| 228 |
+
root_quat_w_first: If True, quaternion in qpos is (w,x,y,z).
|
| 229 |
+
numpy: If True, convert the output to numpy array.
|
| 230 |
+
mujoco_rest_zero: If True, joint angles are written so that kimodo rest (t-pose)
|
| 231 |
+
maps to q=0 in MuJoCo. If False, write raw joint_dofs.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
qpos: (B, T, 7+J) mujoco qpos format.
|
| 235 |
+
"""
|
| 236 |
+
local_rot_mats = to_torch(output["local_rot_mats"], device)
|
| 237 |
+
root_positions = to_torch(output["root_positions"], device)
|
| 238 |
+
|
| 239 |
+
qpos = self.to_qpos(
|
| 240 |
+
local_rot_mats,
|
| 241 |
+
root_positions,
|
| 242 |
+
root_quat_w_first=root_quat_w_first,
|
| 243 |
+
mujoco_rest_zero=mujoco_rest_zero,
|
| 244 |
+
)
|
| 245 |
+
if numpy:
|
| 246 |
+
qpos = to_numpy(qpos)
|
| 247 |
+
return qpos
|
| 248 |
+
|
| 249 |
+
def qpos_to_motion_dict(
|
| 250 |
+
self,
|
| 251 |
+
qpos: torch.Tensor | np.ndarray,
|
| 252 |
+
source_fps: float,
|
| 253 |
+
*,
|
| 254 |
+
root_quat_w_first: bool = True,
|
| 255 |
+
mujoco_rest_zero: bool = False,
|
| 256 |
+
):
|
| 257 |
+
"""Inverse of :meth:`to_qpos` / :meth:`dict_to_qpos` for MuJoCo CSV ``(T, 36)`` rows.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
qpos: Shape ``(T, 36)`` or ``(1, T, 36)`` (root xyz, root quat wxyz, 29 joint angles).
|
| 261 |
+
source_fps: Source frame rate (Hz) of the qpos data.
|
| 262 |
+
root_quat_w_first: Must match how the CSV was written (default ``True``).
|
| 263 |
+
mujoco_rest_zero: Must match :meth:`dict_to_qpos` / :meth:`to_qpos`.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
Kimodo motion dict (see :func:`kimodo.exports.motion_io.complete_motion_dict`).
|
| 267 |
+
"""
|
| 268 |
+
from kimodo.exports.motion_io import complete_motion_dict
|
| 269 |
+
|
| 270 |
+
qpos = to_torch(qpos, None)
|
| 271 |
+
if qpos.dim() == 2:
|
| 272 |
+
qpos = qpos.unsqueeze(0)
|
| 273 |
+
device = qpos.device
|
| 274 |
+
dtype = qpos.dtype
|
| 275 |
+
batch_size, num_frames, ncols = qpos.shape
|
| 276 |
+
if ncols != 36:
|
| 277 |
+
raise ValueError(f"Expected qpos last dim 36; got {ncols}")
|
| 278 |
+
|
| 279 |
+
kimodo_to_mujoco_matrix = self.kimodo_to_mujoco_matrix.to(device=device, dtype=dtype)
|
| 280 |
+
mujoco_to_kimodo_matrix = kimodo_to_mujoco_matrix.T
|
| 281 |
+
|
| 282 |
+
root_mujoco = qpos[..., :3]
|
| 283 |
+
root_positions = torch.matmul(mujoco_to_kimodo_matrix[None, None, ...], root_mujoco[..., None]).squeeze(-1)
|
| 284 |
+
|
| 285 |
+
quat = qpos[..., 3:7]
|
| 286 |
+
if root_quat_w_first:
|
| 287 |
+
root_rot_mujoco = quaternion_to_matrix(quat)
|
| 288 |
+
else:
|
| 289 |
+
quat_wxyz = quat[..., [3, 0, 1, 2]]
|
| 290 |
+
root_rot_mujoco = quaternion_to_matrix(quat_wxyz)
|
| 291 |
+
|
| 292 |
+
O0 = self._rot_offsets_f2q[0].to(device=device, dtype=dtype)
|
| 293 |
+
# root_rot_mujoco is (..., 3, 3) after optional batch unsqueeze (e.g. (1, T, 3, 3)).
|
| 294 |
+
# Use ``...il`` so ``k`` sums with ``kl``; ``...ik`` incorrectly keeps ``k`` in the output.
|
| 295 |
+
R_f2q_root = torch.einsum(
|
| 296 |
+
"ij,...jk,kl->...il",
|
| 297 |
+
mujoco_to_kimodo_matrix,
|
| 298 |
+
root_rot_mujoco,
|
| 299 |
+
kimodo_to_mujoco_matrix,
|
| 300 |
+
)
|
| 301 |
+
R_kimodo_root = torch.einsum("ij,...jk->...ik", O0.T, R_f2q_root)
|
| 302 |
+
|
| 303 |
+
joint_dofs = qpos[..., 7:]
|
| 304 |
+
if mujoco_rest_zero:
|
| 305 |
+
rest_dofs = self._rest_dofs.to(device=device, dtype=dtype)
|
| 306 |
+
angles = joint_dofs + rest_dofs[None, None, :]
|
| 307 |
+
use_relative = True
|
| 308 |
+
else:
|
| 309 |
+
angles = joint_dofs
|
| 310 |
+
use_relative = False
|
| 311 |
+
|
| 312 |
+
nb_joints = self.skeleton.nbjoints
|
| 313 |
+
template = torch.eye(3, device=device, dtype=dtype).expand(batch_size, num_frames, nb_joints, 3, 3).contiguous()
|
| 314 |
+
template[:, :, 0] = R_kimodo_root
|
| 315 |
+
|
| 316 |
+
local_rot_mats = self._joint_dofs_to_local_rot_mats(
|
| 317 |
+
angles,
|
| 318 |
+
template,
|
| 319 |
+
device,
|
| 320 |
+
dtype,
|
| 321 |
+
use_relative=use_relative,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
if batch_size != 1:
|
| 325 |
+
raise ValueError(f"Only a single clip is supported; got batch_size={batch_size}")
|
| 326 |
+
|
| 327 |
+
return complete_motion_dict(local_rot_mats[0], root_positions[0], self.skeleton, source_fps)
|
| 328 |
+
|
| 329 |
+
def save_csv(self, qpos: torch.Tensor | np.ndarray, csv_path):
|
| 330 |
+
# comment this
|
| 331 |
+
qpos = to_numpy(qpos)
|
| 332 |
+
shape = qpos.shape
|
| 333 |
+
if len(shape) == 2:
|
| 334 |
+
# only one motion: save it
|
| 335 |
+
np.savetxt(csv_path, qpos, delimiter=",")
|
| 336 |
+
if len(shape) == 3:
|
| 337 |
+
# batch of motions
|
| 338 |
+
if shape[0] == 1:
|
| 339 |
+
# if only one motion, just save it
|
| 340 |
+
np.savetxt(csv_path, qpos[0], delimiter=",")
|
| 341 |
+
else:
|
| 342 |
+
csv_path_base, ext = os.path.splitext(csv_path)
|
| 343 |
+
for i in range(shape[0]):
|
| 344 |
+
self.save_csv(qpos[i], csv_path_base + "_" + str(i).zfill(2) + ext)
|
| 345 |
+
|
| 346 |
+
def _local_rots_to_joint_dofs(
|
| 347 |
+
self,
|
| 348 |
+
local_rot_mats: torch.Tensor,
|
| 349 |
+
axis_vals: torch.Tensor,
|
| 350 |
+
) -> torch.Tensor:
|
| 351 |
+
"""Extract per-joint single-DoF angles (radians) via Euler projection (for to_qpos/f2q)."""
|
| 352 |
+
x_joint_dof = torch.atan2(local_rot_mats[..., 2, 1], local_rot_mats[..., 2, 2])
|
| 353 |
+
y_joint_dof = torch.atan2(local_rot_mats[..., 0, 2], local_rot_mats[..., 0, 0])
|
| 354 |
+
z_joint_dof = torch.atan2(local_rot_mats[..., 1, 0], local_rot_mats[..., 1, 1])
|
| 355 |
+
xyz_joint_dofs = torch.stack([x_joint_dof, y_joint_dof, z_joint_dof], dim=-1)
|
| 356 |
+
axis_vals = axis_vals.to(device=local_rot_mats.device, dtype=local_rot_mats.dtype)
|
| 357 |
+
joint_dofs = (xyz_joint_dofs * axis_vals[None, None, :, :]).sum(dim=-1)
|
| 358 |
+
return joint_dofs
|
| 359 |
+
|
| 360 |
+
def _local_rots_to_joint_dofs_axis_angle(
|
| 361 |
+
self,
|
| 362 |
+
local_rot_mats: torch.Tensor,
|
| 363 |
+
axis_vals: torch.Tensor,
|
| 364 |
+
) -> torch.Tensor:
|
| 365 |
+
"""Extract per-joint single-DoF angles (radians) via axis-angle; round-trips with
|
| 366 |
+
axis_angle_to_matrix.
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
local_rot_mats: (..., num_hinges, 3, 3) in same frame as axis_vals.
|
| 370 |
+
axis_vals: (num_hinges, 3) unit axis per hinge.
|
| 371 |
+
Returns:
|
| 372 |
+
joint_dofs: (..., num_hinges) signed angle = dot(axis_angle(R), axis).
|
| 373 |
+
"""
|
| 374 |
+
axis_vals = axis_vals.to(device=local_rot_mats.device, dtype=local_rot_mats.dtype)
|
| 375 |
+
full_aa = matrix_to_axis_angle(local_rot_mats)
|
| 376 |
+
joint_dofs = (full_aa * axis_vals).sum(dim=-1)
|
| 377 |
+
return joint_dofs
|
| 378 |
+
|
| 379 |
+
def _local_rots_f2q_to_joint_dofs(self, local_rot_mats_f2q: torch.Tensor) -> torch.Tensor:
|
| 380 |
+
"""Extract per-joint single-DoF angles from local rotations in f2q space (for to_qpos)."""
|
| 381 |
+
axis_vals = self._mujoco_joint_axis_values_f2q_space
|
| 382 |
+
return self._local_rots_to_joint_dofs(local_rot_mats_f2q, axis_vals)
|
| 383 |
+
|
| 384 |
+
def _clamp_to_limits(self, joint_dofs: torch.Tensor) -> torch.Tensor:
|
| 385 |
+
"""Clamp joint angles to XML limits (radians).
|
| 386 |
+
|
| 387 |
+
Angles are in kimodo convention (0 = rest).
|
| 388 |
+
"""
|
| 389 |
+
device = joint_dofs.device
|
| 390 |
+
lo = self._joint_limits_min.to(device=device, dtype=joint_dofs.dtype)
|
| 391 |
+
hi = self._joint_limits_max.to(device=device, dtype=joint_dofs.dtype)
|
| 392 |
+
return torch.clamp(joint_dofs, lo[None, None, :], hi[None, None, :])
|
| 393 |
+
|
| 394 |
+
def _clamp_joint_dofs(self, joint_dofs: torch.Tensor, rest_dofs: torch.Tensor) -> torch.Tensor:
|
| 395 |
+
"""Clamp joint angles to MuJoCo limits (radians), with rest_dofs conversion."""
|
| 396 |
+
device = joint_dofs.device
|
| 397 |
+
rest_dofs = rest_dofs.to(device=device, dtype=joint_dofs.dtype)
|
| 398 |
+
mujoco_dofs = joint_dofs - rest_dofs[None, None, :]
|
| 399 |
+
lo = self._joint_limits_min.to(device=device, dtype=joint_dofs.dtype)
|
| 400 |
+
hi = self._joint_limits_max.to(device=device, dtype=joint_dofs.dtype)
|
| 401 |
+
mujoco_dofs = torch.clamp(mujoco_dofs, lo[None, None, :], hi[None, None, :])
|
| 402 |
+
return mujoco_dofs + rest_dofs[None, None, :]
|
| 403 |
+
|
| 404 |
+
def _joint_dofs_to_local_rot_mats(
|
| 405 |
+
self,
|
| 406 |
+
joint_dofs: torch.Tensor,
|
| 407 |
+
original_local_rot_mats: torch.Tensor,
|
| 408 |
+
device: torch.device,
|
| 409 |
+
dtype: torch.dtype,
|
| 410 |
+
use_relative: bool = False,
|
| 411 |
+
) -> torch.Tensor:
|
| 412 |
+
"""Reconstruct full local rotation matrices from 1-DoF angles."""
|
| 413 |
+
out = original_local_rot_mats.clone()
|
| 414 |
+
axis_kimodo = self._mujoco_joint_axis_values_kimodo_space.to(device=device, dtype=dtype)
|
| 415 |
+
for i in range(joint_dofs.shape[-1]):
|
| 416 |
+
j = self._mujoco_indices_to_kimodo_indices[i].item()
|
| 417 |
+
angle = joint_dofs[..., i]
|
| 418 |
+
axis = axis_kimodo[i]
|
| 419 |
+
if use_relative:
|
| 420 |
+
axis_angle = angle[..., None] * axis[None, None, :]
|
| 421 |
+
R_local = axis_angle_to_matrix(axis_angle)
|
| 422 |
+
else:
|
| 423 |
+
rot_offsets_f2q = self._rot_offsets_f2q.to(device=device, dtype=dtype)
|
| 424 |
+
axis_in_f2q = torch.mv(rot_offsets_f2q[j], axis)
|
| 425 |
+
axis_angle = angle[..., None] * axis_in_f2q[None, None, :]
|
| 426 |
+
R_f2q = axis_angle_to_matrix(axis_angle)
|
| 427 |
+
R_local = torch.einsum("ij,btjk->btik", rot_offsets_f2q[j].T, R_f2q)
|
| 428 |
+
out[:, :, j, :, :] = R_local
|
| 429 |
+
return out
|
| 430 |
+
|
| 431 |
+
@ensure_batched(local_rot_mats=5, root_positions=3, lengths=1)
|
| 432 |
+
def project_to_real_robot_rotations(
|
| 433 |
+
self,
|
| 434 |
+
local_rot_mats: torch.Tensor,
|
| 435 |
+
root_positions: torch.Tensor,
|
| 436 |
+
clamp_to_limits: bool = True,
|
| 437 |
+
mujoco_rest_zero: bool = False,
|
| 438 |
+
) -> dict:
|
| 439 |
+
"""Project full 3D local rotations to G1 real robot DoF and back to 3D for viz.
|
| 440 |
+
|
| 441 |
+
Joint angles are extracted along each hinge axis, optionally clamped to XML limits, then
|
| 442 |
+
reconstructed to 3D rotations. When mujoco_rest_zero=False (default), raw angles are used
|
| 443 |
+
(baked-with-quat). When True, angles are relative to rest (0 = T-pose in MuJoCo).
|
| 444 |
+
"""
|
| 445 |
+
device = local_rot_mats.device
|
| 446 |
+
dtype = local_rot_mats.dtype
|
| 447 |
+
|
| 448 |
+
# Transform to f2q frame and extract 1-DoF angles (axis-angle projection).
|
| 449 |
+
local_rot_f2q = torch.matmul(self._rot_offsets_f2q.to(device=device, dtype=dtype), local_rot_mats)
|
| 450 |
+
hinge_rots = local_rot_f2q[:, :, self._mujoco_indices_to_kimodo_indices, :, :]
|
| 451 |
+
axis_f2q = self._mujoco_joint_axis_values_f2q_space.to(device=device, dtype=dtype)
|
| 452 |
+
joint_dofs = self._local_rots_to_joint_dofs_axis_angle(hinge_rots, axis_f2q)
|
| 453 |
+
|
| 454 |
+
# Optionally express angles relative to rest (MuJoCo q=0 at T-pose).
|
| 455 |
+
if mujoco_rest_zero:
|
| 456 |
+
rest_dofs = self._rest_dofs_axis_angle.to(device=device, dtype=dtype)
|
| 457 |
+
angles = joint_dofs - rest_dofs[None, None, :]
|
| 458 |
+
use_relative = True
|
| 459 |
+
else:
|
| 460 |
+
angles = joint_dofs
|
| 461 |
+
use_relative = False
|
| 462 |
+
|
| 463 |
+
if clamp_to_limits:
|
| 464 |
+
if mujoco_rest_zero:
|
| 465 |
+
angles = self._clamp_to_limits(angles)
|
| 466 |
+
else:
|
| 467 |
+
rest_dofs_aa = self._rest_dofs_axis_angle.to(device=device, dtype=dtype)
|
| 468 |
+
angles = self._clamp_joint_dofs(angles, rest_dofs_aa)
|
| 469 |
+
|
| 470 |
+
# Reconstruct 3D local rotations from 1-DoF angles and run FK.
|
| 471 |
+
local_rot_mats_proj = self._joint_dofs_to_local_rot_mats(
|
| 472 |
+
angles, local_rot_mats, device, dtype, use_relative=use_relative
|
| 473 |
+
)
|
| 474 |
+
global_rot_mats, posed_joints, _ = self.skeleton.fk(local_rot_mats_proj, root_positions)
|
| 475 |
+
return {
|
| 476 |
+
"local_rot_mats": local_rot_mats_proj,
|
| 477 |
+
"global_rot_mats": global_rot_mats,
|
| 478 |
+
"posed_joints": posed_joints,
|
| 479 |
+
"root_positions": root_positions,
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
@ensure_batched(local_rot_mats=5, root_positions=3, lengths=1)
|
| 483 |
+
def to_qpos(
|
| 484 |
+
self,
|
| 485 |
+
local_rot_mats: torch.Tensor,
|
| 486 |
+
root_positions: torch.Tensor,
|
| 487 |
+
root_quat_w_first: bool = True,
|
| 488 |
+
mujoco_rest_zero: bool = False,
|
| 489 |
+
) -> torch.Tensor:
|
| 490 |
+
"""Fast batch conversion from kimodo features to mujoco qpos format.
|
| 491 |
+
|
| 492 |
+
Args:
|
| 493 |
+
local_rot_mats: (B, T, J, 3, 3) local rotation matrices (kimodo convention).
|
| 494 |
+
root_positions: (B, T, 3) root positions.
|
| 495 |
+
root_quat_w_first: If True, quaternion in qpos is (w,x,y,z).
|
| 496 |
+
mujoco_rest_zero: If True, joint angles are written so that kimodo rest (t-pose)
|
| 497 |
+
maps to q=0 in MuJoCo. If False, write raw joint_dofs.
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
torch.Tensor of shape [batch, numFrames, 36] containing mujoco qpos data:
|
| 501 |
+
- root_trans (3) + root_quat (4) + joint_dofs (29) = 36 columns
|
| 502 |
+
"""
|
| 503 |
+
|
| 504 |
+
batch_size, num_frames, nb_joints = local_rot_mats.shape[:3]
|
| 505 |
+
device, dtype = local_rot_mats.device, local_rot_mats.dtype
|
| 506 |
+
|
| 507 |
+
local_rot_mats = torch.matmul(self._rot_offsets_f2q.to(device), local_rot_mats)
|
| 508 |
+
|
| 509 |
+
batch_size, num_frames = root_positions.shape[0], root_positions.shape[1]
|
| 510 |
+
|
| 511 |
+
# Move precomputed matrices to the same device/dtype
|
| 512 |
+
kimodo_to_mujoco_matrix = self.kimodo_to_mujoco_matrix.to(device=device, dtype=dtype)
|
| 513 |
+
|
| 514 |
+
# Initialize output tensor: [batch, numFrames, 36]
|
| 515 |
+
qpos = torch.zeros((batch_size, num_frames, 36), dtype=dtype, device=device)
|
| 516 |
+
|
| 517 |
+
# Convert root translation: apply coordinate transformation
|
| 518 |
+
root_positions_mujoco = torch.matmul(kimodo_to_mujoco_matrix[None, None, ...], root_positions[..., None])
|
| 519 |
+
qpos[:, :, :3] = root_positions_mujoco.view(batch_size, num_frames, 3)
|
| 520 |
+
|
| 521 |
+
# Convert root rotation: apply coordinate transformation to rotation matrix
|
| 522 |
+
root_rot = local_rot_mats[:, :, 0, :] # [batch, numFrames, 3, 3]
|
| 523 |
+
|
| 524 |
+
# Apply coordinate transformation: R_mujoco = kimodo_to_mujoco * R_kimodo * kimodo_to_mujoco^T
|
| 525 |
+
mujoco_to_kimodo_matrix = kimodo_to_mujoco_matrix.T
|
| 526 |
+
root_rot_mujoco = torch.matmul(
|
| 527 |
+
torch.matmul(kimodo_to_mujoco_matrix[None, None, ...], root_rot),
|
| 528 |
+
mujoco_to_kimodo_matrix[None, None, ...],
|
| 529 |
+
)
|
| 530 |
+
root_rot_quat = matrix_to_quaternion(root_rot_mujoco) # [w, x, y, z]
|
| 531 |
+
if root_quat_w_first:
|
| 532 |
+
qpos[:, :, 3:7] = root_rot_quat[:, :, [0, 1, 2, 3]] # [w, x, y, z]
|
| 533 |
+
else:
|
| 534 |
+
qpos[:, :, 3:7] = root_rot_quat[:, :, [1, 2, 3, 0]] # [w, x, y, z] -> [x, y, z, w]
|
| 535 |
+
|
| 536 |
+
# Joint DOFs: raw angles or relative to rest (rest = q=0 in MuJoCo).
|
| 537 |
+
joint_rot_f2q = local_rot_mats[:, :, self._mujoco_indices_to_kimodo_indices, :, :]
|
| 538 |
+
joint_dofs = self._local_rots_f2q_to_joint_dofs(joint_rot_f2q)
|
| 539 |
+
if mujoco_rest_zero:
|
| 540 |
+
rest_dofs = self._rest_dofs.to(device=device, dtype=dtype)
|
| 541 |
+
qpos[:, :, 7:] = joint_dofs - rest_dofs[None, None, :]
|
| 542 |
+
else:
|
| 543 |
+
qpos[:, :, 7:] = joint_dofs
|
| 544 |
+
return qpos
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def apply_g1_real_robot_projection(
|
| 548 |
+
skeleton: G1Skeleton34,
|
| 549 |
+
joints_pos: torch.Tensor,
|
| 550 |
+
joints_rot: torch.Tensor,
|
| 551 |
+
clamp_to_limits: bool = True,
|
| 552 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 553 |
+
"""Project G1 motion to real robot DoF (1-DoF per joint) with optional axis limits.
|
| 554 |
+
|
| 555 |
+
Extracts a single angle per hinge along its axis (1-DoF), optionally clamps to
|
| 556 |
+
joint limits from the MuJoCo XML (when clamp_to_limits=True), then reconstructs
|
| 557 |
+
3D rotations and runs FK. T-pose (identity local rotations) is preserved.
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
skeleton: G1 skeleton instance.
|
| 561 |
+
joints_pos: (T, J, 3) or (B, T, J, 3) joint positions in global space.
|
| 562 |
+
joints_rot: (T, J, 3, 3) or (B, T, J, 3, 3) global rotation matrices.
|
| 563 |
+
clamp_to_limits: If True, clamp joint angles to XML axis limits (default True).
|
| 564 |
+
|
| 565 |
+
Returns:
|
| 566 |
+
(posed_joints, global_rot_mats) as tensors, same shape as inputs (batch preserved).
|
| 567 |
+
"""
|
| 568 |
+
|
| 569 |
+
local_rot_mats = global_rots_to_local_rots(joints_rot, skeleton)
|
| 570 |
+
root_positions = joints_pos[..., skeleton.root_idx, :]
|
| 571 |
+
|
| 572 |
+
# Converter expects batch dim (B, T, ...); add and remove if single sequence.
|
| 573 |
+
single_sequence = local_rot_mats.dim() == 4
|
| 574 |
+
if single_sequence:
|
| 575 |
+
local_rot_mats = local_rot_mats.unsqueeze(0)
|
| 576 |
+
root_positions = root_positions.unsqueeze(0)
|
| 577 |
+
|
| 578 |
+
converter = MujocoQposConverter(skeleton)
|
| 579 |
+
projected = converter.project_to_real_robot_rotations(
|
| 580 |
+
local_rot_mats, root_positions, clamp_to_limits=clamp_to_limits
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
out_pos = projected["posed_joints"]
|
| 584 |
+
out_rot = projected["global_rot_mats"]
|
| 585 |
+
if single_sequence:
|
| 586 |
+
out_pos = out_pos.squeeze(0)
|
| 587 |
+
out_rot = out_rot.squeeze(0)
|
| 588 |
+
return out_pos, out_rot
|
kimodo/exports/smplx.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Convert kimodo motion to AMASS/SMPL-X compatible parameters (axis-angle, Y-up or Z-up)."""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import einops
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from kimodo.assets import skeleton_asset_path
|
| 13 |
+
from kimodo.geometry import axis_angle_to_matrix, matrix_to_axis_angle
|
| 14 |
+
from kimodo.tools import ensure_batched, to_numpy, to_torch
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def kimodo_y_up_to_amass_coord_rotation_matrix() -> np.ndarray:
|
| 18 |
+
"""3x3 rotation mapping Kimodo Y-up (+Z forward) to AMASS Z-up (+Y forward).
|
| 19 |
+
|
| 20 |
+
Used by :func:`get_amass_parameters` and :func:`amass_arrays_to_kimodo_motion` (inverse).
|
| 21 |
+
"""
|
| 22 |
+
y_up_to_z_up = np.array(
|
| 23 |
+
[
|
| 24 |
+
[1.0, 0.0, 0.0],
|
| 25 |
+
[0.0, 0.0, -1.0],
|
| 26 |
+
[0.0, 1.0, 0.0],
|
| 27 |
+
],
|
| 28 |
+
dtype=np.float32,
|
| 29 |
+
)
|
| 30 |
+
rot_z_180 = np.array(
|
| 31 |
+
[
|
| 32 |
+
[-1.0, 0.0, 0.0],
|
| 33 |
+
[0.0, -1.0, 0.0],
|
| 34 |
+
[0.0, 0.0, 1.0],
|
| 35 |
+
],
|
| 36 |
+
dtype=np.float32,
|
| 37 |
+
)
|
| 38 |
+
return np.matmul(rot_z_180, y_up_to_z_up).astype(np.float32)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@ensure_batched(local_rot_mats=5, root_positions=3, lengths=1)
|
| 42 |
+
def get_amass_parameters(
|
| 43 |
+
local_rot_mats,
|
| 44 |
+
root_positions,
|
| 45 |
+
skeleton,
|
| 46 |
+
z_up=True,
|
| 47 |
+
):
|
| 48 |
+
"""Convert local rot mats and root positions to AMASS-style trans and pose_body; optional z_up
|
| 49 |
+
coordinate transform.
|
| 50 |
+
|
| 51 |
+
Our method generates motions with Y-up and +Z forward; if z_up=True, transform to Z-up and +Y
|
| 52 |
+
forward as in AMASS.
|
| 53 |
+
"""
|
| 54 |
+
# Our method generate motions with Y-up and +Z forward
|
| 55 |
+
# if z_up = True, we transform this to: Z-up with +Y forward, as in AMASS
|
| 56 |
+
# Remove the root offset; SMPL-X FK adds pelvis offset back.
|
| 57 |
+
pelvis_offset = skeleton.neutral_joints[skeleton.root_idx].cpu().numpy()
|
| 58 |
+
trans = root_positions - pelvis_offset
|
| 59 |
+
|
| 60 |
+
root_rot_mats = to_numpy(local_rot_mats[:, :, 0])
|
| 61 |
+
local_rot_axis_angle = to_numpy(matrix_to_axis_angle(to_torch(local_rot_mats)))
|
| 62 |
+
pose_body = einops.rearrange(local_rot_axis_angle[:, :, 1:], "b t j d -> b t (j d)")
|
| 63 |
+
|
| 64 |
+
# Optionally convert from Y-up to Z-up coordinates.
|
| 65 |
+
if z_up:
|
| 66 |
+
y_up_to_z_up = kimodo_y_up_to_amass_coord_rotation_matrix()
|
| 67 |
+
root_rot_mats = np.matmul(y_up_to_z_up, root_rot_mats)
|
| 68 |
+
trans = np.matmul(trans + pelvis_offset, y_up_to_z_up.T) - pelvis_offset
|
| 69 |
+
|
| 70 |
+
root_orient = to_numpy(matrix_to_axis_angle(to_torch(root_rot_mats)))
|
| 71 |
+
return trans, root_orient, pose_body
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def amass_arrays_to_kimodo_motion(
|
| 75 |
+
trans: np.ndarray,
|
| 76 |
+
root_orient: np.ndarray,
|
| 77 |
+
pose_body: np.ndarray,
|
| 78 |
+
skeleton,
|
| 79 |
+
source_fps: float,
|
| 80 |
+
*,
|
| 81 |
+
z_up: bool = True,
|
| 82 |
+
):
|
| 83 |
+
"""Inverse of :func:`get_amass_parameters` for a single sequence (AMASS → Kimodo motion dict).
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
trans: ``(T, 3)`` AMASS root translation (same as ``trans`` in AMASS NPZ).
|
| 87 |
+
root_orient: ``(T, 3)`` axis-angle root orientation in AMASS coordinates (z-up when ``z_up``).
|
| 88 |
+
pose_body: ``(T, 63)`` body pose axis-angle (21 joints × 3).
|
| 89 |
+
skeleton: :class:`~kimodo.skeleton.definitions.SMPLXSkeleton22` instance.
|
| 90 |
+
source_fps: Source frame rate (Hz) of the AMASS recording.
|
| 91 |
+
z_up: If ``True``, invert the same Y-up↔Z-up transform as ``get_amass_parameters(..., z_up=True)``.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
Motion dict compatible with :func:`kimodo.exports.motion_io.save_kimodo_npz`.
|
| 95 |
+
"""
|
| 96 |
+
from kimodo.exports.motion_io import complete_motion_dict
|
| 97 |
+
|
| 98 |
+
trans = np.asarray(trans, dtype=np.float32)
|
| 99 |
+
root_orient = np.asarray(root_orient, dtype=np.float32)
|
| 100 |
+
pose_body = np.asarray(pose_body, dtype=np.float32)
|
| 101 |
+
if trans.ndim != 2 or trans.shape[-1] != 3:
|
| 102 |
+
raise ValueError(f"trans must be (T, 3); got {trans.shape}")
|
| 103 |
+
if root_orient.shape != trans.shape:
|
| 104 |
+
raise ValueError(f"root_orient shape {root_orient.shape} must match trans {trans.shape}")
|
| 105 |
+
t = trans.shape[0]
|
| 106 |
+
if pose_body.shape != (t, 63):
|
| 107 |
+
raise ValueError(f"pose_body must be (T, 63); got {pose_body.shape}")
|
| 108 |
+
|
| 109 |
+
pelvis_offset = skeleton.neutral_joints[skeleton.root_idx].detach().cpu().numpy().astype(np.float32)
|
| 110 |
+
device = skeleton.neutral_joints.device
|
| 111 |
+
dtype = torch.float32
|
| 112 |
+
|
| 113 |
+
Y_np = kimodo_y_up_to_amass_coord_rotation_matrix()
|
| 114 |
+
if z_up:
|
| 115 |
+
y_up_to_z_up = torch.from_numpy(Y_np).to(device=device, dtype=dtype)
|
| 116 |
+
# trans_amass = root_kimodo @ Y.T - pelvis_offset => root_kimodo = (trans_amass + pelvis_offset) @ Y
|
| 117 |
+
root_positions_np = (trans + pelvis_offset) @ Y_np
|
| 118 |
+
else:
|
| 119 |
+
root_positions_np = trans + pelvis_offset
|
| 120 |
+
|
| 121 |
+
root_positions = torch.from_numpy(root_positions_np).to(device=device, dtype=dtype)
|
| 122 |
+
|
| 123 |
+
R_amass_root = axis_angle_to_matrix(torch.from_numpy(root_orient).to(device=device, dtype=dtype))
|
| 124 |
+
if z_up:
|
| 125 |
+
R_kimodo_root = torch.einsum("ij,tjk->tik", y_up_to_z_up.T, R_amass_root)
|
| 126 |
+
else:
|
| 127 |
+
R_kimodo_root = R_amass_root
|
| 128 |
+
|
| 129 |
+
nb = skeleton.nbjoints
|
| 130 |
+
if nb != 22:
|
| 131 |
+
raise ValueError(f"Expected SMPL-X body skeleton with 22 joints; got {nb}")
|
| 132 |
+
|
| 133 |
+
local_rot_mats = torch.zeros((t, nb, 3, 3), device=device, dtype=dtype)
|
| 134 |
+
local_rot_mats[:, 0] = R_kimodo_root
|
| 135 |
+
|
| 136 |
+
pose_aa = torch.from_numpy(pose_body.reshape(t, 21, 3)).to(device=device, dtype=dtype)
|
| 137 |
+
local_rot_mats[:, 1:] = axis_angle_to_matrix(pose_aa.reshape(-1, 3)).reshape(t, 21, 3, 3)
|
| 138 |
+
|
| 139 |
+
return complete_motion_dict(local_rot_mats, root_positions, skeleton, source_fps)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def amass_npz_to_kimodo_motion(npz_path: str, skeleton, source_fps: Optional[float] = None, *, z_up: bool = True):
|
| 143 |
+
"""Load an AMASS-style ``.npz`` and return a Kimodo motion dict.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
npz_path: Path to AMASS NPZ (``trans``, ``root_orient``, ``pose_body``, ...).
|
| 147 |
+
skeleton: SMPL-X skeleton instance.
|
| 148 |
+
source_fps: Source frame rate (Hz); if ``None``, uses ``mocap_frame_rate``
|
| 149 |
+
from the file when present, else ``30.0``.
|
| 150 |
+
z_up: Same meaning as :func:`amass_arrays_to_kimodo_motion`.
|
| 151 |
+
"""
|
| 152 |
+
with np.load(npz_path, allow_pickle=True) as data:
|
| 153 |
+
trans = np.asarray(data["trans"], dtype=np.float32)
|
| 154 |
+
root_orient = np.asarray(data["root_orient"], dtype=np.float32)
|
| 155 |
+
pose_body = np.asarray(data["pose_body"], dtype=np.float32)
|
| 156 |
+
if source_fps is None:
|
| 157 |
+
source_fps = float(data["mocap_frame_rate"]) if "mocap_frame_rate" in data.files else 30.0
|
| 158 |
+
|
| 159 |
+
return amass_arrays_to_kimodo_motion(trans, root_orient, pose_body, skeleton, source_fps, z_up=z_up)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class AMASSConverter:
|
| 163 |
+
def __init__(
|
| 164 |
+
self,
|
| 165 |
+
fps,
|
| 166 |
+
skeleton,
|
| 167 |
+
beta_path=str(skeleton_asset_path("smplx22", "beta.npy")),
|
| 168 |
+
mean_hands_path=str(skeleton_asset_path("smplx22", "mean_hands.npy")),
|
| 169 |
+
):
|
| 170 |
+
self.fps = fps
|
| 171 |
+
self.skeleton = skeleton
|
| 172 |
+
# Load betas
|
| 173 |
+
if os.path.exists(beta_path):
|
| 174 |
+
# only use first 16 betas to match AMASS
|
| 175 |
+
betas = np.load(beta_path)[:16]
|
| 176 |
+
else:
|
| 177 |
+
betas = np.zeros(16)
|
| 178 |
+
|
| 179 |
+
# Load mean hands
|
| 180 |
+
if os.path.exists(mean_hands_path):
|
| 181 |
+
mean_hands = np.load(mean_hands_path)
|
| 182 |
+
else:
|
| 183 |
+
mean_hands = np.zeros(90)
|
| 184 |
+
|
| 185 |
+
self.default_frame_params = {
|
| 186 |
+
"pose_jaw": np.zeros(3),
|
| 187 |
+
"pose_eye": np.zeros(6),
|
| 188 |
+
"pose_hand": mean_hands,
|
| 189 |
+
}
|
| 190 |
+
self.output_dict_base = {
|
| 191 |
+
"gender": "neutral",
|
| 192 |
+
"surface_model_type": "smplx",
|
| 193 |
+
"betas": betas,
|
| 194 |
+
"num_betas": len(betas),
|
| 195 |
+
"mocap_frame_rate": float(fps),
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
def convert_save_npz(self, output: dict, npz_path, z_up=True):
|
| 199 |
+
trans, root_orient, pose_body = get_amass_parameters(
|
| 200 |
+
output["local_rot_mats"],
|
| 201 |
+
output["root_positions"],
|
| 202 |
+
self.skeleton,
|
| 203 |
+
z_up=z_up,
|
| 204 |
+
)
|
| 205 |
+
nb_frames = trans.shape[-2]
|
| 206 |
+
|
| 207 |
+
amass_output_base = self.output_dict_base.copy()
|
| 208 |
+
for key, val in self.default_frame_params.items():
|
| 209 |
+
amass_output_base[key] = einops.repeat(val, "d -> t d", t=nb_frames)
|
| 210 |
+
|
| 211 |
+
amass_output_base["mocap_time_length"] = nb_frames / self.fps
|
| 212 |
+
self.save_npz(trans, root_orient, pose_body, amass_output_base, npz_path)
|
| 213 |
+
|
| 214 |
+
def save_npz(self, trans, root_orient, pose_body, base_output, npz_path):
|
| 215 |
+
shape = trans.shape
|
| 216 |
+
if len(shape) == 3 and shape[0] == 1:
|
| 217 |
+
# if only one motion, squeeze the data
|
| 218 |
+
trans = trans[0]
|
| 219 |
+
root_orient = root_orient[0]
|
| 220 |
+
pose_body = pose_body[0]
|
| 221 |
+
shape = trans.shape
|
| 222 |
+
if len(shape) == 2:
|
| 223 |
+
amass_output = {
|
| 224 |
+
"trans": trans,
|
| 225 |
+
"root_orient": root_orient,
|
| 226 |
+
"pose_body": pose_body,
|
| 227 |
+
} | base_output
|
| 228 |
+
np.savez(npz_path, **amass_output)
|
| 229 |
+
|
| 230 |
+
elif len(shape) == 3:
|
| 231 |
+
# real batch of motions
|
| 232 |
+
npz_path_base, ext = os.path.splitext(npz_path)
|
| 233 |
+
for i in range(shape[0]):
|
| 234 |
+
npz_path_i = npz_path_base + "_" + str(i).zfill(2) + ext
|
| 235 |
+
self.save_npz(trans[i], root_orient[i], pose_body[i], base_output, npz_path_i)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# amass_output = {
|
| 239 |
+
# "gender": "neutral",
|
| 240 |
+
# "surface_model_type": "smplx",
|
| 241 |
+
# "mocap_frame_rate": float(fps),
|
| 242 |
+
# "mocap_time_length": len(motion) / float(fps)
|
| 243 |
+
# "trans": trans,
|
| 244 |
+
# "betas": betas,
|
| 245 |
+
# "num_betas": len(betas),
|
| 246 |
+
# "root_orient": np.array([T, 3]), # axis angle
|
| 247 |
+
# "pose_body": np.array([T, 63]), # 63=21*3, axis angle 21 = 22 - root
|
| 248 |
+
# "pose_hand": np.array([T, 90]), # 90=30*3=15*2*3 axis angle (load from mean_hands)
|
| 249 |
+
# "pose_jaw": np.array([T, 3]), # all zeros is fine
|
| 250 |
+
# "pose_eye": np.array([T, 6]), # all zeros is fine`
|
| 251 |
+
# }
|
kimodo/geometry.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Rotation and representation conversions: axis-angle, quaternion, matrix, 6D continuous."""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def angle_to_Y_rotation_matrix(angle: torch.Tensor) -> torch.Tensor:
|
| 10 |
+
"""Build a rotation matrix around the Y axis from a scalar angle (radians).
|
| 11 |
+
|
| 12 |
+
Shape: angle.shape + (3, 3).
|
| 13 |
+
"""
|
| 14 |
+
cos, sin = torch.cos(angle), torch.sin(angle)
|
| 15 |
+
one, zero = torch.ones_like(angle), torch.zeros_like(angle)
|
| 16 |
+
mat = torch.stack((cos, zero, sin, zero, one, zero, -sin, zero, cos), -1)
|
| 17 |
+
mat = mat.reshape(angle.shape + (3, 3))
|
| 18 |
+
return mat
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def matrix_to_cont6d(matrix: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
"""Convert rotation matrix to 6D continuous representation (first two columns).
|
| 23 |
+
|
| 24 |
+
Shape: (..., 3, 3) -> (..., 6).
|
| 25 |
+
"""
|
| 26 |
+
cont_6d = torch.concat([matrix[..., 0], matrix[..., 1]], dim=-1)
|
| 27 |
+
return cont_6d
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def cont6d_to_matrix(cont6d: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
"""Convert 6D continuous representation to rotation matrix (Gram–Schmidt on two columns).
|
| 32 |
+
|
| 33 |
+
Last dim must be 6.
|
| 34 |
+
"""
|
| 35 |
+
assert cont6d.shape[-1] == 6, "The last dimension must be 6"
|
| 36 |
+
x_raw = cont6d[..., 0:3]
|
| 37 |
+
y_raw = cont6d[..., 3:6]
|
| 38 |
+
|
| 39 |
+
x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
|
| 40 |
+
z = torch.cross(x, y_raw, dim=-1)
|
| 41 |
+
z = z / torch.norm(z, dim=-1, keepdim=True)
|
| 42 |
+
|
| 43 |
+
y = torch.cross(z, x, dim=-1)
|
| 44 |
+
|
| 45 |
+
x = x[..., None]
|
| 46 |
+
y = y[..., None]
|
| 47 |
+
z = z[..., None]
|
| 48 |
+
|
| 49 |
+
mat = torch.cat([x, y, z], dim=-1)
|
| 50 |
+
return mat
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
|
| 54 |
+
"""Convert axis-angle to rotation matrix.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
axis_angle: (..., 3) axis-angle vectors (angle = norm, axis = normalized)
|
| 58 |
+
Returns:
|
| 59 |
+
rotmat: (..., 3, 3) rotation matrices
|
| 60 |
+
"""
|
| 61 |
+
eps = 1e-6
|
| 62 |
+
angle = torch.norm(axis_angle, dim=-1, keepdim=True) # (..., 1)
|
| 63 |
+
axis = axis_angle / (angle + eps)
|
| 64 |
+
|
| 65 |
+
x, y, z = axis.unbind(-1)
|
| 66 |
+
|
| 67 |
+
zero = torch.zeros_like(x)
|
| 68 |
+
K = torch.stack([zero, -z, y, z, zero, -x, -y, x, zero], dim=-1).reshape(*axis.shape[:-1], 3, 3)
|
| 69 |
+
|
| 70 |
+
eye = torch.eye(3, device=axis.device, dtype=axis.dtype)
|
| 71 |
+
eye = eye.expand(*axis.shape[:-1], 3, 3)
|
| 72 |
+
|
| 73 |
+
sin = torch.sin(angle)[..., None]
|
| 74 |
+
cos = torch.cos(angle)[..., None]
|
| 75 |
+
|
| 76 |
+
R = eye + sin * K + (1 - cos) * (K @ K)
|
| 77 |
+
return R
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def matrix_to_axis_angle(R: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
"""Convert rotation matrix to axis-angle via quaternions (more numerically stable).
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
R: (..., 3, 3) rotation matrices
|
| 85 |
+
Returns:
|
| 86 |
+
axis_angle: (..., 3)
|
| 87 |
+
"""
|
| 88 |
+
# Go through quaternions for numerical stability
|
| 89 |
+
quat = matrix_to_quaternion(R) # (..., 4) with (w, x, y, z)
|
| 90 |
+
return quaternion_to_axis_angle(quat)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def quaternion_to_axis_angle(quat: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
"""Convert quaternion to axis-angle representation.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
quat: (..., 4) quaternions with real part first (w, x, y, z)
|
| 98 |
+
Returns:
|
| 99 |
+
axis_angle: (..., 3)
|
| 100 |
+
"""
|
| 101 |
+
eps = 1e-6
|
| 102 |
+
|
| 103 |
+
# Ensure canonical form to avoid sign ambiguity.
|
| 104 |
+
# Primary: prefer w > 0. When w ≈ 0 (angle ≈ π), prefer first nonzero xyz > 0.
|
| 105 |
+
w = quat[..., 0:1]
|
| 106 |
+
xyz = quat[..., 1:]
|
| 107 |
+
|
| 108 |
+
# Find first significant component of xyz for tie-breaking when w ≈ 0
|
| 109 |
+
first_significant = xyz[..., 0:1] # use x component as tie-breaker
|
| 110 |
+
|
| 111 |
+
# Flip if: w < 0, OR (w ≈ 0 AND first xyz component < 0)
|
| 112 |
+
should_flip = (w < -eps) | ((w.abs() <= eps) & (first_significant < 0))
|
| 113 |
+
quat = torch.where(should_flip, -quat, quat)
|
| 114 |
+
|
| 115 |
+
w = quat[..., 0]
|
| 116 |
+
xyz = quat[..., 1:]
|
| 117 |
+
|
| 118 |
+
# sin(angle/2) = ||xyz||
|
| 119 |
+
sin_half_angle = xyz.norm(dim=-1)
|
| 120 |
+
|
| 121 |
+
# angle = 2 * atan2(sin(angle/2), cos(angle/2))
|
| 122 |
+
# This is more stable than 2 * acos(w) near angle=0
|
| 123 |
+
angle = 2.0 * torch.atan2(sin_half_angle, w)
|
| 124 |
+
|
| 125 |
+
# axis = xyz / sin(angle/2), but handle small angles
|
| 126 |
+
# For small angles: axis-angle ≈ 2 * xyz (since sin(x) ≈ x for small x)
|
| 127 |
+
small_angle = sin_half_angle.abs() < eps
|
| 128 |
+
|
| 129 |
+
# Safe division
|
| 130 |
+
scale = torch.where(
|
| 131 |
+
small_angle,
|
| 132 |
+
2.0 * torch.ones_like(angle), # small angle: axis_angle ≈ 2 * xyz
|
| 133 |
+
angle / sin_half_angle.clamp(min=eps),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
return xyz * scale.unsqueeze(-1)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
| 140 |
+
"""Returns torch.sqrt(torch.max(0, x)) subgradient is zero where x is 0."""
|
| 141 |
+
return torch.sqrt(x * (x > 0).to(x.dtype))
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
| 145 |
+
"""Convert rotations given as rotation matrices to quaternions.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 149 |
+
Returns:
|
| 150 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
| 151 |
+
"""
|
| 152 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 153 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 154 |
+
|
| 155 |
+
batch_dim = matrix.shape[:-2]
|
| 156 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
|
| 157 |
+
|
| 158 |
+
q_abs = _sqrt_positive_part(
|
| 159 |
+
torch.stack(
|
| 160 |
+
[
|
| 161 |
+
1.0 + m00 + m11 + m22,
|
| 162 |
+
1.0 + m00 - m11 - m22,
|
| 163 |
+
1.0 - m00 + m11 - m22,
|
| 164 |
+
1.0 - m00 - m11 + m22,
|
| 165 |
+
],
|
| 166 |
+
dim=-1,
|
| 167 |
+
)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
quat_by_rijk = torch.stack(
|
| 171 |
+
[
|
| 172 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
| 173 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
| 174 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
| 175 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
| 176 |
+
],
|
| 177 |
+
dim=-2,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
| 181 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
| 182 |
+
|
| 183 |
+
return (
|
| 184 |
+
(F.one_hot(q_abs.argmax(dim=-1), num_classes=4)[..., None] * quat_candidates)
|
| 185 |
+
.sum(dim=-2)
|
| 186 |
+
.reshape(batch_dim + (4,))
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
| 191 |
+
"""Convert rotations given as quaternions to rotation matrices.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
quaternions: quaternions with real part first,
|
| 195 |
+
as tensor of shape (..., 4).
|
| 196 |
+
Returns:
|
| 197 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 198 |
+
"""
|
| 199 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
| 200 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 201 |
+
|
| 202 |
+
o = torch.stack(
|
| 203 |
+
(
|
| 204 |
+
1 - two_s * (j * j + k * k),
|
| 205 |
+
two_s * (i * j - k * r),
|
| 206 |
+
two_s * (i * k + j * r),
|
| 207 |
+
two_s * (i * j + k * r),
|
| 208 |
+
1 - two_s * (i * i + k * k),
|
| 209 |
+
two_s * (j * k - i * r),
|
| 210 |
+
two_s * (i * k - j * r),
|
| 211 |
+
two_s * (j * k + i * r),
|
| 212 |
+
1 - two_s * (i * i + j * j),
|
| 213 |
+
),
|
| 214 |
+
-1,
|
| 215 |
+
)
|
| 216 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
kimodo/meta.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Parse and normalize prompt text/duration data from meta dicts."""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from typing import Any, Optional
|
| 7 |
+
|
| 8 |
+
from kimodo.tools import load_json
|
| 9 |
+
|
| 10 |
+
from .sanitize import sanitize_text, sanitize_texts
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_prompts_from_meta(meta_path: str, **kwargs):
|
| 14 |
+
"""Load prompts from a meta dict or file. If fps is provided, the durations are converted to
|
| 15 |
+
frames.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
meta_path: Path to the meta file.
|
| 19 |
+
**kwargs: Additional arguments to pass to parse_prompts_from_meta.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
texts: List of texts.
|
| 23 |
+
durations: List of durations in seconds or frames.
|
| 24 |
+
"""
|
| 25 |
+
if not os.path.exists(meta_path):
|
| 26 |
+
raise FileNotFoundError(f"meta.json not found in input folder: {meta_path}")
|
| 27 |
+
|
| 28 |
+
meta = load_json(meta_path)
|
| 29 |
+
return parse_prompts_from_meta(meta, **kwargs)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def parse_prompts_from_meta(
|
| 33 |
+
meta: dict[str, Any],
|
| 34 |
+
fps: Optional[float] = None,
|
| 35 |
+
sanitize: bool = False,
|
| 36 |
+
) -> tuple[list[str], list[float]]:
|
| 37 |
+
"""Parse prompt texts and durations from a meta dict into normalized lists. If fps is provided,
|
| 38 |
+
the durations are converted to frames.
|
| 39 |
+
|
| 40 |
+
Accepts either:
|
| 41 |
+
- Single prompt: "text" (str) and "duration" (float) in seconds.
|
| 42 |
+
- Multiple prompts: "texts" (list of str) and "durations" (list of float) in seconds.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
(texts, durations): texts as list of str, durations as list of float (seconds or frames).
|
| 46 |
+
Lengths of both lists are equal.
|
| 47 |
+
|
| 48 |
+
Raises:
|
| 49 |
+
ValueError: If meta does not contain a recognized format.
|
| 50 |
+
"""
|
| 51 |
+
# Single prompt
|
| 52 |
+
if "text" in meta and "duration" in meta:
|
| 53 |
+
text = meta["text"]
|
| 54 |
+
duration = float(meta["duration"])
|
| 55 |
+
if fps is not None:
|
| 56 |
+
duration = int(duration * fps)
|
| 57 |
+
if isinstance(text, list):
|
| 58 |
+
raise ValueError("meta has 'text' but it is a list; use 'texts' for multiple prompts")
|
| 59 |
+
|
| 60 |
+
if sanitize:
|
| 61 |
+
text = sanitize_text(text)
|
| 62 |
+
return ([text], [duration])
|
| 63 |
+
|
| 64 |
+
# Multiple prompts
|
| 65 |
+
if "texts" in meta and "durations" in meta:
|
| 66 |
+
texts = meta["texts"]
|
| 67 |
+
durations = meta["durations"]
|
| 68 |
+
if not isinstance(texts, list) or not isinstance(durations, list):
|
| 69 |
+
raise ValueError("meta 'texts' and 'durations' must be lists")
|
| 70 |
+
if len(texts) != len(durations):
|
| 71 |
+
raise ValueError(f"meta 'texts' and 'durations' length mismatch: {len(texts)} vs {len(durations)}")
|
| 72 |
+
durations = [float(d) for d in durations]
|
| 73 |
+
if fps is not None:
|
| 74 |
+
durations = [int(d * fps) for d in durations]
|
| 75 |
+
|
| 76 |
+
if sanitize:
|
| 77 |
+
texts = sanitize_texts(texts)
|
| 78 |
+
return texts, durations
|
| 79 |
+
|
| 80 |
+
raise ValueError("meta must contain either 'text' and 'duration', or 'texts' and 'durations'.")
|
kimodo/metrics/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Evaluation metrics for motion quality (foot skate, contact consistency, constraint following)."""
|
| 4 |
+
|
| 5 |
+
from .base import (
|
| 6 |
+
Metric,
|
| 7 |
+
aggregate_metrics,
|
| 8 |
+
clear_metrics,
|
| 9 |
+
compute_metrics,
|
| 10 |
+
)
|
| 11 |
+
from .constraints import ContraintFollow
|
| 12 |
+
from .foot_skate import (
|
| 13 |
+
FootContactConsistency,
|
| 14 |
+
FootSkateFromContacts,
|
| 15 |
+
FootSkateFromHeight,
|
| 16 |
+
FootSkateRatio,
|
| 17 |
+
)
|
| 18 |
+
from .tmr import (
|
| 19 |
+
TMR_EmbeddingMetric,
|
| 20 |
+
TMR_Metric,
|
| 21 |
+
compute_tmr_per_sample_retrieval,
|
| 22 |
+
compute_tmr_retrieval_metrics,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
__all__ = [
|
| 26 |
+
"Metric",
|
| 27 |
+
"ContraintFollow",
|
| 28 |
+
"FootContactConsistency",
|
| 29 |
+
"FootSkateFromContacts",
|
| 30 |
+
"FootSkateFromHeight",
|
| 31 |
+
"FootSkateRatio",
|
| 32 |
+
"TMR_EmbeddingMetric",
|
| 33 |
+
"TMR_Metric",
|
| 34 |
+
"aggregate_metrics",
|
| 35 |
+
"clear_metrics",
|
| 36 |
+
"compute_metrics",
|
| 37 |
+
"compute_tmr_per_sample_retrieval",
|
| 38 |
+
"compute_tmr_retrieval_metrics",
|
| 39 |
+
]
|
kimodo/metrics/base.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Base metric class and batch/aggregate helpers."""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from typing import Dict, List
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Metric:
|
| 14 |
+
"""Base class for metrics that accumulate results over multiple __call__ and expose
|
| 15 |
+
aggregate()."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, **kwargs):
|
| 18 |
+
self.clear()
|
| 19 |
+
|
| 20 |
+
def __call__(self, *args, **kwargs):
|
| 21 |
+
"""Compute metric for current batch, append to saved_metrics, and return the batch
|
| 22 |
+
result."""
|
| 23 |
+
metrics = self._compute(*args, **kwargs)
|
| 24 |
+
for key, val in metrics.items():
|
| 25 |
+
self.saved_metrics[key].append(val.detach().cpu().float())
|
| 26 |
+
return metrics
|
| 27 |
+
|
| 28 |
+
def _compute(self, **kwargs):
|
| 29 |
+
"""Subclasses implement this to compute metric dict from batch inputs."""
|
| 30 |
+
raise NotImplementedError()
|
| 31 |
+
|
| 32 |
+
def clear(self):
|
| 33 |
+
"""Reset all accumulated metric values."""
|
| 34 |
+
self.saved_metrics = defaultdict(list)
|
| 35 |
+
|
| 36 |
+
def aggregate(self):
|
| 37 |
+
"""Return a dict of concatenated/stacked tensors over all accumulated batches."""
|
| 38 |
+
output = {}
|
| 39 |
+
for key, lst in self.saved_metrics.items():
|
| 40 |
+
try:
|
| 41 |
+
output[key] = torch.cat(lst)
|
| 42 |
+
except RuntimeError:
|
| 43 |
+
output[key] = torch.stack(lst)
|
| 44 |
+
return output
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def compute_metrics(metrics_list: List[Metric], metrics_in: Dict) -> Dict:
|
| 48 |
+
"""Run each metric on metrics_in and return the combined dict of batch results."""
|
| 49 |
+
metrics_out = {}
|
| 50 |
+
for metric in metrics_list:
|
| 51 |
+
metrics_out.update(metric(**metrics_in))
|
| 52 |
+
return metrics_out
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def aggregate_metrics(metrics_list: List[Metric]) -> Dict:
|
| 56 |
+
"""Return combined aggregated results (concatenated over batches) for all metrics."""
|
| 57 |
+
metrics_out = {}
|
| 58 |
+
for metric in metrics_list:
|
| 59 |
+
metrics_out.update(metric.aggregate())
|
| 60 |
+
return metrics_out
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def clear_metrics(metrics_list: List[Metric]) -> None:
|
| 64 |
+
"""Clear accumulated values for all metrics in the list."""
|
| 65 |
+
for metric in metrics_list:
|
| 66 |
+
metric.clear()
|
kimodo/metrics/constraints.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Constraint-following metrics."""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from typing import Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
|
| 13 |
+
from kimodo.constraints import (
|
| 14 |
+
EndEffectorConstraintSet,
|
| 15 |
+
FullBodyConstraintSet,
|
| 16 |
+
Root2DConstraintSet,
|
| 17 |
+
)
|
| 18 |
+
from kimodo.tools import ensure_batched
|
| 19 |
+
|
| 20 |
+
from .base import Metric
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ContraintFollow(Metric):
|
| 24 |
+
"""Constraint-following metric dispatcher for kimodo constraint sets."""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
skeleton,
|
| 29 |
+
root_threshold: float = 0.10,
|
| 30 |
+
**kwargs,
|
| 31 |
+
):
|
| 32 |
+
super().__init__(**kwargs)
|
| 33 |
+
self.skeleton = skeleton
|
| 34 |
+
self.root_threshold = root_threshold
|
| 35 |
+
|
| 36 |
+
@ensure_batched(posed_joints=4, constraints_lst=2, lengths=1)
|
| 37 |
+
def _compute(
|
| 38 |
+
self,
|
| 39 |
+
posed_joints: Tensor,
|
| 40 |
+
constraints_lst: Optional[List],
|
| 41 |
+
lengths: Optional[Tensor] = None,
|
| 42 |
+
**kwargs,
|
| 43 |
+
) -> Dict:
|
| 44 |
+
if not constraints_lst:
|
| 45 |
+
return {}
|
| 46 |
+
|
| 47 |
+
root_idx = self.skeleton.root_idx
|
| 48 |
+
output = defaultdict(list)
|
| 49 |
+
|
| 50 |
+
for posed_joints_s, constraint_lst_s, lengths_s in zip(posed_joints, constraints_lst, lengths):
|
| 51 |
+
output_seq = defaultdict(list)
|
| 52 |
+
for constraint in constraint_lst_s:
|
| 53 |
+
frame_idx = constraint.frame_indices.to(device=posed_joints_s.device, dtype=torch.long)
|
| 54 |
+
assert frame_idx.max() < lengths_s, "The constraint is defined outsite the lenght of the motion."
|
| 55 |
+
if frame_idx.numel() == 0:
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
if isinstance(constraint, Root2DConstraintSet):
|
| 59 |
+
pred_root2d = posed_joints_s[frame_idx, root_idx][:, [0, 2]]
|
| 60 |
+
target = constraint.smooth_root_2d.to(posed_joints_s.device)
|
| 61 |
+
|
| 62 |
+
dist = torch.norm(pred_root2d - target, dim=-1)
|
| 63 |
+
output_seq["constraint_root2d_err"].append(dist)
|
| 64 |
+
hit = (dist <= self.root_threshold).float()
|
| 65 |
+
output_seq["constraint_root2d_acc"].append(hit)
|
| 66 |
+
|
| 67 |
+
elif isinstance(constraint, FullBodyConstraintSet):
|
| 68 |
+
pred = posed_joints_s[frame_idx]
|
| 69 |
+
target = constraint.global_joints_positions.to(posed_joints_s.device)
|
| 70 |
+
err = torch.norm(pred - target, dim=-1)
|
| 71 |
+
output_seq["constraint_fullbody_keyframe"].append(err)
|
| 72 |
+
|
| 73 |
+
elif isinstance(constraint, EndEffectorConstraintSet):
|
| 74 |
+
pos_idx = constraint.pos_indices.to(device=posed_joints_s.device, dtype=torch.long)
|
| 75 |
+
pred = posed_joints_s[frame_idx].index_select(1, pos_idx)
|
| 76 |
+
target = constraint.global_joints_positions.to(posed_joints_s.device).index_select(1, pos_idx)
|
| 77 |
+
err = torch.norm(pred - target, dim=-1)
|
| 78 |
+
output_seq["constraint_end_effector"].append(err)
|
| 79 |
+
|
| 80 |
+
# in case we have several same constraints in the list
|
| 81 |
+
for key, val in output_seq.items():
|
| 82 |
+
output[key].append(torch.cat(val).mean())
|
| 83 |
+
|
| 84 |
+
reduced = {}
|
| 85 |
+
for key, vals in output.items():
|
| 86 |
+
reduced[key] = torch.stack(vals, dim=0)
|
| 87 |
+
return reduced
|
kimodo/metrics/foot_skate.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Foot skate and contact consistency metrics."""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
|
| 12 |
+
from kimodo.motion_rep.feature_utils import compute_vel_xyz
|
| 13 |
+
from kimodo.motion_rep.feet import foot_detect_from_pos_and_vel
|
| 14 |
+
from kimodo.skeleton import SkeletonBase
|
| 15 |
+
from kimodo.tools import ensure_batched
|
| 16 |
+
|
| 17 |
+
from .base import Metric
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class FootSkateFromHeight(Metric):
|
| 21 |
+
"""When toe joint is near the floor, measures mean velocity of the toes."""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
skeleton: SkeletonBase,
|
| 26 |
+
fps: float,
|
| 27 |
+
height_thresh: float = 0.05,
|
| 28 |
+
**kwargs,
|
| 29 |
+
):
|
| 30 |
+
super().__init__(**kwargs)
|
| 31 |
+
self.height_thresh = height_thresh
|
| 32 |
+
self.skeleton = skeleton
|
| 33 |
+
self.fps = fps
|
| 34 |
+
|
| 35 |
+
@ensure_batched(posed_joints=4, lengths=1)
|
| 36 |
+
def _compute(
|
| 37 |
+
self,
|
| 38 |
+
posed_joints: Tensor,
|
| 39 |
+
lengths: Optional[Tensor] = None,
|
| 40 |
+
**kwargs,
|
| 41 |
+
) -> Dict:
|
| 42 |
+
fidx = self.skeleton.foot_joint_idx
|
| 43 |
+
if len(fidx) != 4:
|
| 44 |
+
raise ValueError("FootSkateFromHeight expects four foot joints (heel/toe per foot)")
|
| 45 |
+
|
| 46 |
+
feet_pos = posed_joints[:, :, fidx]
|
| 47 |
+
toe_pos = feet_pos[:, :, [1, 3]]
|
| 48 |
+
|
| 49 |
+
toe_on_floor = (toe_pos[..., 1] < self.height_thresh)[:, :-1] # y-up [B, T, 2] where [left right]
|
| 50 |
+
|
| 51 |
+
dt = 1.0 / self.fps
|
| 52 |
+
toe_vel = torch.norm(toe_pos[:, 1:] - toe_pos[:, :-1], dim=-1) / dt # [B, nframes-1, 2]
|
| 53 |
+
|
| 54 |
+
# compute err
|
| 55 |
+
contact_toe_vel = toe_vel * toe_on_floor # vel when corresponding toe is on ground
|
| 56 |
+
|
| 57 |
+
# account for generated length
|
| 58 |
+
# since they are velocities use length-1 to avoid inaccurate vel going one frame past len
|
| 59 |
+
device = toe_on_floor.device
|
| 60 |
+
len_mask = torch.arange(toe_on_floor.shape[1], device=device)[None, :, None].expand(toe_on_floor.shape) < (
|
| 61 |
+
lengths[:, None, None] - 1
|
| 62 |
+
)
|
| 63 |
+
toe_on_floor = toe_on_floor * len_mask
|
| 64 |
+
contact_toe_vel = contact_toe_vel * len_mask
|
| 65 |
+
|
| 66 |
+
mean_vel = torch.sum(contact_toe_vel, (1, 2)) / (torch.sum(toe_on_floor, (1, 2)) + 1e-6)
|
| 67 |
+
return {"foot_skate_from_height": mean_vel}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class FootSkateFromContacts(Metric):
|
| 71 |
+
"""Measures velocity of the toes and ankles when predicted to be in contact."""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
skeleton: SkeletonBase,
|
| 76 |
+
fps: float,
|
| 77 |
+
**kwargs,
|
| 78 |
+
):
|
| 79 |
+
super().__init__(**kwargs)
|
| 80 |
+
self.skeleton = skeleton
|
| 81 |
+
self.fps = fps
|
| 82 |
+
|
| 83 |
+
@ensure_batched(posed_joints=4, foot_contacts=3, lengths=1)
|
| 84 |
+
def _compute(
|
| 85 |
+
self,
|
| 86 |
+
posed_joints: Tensor,
|
| 87 |
+
foot_contacts: Tensor,
|
| 88 |
+
lengths: Optional[Tensor] = None,
|
| 89 |
+
**kwargs,
|
| 90 |
+
) -> Dict:
|
| 91 |
+
fidx = self.skeleton.foot_joint_idx
|
| 92 |
+
feet_pos = posed_joints[:, :, fidx]
|
| 93 |
+
dt = 1.0 / self.fps
|
| 94 |
+
foot_vel = torch.norm(feet_pos[:, 1:] - feet_pos[:, :-1], dim=-1) / dt
|
| 95 |
+
|
| 96 |
+
foot_contacts = foot_contacts[:, :-1]
|
| 97 |
+
vel_err = foot_vel * foot_contacts
|
| 98 |
+
|
| 99 |
+
# account for generated length
|
| 100 |
+
# since they are velocities use length-1 to avoid inaccurate vel going one frame past len
|
| 101 |
+
device = foot_contacts.device
|
| 102 |
+
len_mask = torch.arange(foot_contacts.shape[1], device=device)[None, :, None].expand(foot_contacts.shape) < (
|
| 103 |
+
lengths[:, None, None] - 1
|
| 104 |
+
)
|
| 105 |
+
foot_contacts = foot_contacts * len_mask
|
| 106 |
+
vel_err = vel_err * len_mask
|
| 107 |
+
|
| 108 |
+
mean_vel = torch.sum(vel_err, (1, 2)) / (torch.sum(foot_contacts, (1, 2)) + 1e-6) # mean over contacting frames
|
| 109 |
+
|
| 110 |
+
# Compute max velocity error across all feet and frames (per batch)
|
| 111 |
+
max_vel = vel_err.amax(dim=(1, 2)) # [B]
|
| 112 |
+
|
| 113 |
+
return {
|
| 114 |
+
"foot_skate_from_pred_contacts": mean_vel,
|
| 115 |
+
"foot_skate_max_vel": max_vel,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class FootSkateRatio(Metric):
|
| 120 |
+
"""Compute fraction of frames where the foot skates when it is on the ground.
|
| 121 |
+
|
| 122 |
+
Inspired by GMD: https://github.com/korrawe/guided-motion-diffusion/blob/main/data_loaders/humanml/utils/metrics.py#L204
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
skeleton: SkeletonBase,
|
| 128 |
+
fps: float,
|
| 129 |
+
height_thresh=0.05,
|
| 130 |
+
vel_thresh=0.2,
|
| 131 |
+
**kwargs,
|
| 132 |
+
):
|
| 133 |
+
super().__init__(**kwargs)
|
| 134 |
+
self.height_thresh = height_thresh
|
| 135 |
+
self.vel_thresh = vel_thresh
|
| 136 |
+
|
| 137 |
+
self.skeleton = skeleton
|
| 138 |
+
self.fps = fps
|
| 139 |
+
|
| 140 |
+
@ensure_batched(posed_joints=4, foot_contacts=3, lengths=1)
|
| 141 |
+
def _compute(
|
| 142 |
+
self,
|
| 143 |
+
posed_joints: Tensor,
|
| 144 |
+
foot_contacts: Tensor,
|
| 145 |
+
lengths: Optional[Tensor] = None,
|
| 146 |
+
**kwargs,
|
| 147 |
+
) -> Dict:
|
| 148 |
+
fidx = self.skeleton.foot_joint_idx
|
| 149 |
+
assert len(fidx) == 4, "This metric assumes 4 foot joints: heel, toe, heel, toe"
|
| 150 |
+
|
| 151 |
+
feet_pos = posed_joints[:, :, fidx]
|
| 152 |
+
toe_pos = feet_pos[:, :, [1, 3]]
|
| 153 |
+
|
| 154 |
+
toe_on_floor = toe_pos[..., 1] < self.height_thresh # y-up [B, T, 2] where [left right]
|
| 155 |
+
# current and next frame on floor to consider it in contact
|
| 156 |
+
toe_on_floor = torch.logical_and(toe_on_floor[:, :-1], toe_on_floor[:, 1:]) # [B, T-1, 2]
|
| 157 |
+
|
| 158 |
+
dt = 1.0 / self.fps
|
| 159 |
+
toe_vel = torch.norm(toe_pos[:, 1:] - toe_pos[:, :-1], dim=-1) / dt # [B, nframes-1, 2]
|
| 160 |
+
|
| 161 |
+
# compute err
|
| 162 |
+
contact_toe_vel = toe_vel * toe_on_floor # vel when corresponding toe is on ground
|
| 163 |
+
|
| 164 |
+
# account for generated length
|
| 165 |
+
# since they are velocities use length-1 to avoid inaccurate vel going one frame past len
|
| 166 |
+
device = toe_on_floor.device
|
| 167 |
+
len_mask = torch.arange(toe_on_floor.shape[1], device=device)[None, :, None].expand(toe_on_floor.shape) < (
|
| 168 |
+
lengths[:, None, None] - 1
|
| 169 |
+
)
|
| 170 |
+
toe_on_floor = toe_on_floor * len_mask
|
| 171 |
+
contact_toe_vel = contact_toe_vel * len_mask
|
| 172 |
+
|
| 173 |
+
# skating if velocity during contact > thresh
|
| 174 |
+
toe_skate = contact_toe_vel > self.vel_thresh
|
| 175 |
+
skate_ratio = torch.sum(toe_skate, (1, 2)) / (torch.sum(toe_on_floor, (1, 2)) + 1e-6)
|
| 176 |
+
return {"foot_skate_ratio": skate_ratio}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class FootContactConsistency(Metric):
|
| 180 |
+
"""Measures consistency between heuristic detected foot contacts (from height and velocity) and
|
| 181 |
+
predicted foot contacts.
|
| 182 |
+
|
| 183 |
+
i.e. accuracy of how well predicted matches heuristic.
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
def __init__(
|
| 187 |
+
self,
|
| 188 |
+
skeleton: SkeletonBase,
|
| 189 |
+
fps: float,
|
| 190 |
+
vel_thresh: float = 0.15,
|
| 191 |
+
height_thresh: float = 0.10,
|
| 192 |
+
**kwargs,
|
| 193 |
+
):
|
| 194 |
+
super().__init__(**kwargs)
|
| 195 |
+
self.vel_thresh = vel_thresh
|
| 196 |
+
self.height_thresh = height_thresh
|
| 197 |
+
|
| 198 |
+
self.skeleton = skeleton
|
| 199 |
+
self.fps = fps
|
| 200 |
+
|
| 201 |
+
@ensure_batched(posed_joints=4, foot_contacts=3, lengths=1)
|
| 202 |
+
def _compute(
|
| 203 |
+
self,
|
| 204 |
+
posed_joints: Tensor,
|
| 205 |
+
foot_contacts: Tensor,
|
| 206 |
+
lengths: Optional[Tensor] = None,
|
| 207 |
+
**kwargs,
|
| 208 |
+
) -> Dict:
|
| 209 |
+
velocity = compute_vel_xyz(posed_joints, float(self.fps), lengths=lengths)
|
| 210 |
+
heuristic_contacts = foot_detect_from_pos_and_vel(
|
| 211 |
+
posed_joints,
|
| 212 |
+
velocity,
|
| 213 |
+
self.skeleton,
|
| 214 |
+
self.vel_thresh,
|
| 215 |
+
self.height_thresh,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# compute accuracy of predicted, treating heuristic as ground truth
|
| 219 |
+
num_contacts = foot_contacts.shape[-1]
|
| 220 |
+
incorrect = torch.logical_xor(heuristic_contacts, foot_contacts)
|
| 221 |
+
# account for generated length
|
| 222 |
+
# since they are velocities, use length-1 to avoid inaccurate vel going one frame past len
|
| 223 |
+
device = foot_contacts.device
|
| 224 |
+
len_mask = torch.arange(foot_contacts.shape[1], device=device)[None, :, None].expand(foot_contacts.shape) < (
|
| 225 |
+
lengths[:, None, None] - 1
|
| 226 |
+
)
|
| 227 |
+
incorrect = incorrect * len_mask
|
| 228 |
+
|
| 229 |
+
incorrect_ratio = torch.sum(incorrect, (1, 2)) / (num_contacts * (lengths - 1))
|
| 230 |
+
accuracy = 1 - incorrect_ratio
|
| 231 |
+
|
| 232 |
+
return {"foot_contact_consistency": accuracy}
|
kimodo/metrics/tmr.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""TMR evaluation metrics: text-motion retrieval, R-Precision, and related scores."""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from typing import Any, Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from scipy import linalg
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
|
| 15 |
+
from kimodo.model.tmr import TMR
|
| 16 |
+
|
| 17 |
+
from .base import Metric
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Scores are between 0 and 1
|
| 21 |
+
def get_score_matrix_unit(x, y):
|
| 22 |
+
sim_matrix = np.einsum("b i, c i -> b c", x, y)
|
| 23 |
+
scores = sim_matrix / 2 + 0.5
|
| 24 |
+
return scores
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_scores_unit(x, y):
|
| 28 |
+
similarity = np.einsum("... i, ... i", x, y)
|
| 29 |
+
scores = similarity / 2 + 0.5
|
| 30 |
+
return scores
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def compute_tmr_per_sample_retrieval(
|
| 34 |
+
motion_emb: np.ndarray,
|
| 35 |
+
text_emb: np.ndarray,
|
| 36 |
+
sample_ids: List[str],
|
| 37 |
+
texts: List[str],
|
| 38 |
+
top_k: int = 5,
|
| 39 |
+
) -> List[Dict[str, Any]]:
|
| 40 |
+
"""For each sample (text query i), compute t2m rank of motion i and top-k retrieved motions with
|
| 41 |
+
ids and texts.
|
| 42 |
+
|
| 43 |
+
Returns list of dicts: [{"rank": int, "top_k": [{"id": str, "text": str}, ...]}, ...].
|
| 44 |
+
"""
|
| 45 |
+
motion_emb = np.asarray(motion_emb).squeeze()
|
| 46 |
+
text_emb = np.asarray(text_emb).squeeze()
|
| 47 |
+
if motion_emb.ndim == 1:
|
| 48 |
+
motion_emb = motion_emb[np.newaxis, :]
|
| 49 |
+
if text_emb.ndim == 1:
|
| 50 |
+
text_emb = text_emb[np.newaxis, :]
|
| 51 |
+
n = motion_emb.shape[0]
|
| 52 |
+
assert text_emb.shape[0] == n and len(sample_ids) == n and len(texts) == n
|
| 53 |
+
scores = get_score_matrix_unit(text_emb, motion_emb)
|
| 54 |
+
out: List[Dict[str, Any]] = []
|
| 55 |
+
for i in range(n):
|
| 56 |
+
row = np.asarray(scores[i])
|
| 57 |
+
order = np.argsort(row)[::-1]
|
| 58 |
+
rank = int(np.where(order == i)[0][0]) + 1
|
| 59 |
+
top_indices = order[:top_k]
|
| 60 |
+
top_k_list = [{"id": sample_ids[j], "text": texts[j]} for j in top_indices]
|
| 61 |
+
out.append({"rank": rank, "top_k": top_k_list})
|
| 62 |
+
return out
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class TMR_Metric(Metric):
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
tmr_model: TMR,
|
| 69 |
+
ranks: List = [1, 2, 3, 5, 10],
|
| 70 |
+
ranks_rounding=2,
|
| 71 |
+
**kwargs,
|
| 72 |
+
):
|
| 73 |
+
super().__init__(**kwargs)
|
| 74 |
+
self.tmr_model = tmr_model
|
| 75 |
+
self.ranks = ranks
|
| 76 |
+
self.ranks_rounding = ranks_rounding
|
| 77 |
+
|
| 78 |
+
def clear(self):
|
| 79 |
+
self.saved_metrics = defaultdict(list)
|
| 80 |
+
self.saved_text_latents = []
|
| 81 |
+
self.saved_motion_gen_latents = []
|
| 82 |
+
self.saved_motion_gt_latents = []
|
| 83 |
+
|
| 84 |
+
def _compute(
|
| 85 |
+
self,
|
| 86 |
+
motion_rep,
|
| 87 |
+
pred_joints_output: Dict,
|
| 88 |
+
gt_joints_output: Dict,
|
| 89 |
+
text_x_dict: Dict,
|
| 90 |
+
lengths: Tensor,
|
| 91 |
+
**kwargs,
|
| 92 |
+
) -> Dict:
|
| 93 |
+
pred_posed_joints = pred_joints_output["posed_joints"]
|
| 94 |
+
original_skeleton = motion_rep.skeleton if motion_rep is not None else None
|
| 95 |
+
latents_motion = self.tmr_model.encode_motion(
|
| 96 |
+
pred_posed_joints,
|
| 97 |
+
lengths=lengths,
|
| 98 |
+
original_skeleton=original_skeleton,
|
| 99 |
+
unit_vector=True,
|
| 100 |
+
)
|
| 101 |
+
latents_motion = latents_motion.cpu().numpy()
|
| 102 |
+
|
| 103 |
+
if isinstance(text_x_dict, dict) and "texts" in text_x_dict:
|
| 104 |
+
latents_text = self.tmr_model.encode_raw_text(text_x_dict["texts"], unit_vector=True)
|
| 105 |
+
else:
|
| 106 |
+
latents_text = self.tmr_model.encode_text(text_x_dict, unit_vector=True)
|
| 107 |
+
if latents_text.dim() == 1:
|
| 108 |
+
latents_text = latents_text.unsqueeze(0)
|
| 109 |
+
latents_text = latents_text.cpu().numpy()
|
| 110 |
+
|
| 111 |
+
self.saved_text_latents.append(latents_text)
|
| 112 |
+
self.saved_motion_gen_latents.append(latents_motion)
|
| 113 |
+
|
| 114 |
+
scores_text = get_scores_unit(latents_motion, latents_text)
|
| 115 |
+
output = {"TMR/t2m_sim": scores_text}
|
| 116 |
+
|
| 117 |
+
if gt_joints_output is not None and "posed_joints" in gt_joints_output:
|
| 118 |
+
gt_posed_joints = gt_joints_output["posed_joints"]
|
| 119 |
+
gt_latents_motion = self.tmr_model.encode_motion(
|
| 120 |
+
gt_posed_joints,
|
| 121 |
+
lengths=lengths,
|
| 122 |
+
original_skeleton=original_skeleton,
|
| 123 |
+
unit_vector=True,
|
| 124 |
+
)
|
| 125 |
+
gt_latents_motion = gt_latents_motion.cpu().numpy()
|
| 126 |
+
self.saved_motion_gt_latents.append(gt_latents_motion)
|
| 127 |
+
|
| 128 |
+
gt_scores_text = get_scores_unit(gt_latents_motion, latents_text)
|
| 129 |
+
scores_motion = get_scores_unit(latents_motion, gt_latents_motion)
|
| 130 |
+
|
| 131 |
+
output["TMR/t2m_gt_sim"] = gt_scores_text
|
| 132 |
+
output["TMR/m2m_sim"] = scores_motion
|
| 133 |
+
|
| 134 |
+
# pytorch tensors
|
| 135 |
+
for key, val in output.items():
|
| 136 |
+
output[key] = torch.tensor(val)
|
| 137 |
+
return output
|
| 138 |
+
|
| 139 |
+
def aggregate(self):
|
| 140 |
+
output = {}
|
| 141 |
+
for key, lst in self.saved_metrics.items():
|
| 142 |
+
output[key] = np.concatenate(lst)
|
| 143 |
+
|
| 144 |
+
assert self.saved_text_latents, "Should call the metric at least once."
|
| 145 |
+
|
| 146 |
+
text_latents = np.concatenate(self.saved_text_latents)
|
| 147 |
+
motion_gen_latents = np.concatenate(self.saved_motion_gen_latents)
|
| 148 |
+
|
| 149 |
+
batch_size = len(text_latents)
|
| 150 |
+
assert text_latents.shape == motion_gen_latents.shape
|
| 151 |
+
|
| 152 |
+
scores_t2m = get_score_matrix_unit(text_latents, motion_gen_latents)
|
| 153 |
+
scores_t2t = get_score_matrix_unit(text_latents, text_latents)
|
| 154 |
+
|
| 155 |
+
t2m_metrics = contrastive_metrics(
|
| 156 |
+
scores=scores_t2m,
|
| 157 |
+
scores_t2t=scores_t2t,
|
| 158 |
+
threshold=0.99,
|
| 159 |
+
rounding=2,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
for key, val in t2m_metrics.items():
|
| 163 |
+
output["TMR/t2m_R/" + key] = val
|
| 164 |
+
|
| 165 |
+
mu_gen, cov_gen = calculate_activation_statistics(motion_gen_latents)
|
| 166 |
+
mu_text, cov_text = calculate_activation_statistics(text_latents)
|
| 167 |
+
|
| 168 |
+
fid_gen_text = calculate_frechet_distance(mu_gen, cov_gen, mu_text, cov_text)
|
| 169 |
+
output["TMR/FID/gen_text"] = fid_gen_text
|
| 170 |
+
|
| 171 |
+
if self.saved_motion_gt_latents:
|
| 172 |
+
motion_gt_latents = np.concatenate(self.saved_motion_gt_latents)
|
| 173 |
+
assert motion_gt_latents.shape == motion_gen_latents.shape
|
| 174 |
+
|
| 175 |
+
scores_m2gm = get_score_matrix_unit(motion_gen_latents, motion_gt_latents)
|
| 176 |
+
scores_t2gm = get_score_matrix_unit(text_latents, motion_gt_latents)
|
| 177 |
+
|
| 178 |
+
m2gm_metrics = contrastive_metrics(
|
| 179 |
+
scores=scores_m2gm,
|
| 180 |
+
scores_t2t=scores_t2t,
|
| 181 |
+
threshold=0.99,
|
| 182 |
+
rounding=2,
|
| 183 |
+
)
|
| 184 |
+
for key, val in m2gm_metrics.items():
|
| 185 |
+
output["TMR/m2m_R/" + key] = val
|
| 186 |
+
|
| 187 |
+
t2gm_metrics = contrastive_metrics(
|
| 188 |
+
scores=scores_t2gm,
|
| 189 |
+
scores_t2t=scores_t2t,
|
| 190 |
+
threshold=0.99,
|
| 191 |
+
rounding=2,
|
| 192 |
+
)
|
| 193 |
+
for key, val in t2gm_metrics.items():
|
| 194 |
+
output["TMR/t2m_gt_R/" + key] = val
|
| 195 |
+
|
| 196 |
+
mu_gt_motion, cov_gt_motion = calculate_activation_statistics(motion_gt_latents)
|
| 197 |
+
fid_gen_motion = calculate_frechet_distance(
|
| 198 |
+
mu_gen,
|
| 199 |
+
cov_gen,
|
| 200 |
+
mu_gt_motion,
|
| 201 |
+
cov_gt_motion,
|
| 202 |
+
)
|
| 203 |
+
output["TMR/FID/gen_gt"] = fid_gen_motion
|
| 204 |
+
|
| 205 |
+
fid_gt_text = calculate_frechet_distance(
|
| 206 |
+
mu_gt_motion,
|
| 207 |
+
cov_gt_motion,
|
| 208 |
+
mu_text,
|
| 209 |
+
cov_text,
|
| 210 |
+
)
|
| 211 |
+
output["TMR/FID/gt_text"] = fid_gt_text
|
| 212 |
+
|
| 213 |
+
for key, val in output.items():
|
| 214 |
+
if isinstance(val, (int, float, np.integer, np.floating)):
|
| 215 |
+
val = torch.tensor([val for _ in range(batch_size)])
|
| 216 |
+
|
| 217 |
+
if isinstance(val, np.ndarray):
|
| 218 |
+
val = torch.from_numpy(val)
|
| 219 |
+
|
| 220 |
+
output[key] = val.cpu().float()
|
| 221 |
+
return output
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class TMR_EmbeddingMetric(Metric):
|
| 225 |
+
"""TMR metrics from precomputed motion and text embeddings (no model load).
|
| 226 |
+
|
| 227 |
+
Use in the loop: pass motion_emb and text_emb per sample; aggregate() computes retrieval metrics.
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
def __init__(self, ranks_rounding: int = 2, **kwargs):
|
| 231 |
+
super().__init__(**kwargs)
|
| 232 |
+
self.ranks_rounding = ranks_rounding
|
| 233 |
+
|
| 234 |
+
def clear(self):
|
| 235 |
+
self.saved_metrics = defaultdict(list)
|
| 236 |
+
self.saved_text_latents = []
|
| 237 |
+
self.saved_motion_gen_latents = []
|
| 238 |
+
self.saved_motion_gt_latents = []
|
| 239 |
+
|
| 240 |
+
def _compute(
|
| 241 |
+
self,
|
| 242 |
+
motion_emb=None,
|
| 243 |
+
text_emb=None,
|
| 244 |
+
gt_motion_emb=None,
|
| 245 |
+
**kwargs,
|
| 246 |
+
) -> Dict:
|
| 247 |
+
if motion_emb is None or text_emb is None:
|
| 248 |
+
return {}
|
| 249 |
+
motion_emb = np.asarray(motion_emb)
|
| 250 |
+
text_emb = np.asarray(text_emb)
|
| 251 |
+
if motion_emb.ndim == 1:
|
| 252 |
+
motion_emb = motion_emb[np.newaxis, :]
|
| 253 |
+
if text_emb.ndim == 1:
|
| 254 |
+
text_emb = text_emb[np.newaxis, :]
|
| 255 |
+
self.saved_text_latents.append(text_emb)
|
| 256 |
+
self.saved_motion_gen_latents.append(motion_emb)
|
| 257 |
+
if gt_motion_emb is not None:
|
| 258 |
+
gt_motion_emb = np.asarray(gt_motion_emb)
|
| 259 |
+
if gt_motion_emb.ndim == 1:
|
| 260 |
+
gt_motion_emb = gt_motion_emb[np.newaxis, :]
|
| 261 |
+
self.saved_motion_gt_latents.append(gt_motion_emb)
|
| 262 |
+
scores = get_scores_unit(motion_emb, text_emb)
|
| 263 |
+
return {"TMR/t2m_sim": torch.tensor(scores, dtype=torch.float32)}
|
| 264 |
+
|
| 265 |
+
def aggregate(self):
|
| 266 |
+
output = {}
|
| 267 |
+
for key, lst in self.saved_metrics.items():
|
| 268 |
+
output[key] = np.concatenate(lst)
|
| 269 |
+
if not self.saved_text_latents:
|
| 270 |
+
return output
|
| 271 |
+
text_latents = np.concatenate(self.saved_text_latents)
|
| 272 |
+
motion_gen_latents = np.concatenate(self.saved_motion_gen_latents)
|
| 273 |
+
batch_size = len(text_latents)
|
| 274 |
+
assert text_latents.shape == motion_gen_latents.shape
|
| 275 |
+
scores_t2m = get_score_matrix_unit(text_latents, motion_gen_latents)
|
| 276 |
+
scores_t2t = get_score_matrix_unit(text_latents, text_latents)
|
| 277 |
+
t2m_metrics = contrastive_metrics(
|
| 278 |
+
scores=scores_t2m,
|
| 279 |
+
scores_t2t=scores_t2t,
|
| 280 |
+
threshold=0.99,
|
| 281 |
+
rounding=self.ranks_rounding,
|
| 282 |
+
)
|
| 283 |
+
for key, val in t2m_metrics.items():
|
| 284 |
+
output["TMR/t2m_R/" + key] = val
|
| 285 |
+
mu_gen, cov_gen = calculate_activation_statistics(motion_gen_latents)
|
| 286 |
+
mu_text, cov_text = calculate_activation_statistics(text_latents)
|
| 287 |
+
output["TMR/FID/gen_text"] = calculate_frechet_distance(mu_gen, cov_gen, mu_text, cov_text)
|
| 288 |
+
if self.saved_motion_gt_latents:
|
| 289 |
+
motion_gt_latents = np.concatenate(self.saved_motion_gt_latents)
|
| 290 |
+
assert motion_gt_latents.shape == motion_gen_latents.shape
|
| 291 |
+
scores_m2gm = get_score_matrix_unit(motion_gen_latents, motion_gt_latents)
|
| 292 |
+
scores_t2gm = get_score_matrix_unit(text_latents, motion_gt_latents)
|
| 293 |
+
m2gm_metrics = contrastive_metrics(
|
| 294 |
+
scores=scores_m2gm,
|
| 295 |
+
scores_t2t=scores_t2t,
|
| 296 |
+
threshold=0.99,
|
| 297 |
+
rounding=self.ranks_rounding,
|
| 298 |
+
)
|
| 299 |
+
for key, val in m2gm_metrics.items():
|
| 300 |
+
output["TMR/m2m_R/" + key] = val
|
| 301 |
+
t2gm_metrics = contrastive_metrics(
|
| 302 |
+
scores=scores_t2gm,
|
| 303 |
+
scores_t2t=scores_t2t,
|
| 304 |
+
threshold=0.99,
|
| 305 |
+
rounding=self.ranks_rounding,
|
| 306 |
+
)
|
| 307 |
+
for key, val in t2gm_metrics.items():
|
| 308 |
+
output["TMR/t2m_gt_R/" + key] = val
|
| 309 |
+
mu_gt_motion, cov_gt_motion = calculate_activation_statistics(motion_gt_latents)
|
| 310 |
+
output["TMR/FID/gen_gt"] = calculate_frechet_distance(mu_gen, cov_gen, mu_gt_motion, cov_gt_motion)
|
| 311 |
+
output["TMR/FID/gt_text"] = calculate_frechet_distance(mu_gt_motion, cov_gt_motion, mu_text, cov_text)
|
| 312 |
+
for key, val in output.items():
|
| 313 |
+
if isinstance(val, (int, float, np.integer, np.floating)):
|
| 314 |
+
val = torch.tensor([val for _ in range(batch_size)])
|
| 315 |
+
if isinstance(val, np.ndarray):
|
| 316 |
+
val = torch.from_numpy(val)
|
| 317 |
+
output[key] = val.cpu().float()
|
| 318 |
+
return output
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def compute_tmr_retrieval_metrics(
|
| 322 |
+
motion_emb: np.ndarray,
|
| 323 |
+
text_emb: np.ndarray,
|
| 324 |
+
gt_motion_emb: Optional[np.ndarray] = None,
|
| 325 |
+
rounding: int = 2,
|
| 326 |
+
) -> Dict[str, float]:
|
| 327 |
+
"""Compute TMR retrieval metrics from precomputed embeddings."""
|
| 328 |
+
if motion_emb.shape != text_emb.shape:
|
| 329 |
+
raise ValueError(f"Expected same shape for motion/text embeddings, got {motion_emb.shape} vs {text_emb.shape}")
|
| 330 |
+
|
| 331 |
+
scores_t2m = get_score_matrix_unit(text_emb, motion_emb)
|
| 332 |
+
scores_t2t = get_score_matrix_unit(text_emb, text_emb)
|
| 333 |
+
|
| 334 |
+
output: Dict[str, float] = {}
|
| 335 |
+
t2m_metrics = contrastive_metrics(
|
| 336 |
+
scores=scores_t2m,
|
| 337 |
+
scores_t2t=scores_t2t,
|
| 338 |
+
threshold=0.99,
|
| 339 |
+
rounding=rounding,
|
| 340 |
+
)
|
| 341 |
+
for key, val in t2m_metrics.items():
|
| 342 |
+
output[f"TMR/t2m_R/{key}"] = float(val)
|
| 343 |
+
|
| 344 |
+
mu_gen, cov_gen = calculate_activation_statistics(motion_emb)
|
| 345 |
+
mu_text, cov_text = calculate_activation_statistics(text_emb)
|
| 346 |
+
output["TMR/FID/gen_text"] = float(calculate_frechet_distance(mu_gen, cov_gen, mu_text, cov_text))
|
| 347 |
+
|
| 348 |
+
if gt_motion_emb is not None:
|
| 349 |
+
if gt_motion_emb.shape != motion_emb.shape:
|
| 350 |
+
raise ValueError(f"Expected gt motion embeddings shape {motion_emb.shape}, got {gt_motion_emb.shape}")
|
| 351 |
+
|
| 352 |
+
scores_m2gm = get_score_matrix_unit(motion_emb, gt_motion_emb)
|
| 353 |
+
scores_t2gm = get_score_matrix_unit(text_emb, gt_motion_emb)
|
| 354 |
+
|
| 355 |
+
m2gm_metrics = contrastive_metrics(
|
| 356 |
+
scores=scores_m2gm,
|
| 357 |
+
scores_t2t=scores_t2t,
|
| 358 |
+
threshold=0.99,
|
| 359 |
+
rounding=rounding,
|
| 360 |
+
)
|
| 361 |
+
for key, val in m2gm_metrics.items():
|
| 362 |
+
output[f"TMR/m2m_R/{key}"] = float(val)
|
| 363 |
+
|
| 364 |
+
t2gm_metrics = contrastive_metrics(
|
| 365 |
+
scores=scores_t2gm,
|
| 366 |
+
scores_t2t=scores_t2t,
|
| 367 |
+
threshold=0.99,
|
| 368 |
+
rounding=rounding,
|
| 369 |
+
)
|
| 370 |
+
for key, val in t2gm_metrics.items():
|
| 371 |
+
output[f"TMR/t2m_gt_R/{key}"] = float(val)
|
| 372 |
+
|
| 373 |
+
mu_gt_motion, cov_gt_motion = calculate_activation_statistics(gt_motion_emb)
|
| 374 |
+
output["TMR/FID/gen_gt"] = float(calculate_frechet_distance(mu_gen, cov_gen, mu_gt_motion, cov_gt_motion))
|
| 375 |
+
output["TMR/FID/gt_text"] = float(calculate_frechet_distance(mu_gt_motion, cov_gt_motion, mu_text, cov_text))
|
| 376 |
+
|
| 377 |
+
return output
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def all_contrastive_metrics(sims, emb=None, threshold=None, rounding=2, return_cols=False):
|
| 381 |
+
text_selfsim = None
|
| 382 |
+
if emb is not None:
|
| 383 |
+
text_selfsim = emb @ emb.T
|
| 384 |
+
|
| 385 |
+
t2m_m, t2m_cols = contrastive_metrics(sims, text_selfsim, threshold, return_cols=True, rounding=rounding)
|
| 386 |
+
m2t_m, m2t_cols = contrastive_metrics(sims.T, text_selfsim, threshold, return_cols=True, rounding=rounding)
|
| 387 |
+
|
| 388 |
+
all_m = {}
|
| 389 |
+
for key in t2m_m:
|
| 390 |
+
all_m[f"t2m/{key}"] = t2m_m[key]
|
| 391 |
+
all_m[f"m2t/{key}"] = m2t_m[key]
|
| 392 |
+
|
| 393 |
+
all_m["t2m/len"] = float(len(sims))
|
| 394 |
+
all_m["m2t/len"] = float(len(sims[0]))
|
| 395 |
+
if return_cols:
|
| 396 |
+
return all_m, t2m_cols, m2t_cols
|
| 397 |
+
return all_m
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def contrastive_metrics(
|
| 401 |
+
scores,
|
| 402 |
+
scores_t2t=None,
|
| 403 |
+
threshold=None,
|
| 404 |
+
rounding=2,
|
| 405 |
+
):
|
| 406 |
+
n, m = scores.shape
|
| 407 |
+
assert n == m
|
| 408 |
+
num_queries = n
|
| 409 |
+
|
| 410 |
+
dists = -scores
|
| 411 |
+
sorted_dists = np.sort(dists, axis=1)
|
| 412 |
+
# GT is in the diagonal
|
| 413 |
+
gt_dists = np.diag(dists)[:, None]
|
| 414 |
+
|
| 415 |
+
if scores_t2t is not None and threshold is not None:
|
| 416 |
+
real_threshold = 2 * threshold - 1
|
| 417 |
+
idx = np.argwhere(scores_t2t > real_threshold)
|
| 418 |
+
partition = np.unique(idx[:, 0], return_index=True)[1]
|
| 419 |
+
# take as GT the minimum score of similar values
|
| 420 |
+
gt_dists = np.minimum.reduceat(dists[tuple(idx.T)], partition)
|
| 421 |
+
gt_dists = gt_dists[:, None]
|
| 422 |
+
|
| 423 |
+
rows, cols = np.where((sorted_dists - gt_dists) == 0) # find column position of GT
|
| 424 |
+
|
| 425 |
+
# if there are ties
|
| 426 |
+
if rows.size > num_queries:
|
| 427 |
+
assert np.unique(rows).size == num_queries, "issue in metric evaluation"
|
| 428 |
+
avg_cols = break_ties_average(sorted_dists, gt_dists)
|
| 429 |
+
cols = avg_cols
|
| 430 |
+
|
| 431 |
+
msg = "expected ranks to match queries ({} vs {}) "
|
| 432 |
+
assert cols.size == num_queries, msg
|
| 433 |
+
|
| 434 |
+
metrics = {}
|
| 435 |
+
vals = [str(x).zfill(2) for x in [1, 2, 3, 5, 10]]
|
| 436 |
+
for val in vals:
|
| 437 |
+
metrics[f"R{val}"] = 100 * float(np.sum(cols < int(val))) / num_queries
|
| 438 |
+
|
| 439 |
+
metrics["MedR"] = float(np.median(cols) + 1)
|
| 440 |
+
metrics["len"] = num_queries
|
| 441 |
+
|
| 442 |
+
if rounding is not None:
|
| 443 |
+
for key in metrics:
|
| 444 |
+
metrics[key] = round(metrics[key], rounding)
|
| 445 |
+
return metrics
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def break_ties_average(sorted_dists, gt_dists):
|
| 449 |
+
# fast implementation, based on this code:
|
| 450 |
+
# https://stackoverflow.com/a/49239335
|
| 451 |
+
locs = np.argwhere((sorted_dists - gt_dists) == 0)
|
| 452 |
+
|
| 453 |
+
# Find the split indices
|
| 454 |
+
steps = np.diff(locs[:, 0])
|
| 455 |
+
splits = np.nonzero(steps)[0] + 1
|
| 456 |
+
splits = np.insert(splits, 0, 0)
|
| 457 |
+
|
| 458 |
+
# Compute the result columns
|
| 459 |
+
summed_cols = np.add.reduceat(locs[:, 1], splits)
|
| 460 |
+
counts = np.diff(np.append(splits, locs.shape[0]))
|
| 461 |
+
avg_cols = summed_cols / counts
|
| 462 |
+
return avg_cols
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def calculate_activation_statistics(activations):
|
| 466 |
+
"""
|
| 467 |
+
Params:
|
| 468 |
+
-- activation: num_samples x dim_feat
|
| 469 |
+
Returns:
|
| 470 |
+
-- mu: dim_feat
|
| 471 |
+
-- sigma: dim_feat x dim_feat
|
| 472 |
+
"""
|
| 473 |
+
mu = np.mean(activations, axis=0)
|
| 474 |
+
cov = np.cov(activations, rowvar=False)
|
| 475 |
+
return mu, cov
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
| 479 |
+
"""Numpy implementation of the Frechet Distance. The Frechet distance between two multivariate
|
| 480 |
+
Gaussians X_1 ~ N(mu_1, C_1)
|
| 481 |
+
|
| 482 |
+
and X_2 ~ N(mu_2, C_2) is
|
| 483 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
| 484 |
+
Stable version by Dougal J. Sutherland.
|
| 485 |
+
Params:
|
| 486 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
| 487 |
+
inception net (like returned by the function 'get_predictions')
|
| 488 |
+
for generated samples.
|
| 489 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
| 490 |
+
representative dataset set.
|
| 491 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
| 492 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
| 493 |
+
representative dataset set.
|
| 494 |
+
Returns:
|
| 495 |
+
-- : The Frechet Distance.
|
| 496 |
+
"""
|
| 497 |
+
|
| 498 |
+
mu1 = np.atleast_1d(mu1)
|
| 499 |
+
mu2 = np.atleast_1d(mu2)
|
| 500 |
+
|
| 501 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 502 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 503 |
+
|
| 504 |
+
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
|
| 505 |
+
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
|
| 506 |
+
|
| 507 |
+
diff = mu1 - mu2
|
| 508 |
+
|
| 509 |
+
# Product might be almost singular
|
| 510 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 511 |
+
if not np.isfinite(covmean).all():
|
| 512 |
+
msg = ("fid calculation produces singular product; " "adding %s to diagonal of cov estimates") % eps
|
| 513 |
+
print(msg)
|
| 514 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 515 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 516 |
+
|
| 517 |
+
# Numerical error might give slight imaginary component
|
| 518 |
+
if np.iscomplexobj(covmean):
|
| 519 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 520 |
+
# try again with diagonal %s
|
| 521 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 522 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 523 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 524 |
+
m = np.max(np.abs(covmean.imag))
|
| 525 |
+
raise ValueError("Imaginary component {}".format(m))
|
| 526 |
+
covmean = covmean.real
|
| 527 |
+
|
| 528 |
+
tr_covmean = np.trace(covmean)
|
| 529 |
+
|
| 530 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
kimodo/model/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Kimodo model package: main model class, text encoders, and loading utilities."""
|
| 4 |
+
|
| 5 |
+
from .common import resolve_target
|
| 6 |
+
from .kimodo_model import Kimodo
|
| 7 |
+
from .llm2vec import LLM2VecEncoder
|
| 8 |
+
from .load_model import load_model
|
| 9 |
+
from .loading import (
|
| 10 |
+
AVAILABLE_MODELS,
|
| 11 |
+
DEFAULT_MODEL,
|
| 12 |
+
DEFAULT_TEXT_ENCODER_URL,
|
| 13 |
+
MODEL_NAMES,
|
| 14 |
+
load_checkpoint_state_dict,
|
| 15 |
+
)
|
| 16 |
+
from .tmr import TMR
|
| 17 |
+
from .twostage_denoiser import TwostageDenoiser
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"Kimodo",
|
| 21 |
+
"LLM2VecEncoder",
|
| 22 |
+
"TMR",
|
| 23 |
+
"TwostageDenoiser",
|
| 24 |
+
"load_model",
|
| 25 |
+
"load_checkpoint_state_dict",
|
| 26 |
+
"resolve_target",
|
| 27 |
+
"AVAILABLE_MODELS",
|
| 28 |
+
"DEFAULT_MODEL",
|
| 29 |
+
"DEFAULT_TEXT_ENCODER_URL",
|
| 30 |
+
"MODEL_NAMES",
|
| 31 |
+
]
|
kimodo/model/backbone.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Transformer backbone: padding, masking, and encoder stack for the denoiser."""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Optional, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from omegaconf import ListConfig
|
| 10 |
+
from pydantic.dataclasses import dataclass
|
| 11 |
+
from torch import Tensor, nn
|
| 12 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
| 13 |
+
|
| 14 |
+
from kimodo.tools import validate
|
| 15 |
+
|
| 16 |
+
log = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def pad_x_and_mask_to_fixed_size(x: Tensor, mask: Tensor, size: int):
|
| 20 |
+
"""Pad a feature vector x and the mask to always have the same size.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
x (torch.Tensor): [B, T, D]
|
| 24 |
+
mask (torch.Tensor): [B, T]
|
| 25 |
+
size (int)
|
| 26 |
+
Returns:
|
| 27 |
+
torch.Tensor: [B, size, D]
|
| 28 |
+
torch.Tensor: [B, size]
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
batch_size, cur_max_size, dim = x.shape[0], x.shape[1], x.shape[2]
|
| 32 |
+
|
| 33 |
+
if cur_max_size == size:
|
| 34 |
+
# already padded to this size, probably in the collate function
|
| 35 |
+
return x, mask
|
| 36 |
+
|
| 37 |
+
if cur_max_size > size:
|
| 38 |
+
# This issue should have been handled in the collate function
|
| 39 |
+
# usefull as a check for test time
|
| 40 |
+
log.warn("The size of the tensor is larger than the maximum size. Cropping the input..")
|
| 41 |
+
cur_max_size = size
|
| 42 |
+
|
| 43 |
+
new_x = torch.zeros(
|
| 44 |
+
(batch_size, size, dim),
|
| 45 |
+
dtype=x.dtype,
|
| 46 |
+
device=x.device,
|
| 47 |
+
)
|
| 48 |
+
new_x[:, :cur_max_size] = x
|
| 49 |
+
|
| 50 |
+
# same for the mask
|
| 51 |
+
new_mask = torch.zeros(
|
| 52 |
+
(batch_size, size),
|
| 53 |
+
dtype=mask.dtype,
|
| 54 |
+
device=mask.device,
|
| 55 |
+
)
|
| 56 |
+
new_mask[:, :cur_max_size] = mask
|
| 57 |
+
return new_x, new_mask
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass(frozen=True, config=dict(extra="forbid", arbitrary_types_allowed=True))
|
| 61 |
+
class TransformerEncoderBlockConfig:
|
| 62 |
+
"""Configuration for the transformer encoder backbone."""
|
| 63 |
+
|
| 64 |
+
# input features dimension
|
| 65 |
+
input_dim: int
|
| 66 |
+
# output features dimension
|
| 67 |
+
output_dim: int
|
| 68 |
+
|
| 69 |
+
# skeleton object
|
| 70 |
+
skeleton: object
|
| 71 |
+
|
| 72 |
+
# dimension of the text embeddings
|
| 73 |
+
llm_shape: Union[list[int], ListConfig]
|
| 74 |
+
|
| 75 |
+
# mask the text or not
|
| 76 |
+
use_text_mask: bool
|
| 77 |
+
|
| 78 |
+
# latent dimension of the model
|
| 79 |
+
latent_dim: int
|
| 80 |
+
# dimension of the feedforward network in transformer
|
| 81 |
+
ff_size: int
|
| 82 |
+
# num layers in transformer
|
| 83 |
+
num_layers: int
|
| 84 |
+
# num heads in transformer
|
| 85 |
+
num_heads: int
|
| 86 |
+
# activation in transformer
|
| 87 |
+
activation: str
|
| 88 |
+
# dropout rate for the transformer
|
| 89 |
+
dropout: float
|
| 90 |
+
# dropout rate for the positional embeddings
|
| 91 |
+
pe_dropout: float
|
| 92 |
+
# use norm first or not
|
| 93 |
+
norm_first: bool = False
|
| 94 |
+
# artificially extend the number of text tokens
|
| 95 |
+
num_text_tokens_override: Optional[int] = None
|
| 96 |
+
|
| 97 |
+
# Input first heading angle
|
| 98 |
+
input_first_heading_angle: bool = False
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class TransformerEncoderBlock(nn.Module):
|
| 102 |
+
@validate(TransformerEncoderBlockConfig, save_args=True, super_init=True)
|
| 103 |
+
def __init__(self, conf):
|
| 104 |
+
self.nbjoints = self.skeleton.nbjoints
|
| 105 |
+
llm_dim = self.llm_shape[-1]
|
| 106 |
+
self.embed_text = nn.Linear(llm_dim, self.latent_dim)
|
| 107 |
+
|
| 108 |
+
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.pe_dropout)
|
| 109 |
+
|
| 110 |
+
# maximum number of tokens
|
| 111 |
+
self.num_text_tokens = self.llm_shape[0]
|
| 112 |
+
if self.num_text_tokens_override is not None:
|
| 113 |
+
self.num_text_tokens = self.num_text_tokens_override
|
| 114 |
+
|
| 115 |
+
self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)
|
| 116 |
+
|
| 117 |
+
self.input_linear = nn.Linear(self.input_dim, self.latent_dim)
|
| 118 |
+
self.output_linear = nn.Linear(self.latent_dim, self.output_dim)
|
| 119 |
+
self.linear_first_heading_angle = nn.Linear(2, self.latent_dim)
|
| 120 |
+
|
| 121 |
+
trans_enc_layer = TransformerEncoderLayer(
|
| 122 |
+
d_model=self.latent_dim,
|
| 123 |
+
nhead=self.num_heads,
|
| 124 |
+
dim_feedforward=self.ff_size,
|
| 125 |
+
dropout=self.dropout,
|
| 126 |
+
activation=self.activation,
|
| 127 |
+
batch_first=True,
|
| 128 |
+
norm_first=self.norm_first,
|
| 129 |
+
)
|
| 130 |
+
self.seqTransEncoder = TransformerEncoder(
|
| 131 |
+
trans_enc_layer,
|
| 132 |
+
num_layers=self.num_layers,
|
| 133 |
+
enable_nested_tensor=False,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def forward(
|
| 137 |
+
self,
|
| 138 |
+
x: Tensor,
|
| 139 |
+
x_pad_mask: torch.Tensor,
|
| 140 |
+
text_feat: torch.Tensor,
|
| 141 |
+
text_feat_pad_mask: torch.Tensor,
|
| 142 |
+
timesteps: Tensor,
|
| 143 |
+
first_heading_angle: Optional[Tensor] = None,
|
| 144 |
+
) -> Tensor:
|
| 145 |
+
"""
|
| 146 |
+
Args:
|
| 147 |
+
x (torch.Tensor): [B, T, dim_motion] current noisy motion
|
| 148 |
+
x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not
|
| 149 |
+
text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts
|
| 150 |
+
text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not
|
| 151 |
+
timesteps (torch.Tensor): [B,] current denoising step
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
torch.Tensor: [B, T, output_dim]
|
| 155 |
+
"""
|
| 156 |
+
batch_size = len(x)
|
| 157 |
+
x = self.input_linear(x) # [B, T, D]
|
| 158 |
+
|
| 159 |
+
# Pad the text tokens + mask to always have the same size == self.num_text_tokens
|
| 160 |
+
# done here if it was not done in the collate function
|
| 161 |
+
if self.num_text_tokens is not None:
|
| 162 |
+
text_feat, text_feat_pad_mask = pad_x_and_mask_to_fixed_size(
|
| 163 |
+
text_feat,
|
| 164 |
+
text_feat_pad_mask,
|
| 165 |
+
self.num_text_tokens,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Encode the text features and the time information
|
| 169 |
+
emb_text = self.embed_text(text_feat) # [B, max_text_len, D]
|
| 170 |
+
emb_time = self.embed_timestep(timesteps) # [B, 1, D]
|
| 171 |
+
|
| 172 |
+
# Create mask for the time information
|
| 173 |
+
time_mask = torch.ones((batch_size, 1), dtype=bool, device=x.device)
|
| 174 |
+
|
| 175 |
+
# Create the prefix features (text, time, etc): [B, max_text_len + 1 + etc]
|
| 176 |
+
prefix_feats = torch.cat((emb_text, emb_time), axis=1)
|
| 177 |
+
|
| 178 |
+
# Behavior from old code: not use text mask -> True for all the tokens
|
| 179 |
+
if not self.use_text_mask:
|
| 180 |
+
text_feat_pad_mask = torch.ones(
|
| 181 |
+
(batch_size, emb_text.shape[1]),
|
| 182 |
+
dtype=torch.bool,
|
| 183 |
+
device=x.device,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
prefix_mask = torch.cat((text_feat_pad_mask, time_mask), axis=1)
|
| 187 |
+
|
| 188 |
+
# add the input first heading angle
|
| 189 |
+
if self.input_first_heading_angle:
|
| 190 |
+
assert first_heading_angle is not None, "The first heading angle is mandatory for this model"
|
| 191 |
+
# cos(angle) / sin(angle)
|
| 192 |
+
first_heading_angle_feats = torch.stack(
|
| 193 |
+
[
|
| 194 |
+
torch.cos(first_heading_angle),
|
| 195 |
+
torch.sin(first_heading_angle),
|
| 196 |
+
],
|
| 197 |
+
axis=-1,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
first_heading_angle_feats = self.linear_first_heading_angle(first_heading_angle_feats)
|
| 201 |
+
first_heading_angle_feats = first_heading_angle_feats[:, None] # for cat
|
| 202 |
+
first_heading_angle_mask = torch.ones(
|
| 203 |
+
(batch_size, 1),
|
| 204 |
+
dtype=bool,
|
| 205 |
+
device=x.device,
|
| 206 |
+
)
|
| 207 |
+
prefix_feats = torch.cat((prefix_feats, first_heading_angle_feats), axis=1)
|
| 208 |
+
prefix_mask = torch.cat((prefix_mask, first_heading_angle_mask), axis=1)
|
| 209 |
+
|
| 210 |
+
# compute the number of prefix features
|
| 211 |
+
pose_start_ind = prefix_feats.shape[1]
|
| 212 |
+
|
| 213 |
+
# Concatenate prefix and x: [B, len(prefix) + T, D]
|
| 214 |
+
xseq = torch.cat((prefix_feats, x), axis=1)
|
| 215 |
+
|
| 216 |
+
# Concatenate the masks and negate them: [B, len(prefix) + T]
|
| 217 |
+
src_key_padding_mask = ~torch.cat((prefix_mask, x_pad_mask), axis=1)
|
| 218 |
+
|
| 219 |
+
# Add positional encoding
|
| 220 |
+
xseq = self.sequence_pos_encoder(xseq)
|
| 221 |
+
|
| 222 |
+
# Input to the transformer and keep the motion indexes
|
| 223 |
+
if isinstance(self.seqTransEncoder, nn.TransformerEncoder):
|
| 224 |
+
assert not self.seqTransEncoder.use_nested_tensor, "Flash attention should be disabled due to bug!"
|
| 225 |
+
|
| 226 |
+
output = self.seqTransEncoder(
|
| 227 |
+
xseq,
|
| 228 |
+
src_key_padding_mask=src_key_padding_mask,
|
| 229 |
+
)
|
| 230 |
+
output = output[:, pose_start_ind:] # [B, T, D]
|
| 231 |
+
output = self.output_linear(output) # [B, T, OD]
|
| 232 |
+
return output
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class PositionalEncoding(nn.Module):
|
| 236 |
+
"""Non-learned positional encoding."""
|
| 237 |
+
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
d_model: int,
|
| 241 |
+
dropout: Optional[float] = 0.1,
|
| 242 |
+
max_len: Optional[int] = 5000,
|
| 243 |
+
):
|
| 244 |
+
"""
|
| 245 |
+
Args:
|
| 246 |
+
d_model (int): input dim
|
| 247 |
+
dropout (Optional[float] = 0.1): dropout probability on output
|
| 248 |
+
max_len (Optional[int] = 5000): maximum sequence length
|
| 249 |
+
"""
|
| 250 |
+
super(PositionalEncoding, self).__init__()
|
| 251 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 252 |
+
|
| 253 |
+
pe = torch.zeros(max_len, d_model)
|
| 254 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 255 |
+
|
| 256 |
+
# Note: have to replace torch.exp() and math.log() with torch.pow()
|
| 257 |
+
# due to MKL exp() and ln() throws floating point exceptions on certain CPUs
|
| 258 |
+
# see corresponding commit and MR
|
| 259 |
+
div_term = torch.pow(10000.0, -torch.arange(0, d_model, 2).float() / d_model)
|
| 260 |
+
# div_term = torch.exp(
|
| 261 |
+
# torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
|
| 262 |
+
# )
|
| 263 |
+
|
| 264 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 265 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 266 |
+
pe = pe.unsqueeze(0) # [1, T, D]
|
| 267 |
+
|
| 268 |
+
self.register_buffer("pe", pe, persistent=False)
|
| 269 |
+
|
| 270 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 271 |
+
"""Apply positional encoding to input sequence.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
x (torch.Tensor): [B, T, D] input motion sequence
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
torch.Tensor: [B, T, D] input motion with PE added to it (and optionally dropout)
|
| 278 |
+
"""
|
| 279 |
+
x = x + self.pe[:, : x.shape[1], :]
|
| 280 |
+
return self.dropout(x)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class TimestepEmbedder(nn.Module):
|
| 284 |
+
"""Encoder for diffusion step."""
|
| 285 |
+
|
| 286 |
+
def __init__(self, latent_dim: int, sequence_pos_encoder: PositionalEncoding):
|
| 287 |
+
"""
|
| 288 |
+
Args:
|
| 289 |
+
latent_dim (int): dim to encode to
|
| 290 |
+
sequence_pos_encoder (PositionalEncoding): the PE to use on timesteps
|
| 291 |
+
"""
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.latent_dim = latent_dim
|
| 294 |
+
self.sequence_pos_encoder = sequence_pos_encoder
|
| 295 |
+
|
| 296 |
+
time_embed_dim = self.latent_dim
|
| 297 |
+
self.time_embed = nn.Sequential(
|
| 298 |
+
nn.Linear(self.latent_dim, time_embed_dim),
|
| 299 |
+
nn.SiLU(),
|
| 300 |
+
nn.Linear(time_embed_dim, time_embed_dim),
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
| 304 |
+
"""Embed timesteps by adding PE then going through linear layers.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
timesteps (torch.Tensor): [B]
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
torch.Tensor: [B, 1, D]
|
| 311 |
+
"""
|
| 312 |
+
return self.time_embed(self.sequence_pos_encoder.pe.transpose(0, 1)[timesteps])
|
kimodo/model/cfg.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Classifier-free guidance wrapper for the denoiser at sampling time."""
|
| 4 |
+
|
| 5 |
+
from typing import Dict, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
CFG_TYPES = ["nocfg", "regular", "separated"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ClassifierFreeGuidedModel(nn.Module):
|
| 14 |
+
"""Wrapper around denoiser to use classifier-free guidance at sampling time."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, model: nn.Module, cfg_type: Optional[str] = "separated"):
|
| 17 |
+
"""Wrap the denoiser for classifier-free guidance; cfg_type in CFG_TYPES (e.g. 'regular',
|
| 18 |
+
'nocfg')."""
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.model = model
|
| 21 |
+
assert cfg_type in CFG_TYPES, f"Invalid cfg_type: {cfg_type}"
|
| 22 |
+
self.cfg_type_default = cfg_type
|
| 23 |
+
|
| 24 |
+
def forward(
|
| 25 |
+
self,
|
| 26 |
+
cfg_weight: Union[float, Tuple[float, float]],
|
| 27 |
+
x: torch.Tensor,
|
| 28 |
+
x_pad_mask: torch.Tensor,
|
| 29 |
+
text_feat: torch.Tensor,
|
| 30 |
+
text_feat_pad_mask: torch.Tensor,
|
| 31 |
+
timesteps: torch.Tensor,
|
| 32 |
+
first_heading_angle: Optional[torch.Tensor] = None,
|
| 33 |
+
motion_mask: Optional[torch.Tensor] = None,
|
| 34 |
+
observed_motion: Optional[torch.Tensor] = None,
|
| 35 |
+
cfg_type: Optional[str] = None,
|
| 36 |
+
) -> torch.Tensor:
|
| 37 |
+
"""
|
| 38 |
+
Args:
|
| 39 |
+
cfg_weight (float): guidance weight float or tuple of floats with (text, constraint) weights if using separated cfg
|
| 40 |
+
x (torch.Tensor): [B, T, dim_motion] current noisy motion
|
| 41 |
+
x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not
|
| 42 |
+
text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts
|
| 43 |
+
text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not
|
| 44 |
+
timesteps (torch.Tensor): [B,] current denoising step
|
| 45 |
+
motion_mask
|
| 46 |
+
observed_motion
|
| 47 |
+
neutral_joints (torch.Tensor): [B, nbjoints] The neutral joints of the motions
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
torch.Tensor: same size as input x
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
if cfg_type is None:
|
| 54 |
+
cfg_type = self.cfg_type_default
|
| 55 |
+
|
| 56 |
+
assert cfg_type in CFG_TYPES, f"Invalid cfg_type: {cfg_type}"
|
| 57 |
+
|
| 58 |
+
# batched conditional and uncond pass together
|
| 59 |
+
if cfg_type == "nocfg":
|
| 60 |
+
return self.model(
|
| 61 |
+
x,
|
| 62 |
+
x_pad_mask,
|
| 63 |
+
text_feat,
|
| 64 |
+
text_feat_pad_mask,
|
| 65 |
+
timesteps,
|
| 66 |
+
first_heading_angle=first_heading_angle,
|
| 67 |
+
motion_mask=motion_mask,
|
| 68 |
+
observed_motion=observed_motion,
|
| 69 |
+
)
|
| 70 |
+
elif cfg_type == "regular":
|
| 71 |
+
assert isinstance(cfg_weight, (float, int)), "cfg_weight must be a single float for regular CFG"
|
| 72 |
+
# out_uncond + w * (out_text_and_constraint - out_uncond)
|
| 73 |
+
text_feat = torch.concatenate([text_feat, 0 * text_feat], dim=0)
|
| 74 |
+
if motion_mask is not None:
|
| 75 |
+
motion_mask = torch.concatenate([motion_mask, 0 * motion_mask], dim=0)
|
| 76 |
+
if observed_motion is not None:
|
| 77 |
+
observed_motion = torch.concatenate([observed_motion, observed_motion], dim=0)
|
| 78 |
+
if first_heading_angle is not None:
|
| 79 |
+
first_heading_angle = torch.concatenate([first_heading_angle, first_heading_angle], dim=0)
|
| 80 |
+
|
| 81 |
+
out_cond_uncond = self.model(
|
| 82 |
+
torch.concatenate([x, x], dim=0),
|
| 83 |
+
torch.concatenate([x_pad_mask, x_pad_mask], dim=0),
|
| 84 |
+
text_feat,
|
| 85 |
+
torch.concatenate([text_feat_pad_mask, False * text_feat_pad_mask], dim=0),
|
| 86 |
+
torch.concatenate([timesteps, timesteps], dim=0),
|
| 87 |
+
first_heading_angle=first_heading_angle,
|
| 88 |
+
motion_mask=motion_mask,
|
| 89 |
+
observed_motion=observed_motion,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
out, out_uncond = torch.chunk(out_cond_uncond, 2)
|
| 93 |
+
out_new = out_uncond + (cfg_weight * (out - out_uncond))
|
| 94 |
+
elif cfg_type == "separated":
|
| 95 |
+
assert len(cfg_weight) == 2, "cfg_weight must be a tuple of two floats for separated CFG"
|
| 96 |
+
# out_uncond + w_text * (out_text - out_uncond) + w_constraint * (out_constraint - out_uncond)
|
| 97 |
+
text_feat = torch.concatenate([text_feat, 0 * text_feat, 0 * text_feat], dim=0)
|
| 98 |
+
if motion_mask is not None:
|
| 99 |
+
motion_mask = torch.concatenate([0 * motion_mask, motion_mask, 0 * motion_mask], dim=0)
|
| 100 |
+
if observed_motion is not None:
|
| 101 |
+
observed_motion = torch.concatenate([observed_motion, observed_motion, observed_motion], dim=0)
|
| 102 |
+
if first_heading_angle is not None:
|
| 103 |
+
first_heading_angle = torch.concatenate(
|
| 104 |
+
[first_heading_angle, first_heading_angle, first_heading_angle],
|
| 105 |
+
dim=0,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
out_cond_uncond = self.model(
|
| 109 |
+
torch.concatenate([x, x, x], dim=0),
|
| 110 |
+
torch.concatenate([x_pad_mask, x_pad_mask, x_pad_mask], dim=0),
|
| 111 |
+
text_feat,
|
| 112 |
+
torch.concatenate(
|
| 113 |
+
[
|
| 114 |
+
text_feat_pad_mask,
|
| 115 |
+
False * text_feat_pad_mask,
|
| 116 |
+
False * text_feat_pad_mask,
|
| 117 |
+
],
|
| 118 |
+
dim=0,
|
| 119 |
+
),
|
| 120 |
+
torch.concatenate([timesteps, timesteps, timesteps], dim=0),
|
| 121 |
+
first_heading_angle=first_heading_angle,
|
| 122 |
+
motion_mask=motion_mask,
|
| 123 |
+
observed_motion=observed_motion,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
out_text, out_constraint, out_uncond = torch.chunk(out_cond_uncond, 3)
|
| 127 |
+
out_new = (
|
| 128 |
+
out_uncond + (cfg_weight[0] * (out_text - out_uncond)) + (cfg_weight[1] * (out_constraint - out_uncond))
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
raise ValueError(f"Invalid cfg_type: {cfg_type}")
|
| 132 |
+
|
| 133 |
+
return out_new
|
kimodo/model/common.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Config hydration: env vars, _target_ resolution, and recursive instantiation."""
|
| 4 |
+
|
| 5 |
+
import importlib
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_env_var(name: str, default=None):
|
| 10 |
+
"""Read env var by name and by lowercased name; return default if neither set."""
|
| 11 |
+
return os.getenv(name, os.getenv(name.lower(), default))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def resolve_target(target: str):
|
| 15 |
+
"""Import module and return the attribute named by a dotted path (e.g. 'pkg.mod.Class')."""
|
| 16 |
+
module_name, attr_name = target.rsplit(".", 1)
|
| 17 |
+
module = importlib.import_module(module_name)
|
| 18 |
+
return getattr(module, attr_name)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def materialize_value(value):
|
| 22 |
+
"""Recursively turn dicts with '_target_' into instances; lists/dicts traversed; leaves
|
| 23 |
+
unchanged."""
|
| 24 |
+
if isinstance(value, dict):
|
| 25 |
+
if "_target_" in value:
|
| 26 |
+
return instantiate_from_dict(value)
|
| 27 |
+
return {k: materialize_value(v) for k, v in value.items()}
|
| 28 |
+
if isinstance(value, list):
|
| 29 |
+
return [materialize_value(v) for v in value]
|
| 30 |
+
return value
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def instantiate_from_dict(node, overrides=None):
|
| 34 |
+
"""Build an instance from a config dict: '_target_' gives the class, other keys are kwargs; overrides merged in."""
|
| 35 |
+
if not isinstance(node, dict) or "_target_" not in node:
|
| 36 |
+
raise ValueError("Config node must be a dict with a '_target_' key.")
|
| 37 |
+
|
| 38 |
+
target = resolve_target(node["_target_"])
|
| 39 |
+
kwargs = {}
|
| 40 |
+
for key, value in node.items():
|
| 41 |
+
if key == "_target_":
|
| 42 |
+
continue
|
| 43 |
+
kwargs[key] = materialize_value(value)
|
| 44 |
+
|
| 45 |
+
if overrides:
|
| 46 |
+
kwargs.update({k: v for k, v in overrides.items() if v is not None})
|
| 47 |
+
|
| 48 |
+
return target(**kwargs)
|
kimodo/model/diffusion.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Diffusion process and DDIM sampling for motion generation."""
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_beta_schedule(
|
| 13 |
+
num_diffusion_timesteps: int,
|
| 14 |
+
max_beta: Optional[float] = 0.999,
|
| 15 |
+
) -> torch.Tensor:
|
| 16 |
+
"""Get cosine beta schedule."""
|
| 17 |
+
|
| 18 |
+
def alpha_bar(t):
|
| 19 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 20 |
+
|
| 21 |
+
betas = []
|
| 22 |
+
for i in range(num_diffusion_timesteps):
|
| 23 |
+
t1 = i / num_diffusion_timesteps
|
| 24 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 25 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
| 26 |
+
return torch.tensor(betas, dtype=torch.float)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Diffusion(torch.nn.Module):
|
| 30 |
+
"""Cosine-schedule diffusion process: betas, alphas, and DDIM step mapping."""
|
| 31 |
+
|
| 32 |
+
def __init__(self, num_base_steps: int):
|
| 33 |
+
"""Set up cosine beta schedule and precompute diffusion variables for num_base_steps."""
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.num_base_steps = num_base_steps
|
| 36 |
+
betas_base = get_beta_schedule(self.num_base_steps)
|
| 37 |
+
self.register_buffer("betas_base", betas_base, persistent=False)
|
| 38 |
+
alphas_cumprod_base = torch.cumprod(1.0 - self.betas_base, dim=0)
|
| 39 |
+
self.register_buffer("alphas_cumprod_base", alphas_cumprod_base, persistent=False)
|
| 40 |
+
use_timesteps, _ = self.space_timesteps(self.num_base_steps)
|
| 41 |
+
self.calc_diffusion_vars(use_timesteps)
|
| 42 |
+
|
| 43 |
+
def extra_repr(self) -> str:
|
| 44 |
+
return f"num_base_steps={self.num_base_steps}"
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def device(self):
|
| 48 |
+
return self.betas_base.device
|
| 49 |
+
|
| 50 |
+
def space_timesteps(self, num_denoising_steps: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 51 |
+
"""Return (use_timesteps, map_tensor) for a subsampled denoising schedule of
|
| 52 |
+
num_denoising_steps."""
|
| 53 |
+
nsteps_train = self.num_base_steps
|
| 54 |
+
frac_stride = (nsteps_train - 1) / max(1, num_denoising_steps - 1)
|
| 55 |
+
use_timesteps = torch.round(torch.arange(nsteps_train, device=self.device) * frac_stride).to(torch.long)
|
| 56 |
+
use_timesteps = torch.clamp(use_timesteps, max=nsteps_train - 1)
|
| 57 |
+
map_tensor = torch.arange(nsteps_train, device=self.device, dtype=torch.long)[use_timesteps]
|
| 58 |
+
return use_timesteps, map_tensor
|
| 59 |
+
|
| 60 |
+
def calc_diffusion_vars(self, use_timesteps: torch.Tensor) -> None:
|
| 61 |
+
"""Update buffers (betas, alphas, alphas_cumprod, etc.) for the given subsampled
|
| 62 |
+
timesteps."""
|
| 63 |
+
alphas_cumprod = self.alphas_cumprod_base[use_timesteps]
|
| 64 |
+
last_alpha_cumprod = torch.cat([torch.tensor([1.0]).to(alphas_cumprod), alphas_cumprod[:-1]])
|
| 65 |
+
betas = 1.0 - alphas_cumprod / last_alpha_cumprod
|
| 66 |
+
self.register_buffer("betas", betas, persistent=False)
|
| 67 |
+
|
| 68 |
+
alphas = 1.0 - self.betas
|
| 69 |
+
self.register_buffer("alphas", alphas, persistent=False)
|
| 70 |
+
alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
| 71 |
+
alphas_cumprod = torch.clamp(alphas_cumprod, min=1e-9)
|
| 72 |
+
self.register_buffer("alphas_cumprod", alphas_cumprod, persistent=False)
|
| 73 |
+
|
| 74 |
+
alphas_cumprod_prev = torch.cat([torch.tensor([1.0]).to(self.alphas_cumprod), self.alphas_cumprod[:-1]])
|
| 75 |
+
self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev, persistent=False)
|
| 76 |
+
|
| 77 |
+
sqrt_recip_alphas_cumprod = torch.rsqrt(self.alphas_cumprod)
|
| 78 |
+
self.register_buffer("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod, persistent=False)
|
| 79 |
+
|
| 80 |
+
sqrt_recipm1_alphas_cumprod = torch.rsqrt(self.alphas_cumprod / (1.0 - self.alphas_cumprod))
|
| 81 |
+
self.register_buffer("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod, persistent=False)
|
| 82 |
+
|
| 83 |
+
posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
| 84 |
+
self.register_buffer("posterior_variance", posterior_variance, persistent=False)
|
| 85 |
+
|
| 86 |
+
sqrt_alphas_cumprod = torch.rsqrt(1.0 / self.alphas_cumprod)
|
| 87 |
+
self.register_buffer("sqrt_alphas_cumprod", sqrt_alphas_cumprod, persistent=False)
|
| 88 |
+
|
| 89 |
+
sqrt_one_minus_alphas_cumprod = torch.rsqrt(1.0 / (1.0 - self.alphas_cumprod))
|
| 90 |
+
self.register_buffer(
|
| 91 |
+
"sqrt_one_minus_alphas_cumprod",
|
| 92 |
+
sqrt_one_minus_alphas_cumprod,
|
| 93 |
+
persistent=False,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def q_sample(
|
| 97 |
+
self,
|
| 98 |
+
x_start: torch.Tensor,
|
| 99 |
+
t: torch.Tensor,
|
| 100 |
+
noise: torch.Tensor = None,
|
| 101 |
+
):
|
| 102 |
+
if noise is None:
|
| 103 |
+
noise = torch.randn_like(x_start)
|
| 104 |
+
assert noise.shape == x_start.shape
|
| 105 |
+
|
| 106 |
+
xt = (
|
| 107 |
+
self.sqrt_alphas_cumprod[t, None, None] * x_start
|
| 108 |
+
+ self.sqrt_one_minus_alphas_cumprod[t, None, None] * noise
|
| 109 |
+
)
|
| 110 |
+
return xt
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class DDIMSampler(nn.Module):
|
| 114 |
+
"""Deterministic DDIM sampler (eta = 0)."""
|
| 115 |
+
|
| 116 |
+
def __init__(self, diffusion: Diffusion):
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.diffusion = diffusion
|
| 119 |
+
|
| 120 |
+
def __call__(
|
| 121 |
+
self,
|
| 122 |
+
use_timesteps: torch.Tensor,
|
| 123 |
+
x_t: torch.Tensor,
|
| 124 |
+
pred_xstart: torch.Tensor,
|
| 125 |
+
t: torch.Tensor,
|
| 126 |
+
) -> torch.Tensor:
|
| 127 |
+
self.diffusion.calc_diffusion_vars(use_timesteps)
|
| 128 |
+
eps = (
|
| 129 |
+
self.diffusion.sqrt_recip_alphas_cumprod[t, None, None] * x_t - pred_xstart
|
| 130 |
+
) / self.diffusion.sqrt_recipm1_alphas_cumprod[t, None, None]
|
| 131 |
+
alpha_bar_prev = self.diffusion.alphas_cumprod_prev[t, None, None]
|
| 132 |
+
x = pred_xstart * torch.sqrt(alpha_bar_prev) + torch.sqrt(1 - alpha_bar_prev) * eps
|
| 133 |
+
return x
|
kimodo/model/kimodo_model.py
ADDED
|
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Kimodo model: denoiser, text encoder, diffusion sampling, and post-processing."""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from tqdm.auto import tqdm
|
| 11 |
+
|
| 12 |
+
from kimodo.constraints import FullBodyConstraintSet
|
| 13 |
+
from kimodo.motion_rep.feature_utils import compute_heading_angle, length_to_mask
|
| 14 |
+
from kimodo.postprocess import post_process_motion
|
| 15 |
+
from kimodo.sanitize import sanitize_texts
|
| 16 |
+
from kimodo.skeleton import SOMASkeleton30
|
| 17 |
+
from kimodo.tools import to_numpy
|
| 18 |
+
|
| 19 |
+
from .cfg import ClassifierFreeGuidedModel
|
| 20 |
+
from .diffusion import DDIMSampler, Diffusion
|
| 21 |
+
|
| 22 |
+
log = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Kimodo(nn.Module):
|
| 26 |
+
"""Helper class for test time."""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
denoiser: nn.Module,
|
| 31 |
+
text_encoder: nn.Module,
|
| 32 |
+
num_base_steps: int,
|
| 33 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 34 |
+
cfg_type: Optional[str] = "separated",
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
self.denoiser = denoiser.eval()
|
| 39 |
+
|
| 40 |
+
if cfg_type is None:
|
| 41 |
+
cfg_type = "nocfg"
|
| 42 |
+
|
| 43 |
+
# Add Classifier-free guidance to the model if needed
|
| 44 |
+
self.denoiser = ClassifierFreeGuidedModel(self.denoiser, cfg_type=cfg_type)
|
| 45 |
+
|
| 46 |
+
self.motion_rep = denoiser.motion_rep
|
| 47 |
+
self.skeleton = self.motion_rep.skeleton
|
| 48 |
+
|
| 49 |
+
self.fps = denoiser.motion_rep.fps
|
| 50 |
+
|
| 51 |
+
self.diffusion = Diffusion(num_base_steps=num_base_steps)
|
| 52 |
+
self.sampler = DDIMSampler(self.diffusion)
|
| 53 |
+
self.text_encoder = text_encoder
|
| 54 |
+
|
| 55 |
+
self.device = device
|
| 56 |
+
# for classifier-free guidance
|
| 57 |
+
|
| 58 |
+
self.to(device)
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def output_skeleton(self):
|
| 62 |
+
"""Skeleton used for model output (somaskel77 for SOMA, else unchanged)."""
|
| 63 |
+
if isinstance(self.skeleton, SOMASkeleton30):
|
| 64 |
+
return self.skeleton.somaskel77
|
| 65 |
+
return self.skeleton
|
| 66 |
+
|
| 67 |
+
def train(self, mode: bool):
|
| 68 |
+
self.denoiser.train(mode)
|
| 69 |
+
return self
|
| 70 |
+
|
| 71 |
+
def eval(self):
|
| 72 |
+
self.denoiser.eval()
|
| 73 |
+
return self
|
| 74 |
+
|
| 75 |
+
def denoising_step(
|
| 76 |
+
self,
|
| 77 |
+
motion: torch.Tensor,
|
| 78 |
+
pad_mask: torch.Tensor,
|
| 79 |
+
text_feat: torch.Tensor,
|
| 80 |
+
text_pad_mask: torch.Tensor,
|
| 81 |
+
t: torch.Tensor,
|
| 82 |
+
first_heading_angle: Optional[torch.Tensor],
|
| 83 |
+
motion_mask: torch.Tensor,
|
| 84 |
+
observed_motion: torch.Tensor,
|
| 85 |
+
num_denoising_steps: torch.Tensor,
|
| 86 |
+
cfg_weight: Union[float, Tuple[float, float]],
|
| 87 |
+
guide_masks: Optional[Dict] = None,
|
| 88 |
+
cfg_type: Optional[str] = None,
|
| 89 |
+
) -> torch.Tensor:
|
| 90 |
+
"""Single denoising step.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
torch.Tensor: [B, T, D] noisy motion input to t-1
|
| 94 |
+
"""
|
| 95 |
+
# subsample timesteps
|
| 96 |
+
# NOTE: do this at every step due to ONNX export, i.e. num_samp_stepsmay change dynamically when
|
| 97 |
+
# running onnx version so need to account for that.
|
| 98 |
+
num_denoising_steps = num_denoising_steps[0]
|
| 99 |
+
use_timesteps, map_tensor = self.diffusion.space_timesteps(num_denoising_steps)
|
| 100 |
+
self.diffusion.calc_diffusion_vars(use_timesteps)
|
| 101 |
+
|
| 102 |
+
# first compute initial clean prediction from denoiser
|
| 103 |
+
t_map = map_tensor[t]
|
| 104 |
+
|
| 105 |
+
with torch.inference_mode():
|
| 106 |
+
pred_clean = self.denoiser(
|
| 107 |
+
cfg_weight,
|
| 108 |
+
motion,
|
| 109 |
+
pad_mask,
|
| 110 |
+
text_feat,
|
| 111 |
+
text_pad_mask,
|
| 112 |
+
t_map,
|
| 113 |
+
first_heading_angle,
|
| 114 |
+
motion_mask,
|
| 115 |
+
observed_motion,
|
| 116 |
+
cfg_type=cfg_type,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# sampler computes next step noisy motion
|
| 120 |
+
x_tm1 = self.sampler(use_timesteps, motion, pred_clean, t)
|
| 121 |
+
return x_tm1
|
| 122 |
+
|
| 123 |
+
def _multiprompt(
|
| 124 |
+
self,
|
| 125 |
+
prompts: list[str],
|
| 126 |
+
num_frames: int | list[int],
|
| 127 |
+
num_denoising_steps: int,
|
| 128 |
+
constraint_lst: Optional[list] = [],
|
| 129 |
+
cfg_weight: Optional[float] = [2.0, 2.0],
|
| 130 |
+
num_samples: Optional[int] = None,
|
| 131 |
+
cfg_type: Optional[str] = None,
|
| 132 |
+
return_numpy: bool = False,
|
| 133 |
+
first_heading_angle: Optional[torch.Tensor] = None,
|
| 134 |
+
# for transitioning
|
| 135 |
+
num_transition_frames: int = 5,
|
| 136 |
+
share_transition: bool = True,
|
| 137 |
+
percentage_transition_override=0.10,
|
| 138 |
+
# for postprocess
|
| 139 |
+
post_processing: bool = False,
|
| 140 |
+
root_margin: float = 0.04,
|
| 141 |
+
# progress bar
|
| 142 |
+
progress_bar=tqdm,
|
| 143 |
+
) -> torch.Tensor:
|
| 144 |
+
device = self.device
|
| 145 |
+
|
| 146 |
+
bs = num_samples
|
| 147 |
+
texts = sanitize_texts(prompts)
|
| 148 |
+
|
| 149 |
+
if isinstance(num_frames, int):
|
| 150 |
+
# same duration for all the segments
|
| 151 |
+
num_frames = [num_frames for _ in range(num_samples)]
|
| 152 |
+
|
| 153 |
+
tosqueeze = False
|
| 154 |
+
if num_samples is None:
|
| 155 |
+
num_samples = 1
|
| 156 |
+
tosqueeze = True
|
| 157 |
+
|
| 158 |
+
if constraint_lst is None:
|
| 159 |
+
constraint_lst = []
|
| 160 |
+
|
| 161 |
+
# Generate one chunck at a time
|
| 162 |
+
current_frame = 0
|
| 163 |
+
generated_motions = []
|
| 164 |
+
|
| 165 |
+
for idx, (text, num_frame) in enumerate(zip(texts, num_frames)):
|
| 166 |
+
texts_bs = [text for _ in range(num_samples)]
|
| 167 |
+
|
| 168 |
+
lengths = torch.tensor(
|
| 169 |
+
[num_frame for _ in range(num_samples)],
|
| 170 |
+
device=device,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
is_first_motion = not generated_motions
|
| 174 |
+
|
| 175 |
+
observed_motion, motion_mask = None, None
|
| 176 |
+
|
| 177 |
+
# filter the constraint_lst to only keep the relevent ones
|
| 178 |
+
constraint_lst_base = [
|
| 179 |
+
constraint.crop_move(current_frame, current_frame + num_frame) for constraint in constraint_lst
|
| 180 |
+
] # this move temporally but not spatially
|
| 181 |
+
|
| 182 |
+
observed_motion, motion_mask = self.motion_rep.create_conditions_from_constraints_batched(
|
| 183 |
+
constraint_lst_base,
|
| 184 |
+
lengths,
|
| 185 |
+
to_normalize=False, # don't normalize yet, it needs to be moved around
|
| 186 |
+
device=device,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if not is_first_motion:
|
| 190 |
+
prev_num_frame = num_frames[idx - 1]
|
| 191 |
+
if share_transition:
|
| 192 |
+
# starting the transitioning earlier, to "share" the transition between A and B
|
| 193 |
+
# in any case, we still use "num_transition_frames" for conditioning
|
| 194 |
+
# we don't condition until the end of A
|
| 195 |
+
# we compute the number of frames of transition as a percentage of the last motion
|
| 196 |
+
nb_transition_frames = num_transition_frames + int(prev_num_frame * percentage_transition_override)
|
| 197 |
+
else:
|
| 198 |
+
nb_transition_frames = num_transition_frames
|
| 199 |
+
|
| 200 |
+
latest_motions = generated_motions.pop()
|
| 201 |
+
# remove the transition part of A (will be put back afterward)
|
| 202 |
+
generated_motions.append(latest_motions[:, :-nb_transition_frames])
|
| 203 |
+
latest_frames = latest_motions[:, -nb_transition_frames:]
|
| 204 |
+
# latest_frames[..., 2] += 0.5
|
| 205 |
+
|
| 206 |
+
last_output = self.motion_rep.inverse(
|
| 207 |
+
latest_frames,
|
| 208 |
+
is_normalized=False,
|
| 209 |
+
return_numpy=False,
|
| 210 |
+
)
|
| 211 |
+
smooth_root_2d = last_output["smooth_root_pos"][..., [0, 2]]
|
| 212 |
+
|
| 213 |
+
# add constraints at the begining to allow natural transitions
|
| 214 |
+
constraint_lst_transition = []
|
| 215 |
+
for batch_id in range(bs):
|
| 216 |
+
new_constraint = FullBodyConstraintSet(
|
| 217 |
+
self.skeleton,
|
| 218 |
+
torch.arange(num_transition_frames),
|
| 219 |
+
last_output["posed_joints"][batch_id, :num_transition_frames],
|
| 220 |
+
last_output["local_rot_mats"][batch_id, :num_transition_frames],
|
| 221 |
+
smooth_root_2d[batch_id, :num_transition_frames],
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# new lists
|
| 225 |
+
constraint_lst_transition.append([new_constraint])
|
| 226 |
+
|
| 227 |
+
transition_lengths = torch.tensor(
|
| 228 |
+
[nb_transition_frames for _ in range(num_samples)],
|
| 229 |
+
device=device,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
observed_motion_transition, motion_mask_transition = (
|
| 233 |
+
self.motion_rep.create_conditions_from_constraints_batched(
|
| 234 |
+
constraint_lst_transition,
|
| 235 |
+
transition_lengths,
|
| 236 |
+
to_normalize=False, # don't normalize yet
|
| 237 |
+
device=device,
|
| 238 |
+
)
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# concatenate the obversed motion / motion mask
|
| 242 |
+
observed_motion = torch.cat([observed_motion_transition, observed_motion], axis=1)
|
| 243 |
+
motion_mask = torch.cat([motion_mask_transition, motion_mask], axis=1)
|
| 244 |
+
|
| 245 |
+
# we need to move each observed motion in the batch to the new starting points
|
| 246 |
+
last_smooth_root_2d = smooth_root_2d[:, 0]
|
| 247 |
+
observed_motion = self.motion_rep.translate_2d(
|
| 248 |
+
observed_motion, -last_smooth_root_2d
|
| 249 |
+
) # equivalent to: self.motion_rep.translate_2d_to_zero(observed_motion)
|
| 250 |
+
|
| 251 |
+
# remove dummy values after moving
|
| 252 |
+
observed_motion = observed_motion * motion_mask
|
| 253 |
+
|
| 254 |
+
lengths = lengths + transition_lengths
|
| 255 |
+
first_heading_angle = compute_heading_angle(last_output["posed_joints"], self.skeleton)[:, 0]
|
| 256 |
+
else:
|
| 257 |
+
if first_heading_angle is None:
|
| 258 |
+
# Start at 0 angle, but this will change afterward
|
| 259 |
+
first_heading_angle = torch.tensor([0.0] * bs, device=device)
|
| 260 |
+
else:
|
| 261 |
+
first_heading_angle = torch.as_tensor(first_heading_angle, device=device)
|
| 262 |
+
if first_heading_angle.numel() == 1:
|
| 263 |
+
first_heading_angle = first_heading_angle.repeat(bs)
|
| 264 |
+
|
| 265 |
+
observed_motion = self.motion_rep.normalize(observed_motion)
|
| 266 |
+
|
| 267 |
+
max_frames = max(lengths)
|
| 268 |
+
motion_pad_mask = length_to_mask(lengths)
|
| 269 |
+
|
| 270 |
+
motion = self._generate(
|
| 271 |
+
texts_bs,
|
| 272 |
+
max_frames,
|
| 273 |
+
num_denoising_steps=num_denoising_steps,
|
| 274 |
+
pad_mask=motion_pad_mask,
|
| 275 |
+
first_heading_angle=first_heading_angle,
|
| 276 |
+
motion_mask=motion_mask,
|
| 277 |
+
observed_motion=observed_motion,
|
| 278 |
+
cfg_weight=cfg_weight,
|
| 279 |
+
cfg_type=cfg_type,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
motion = self.motion_rep.unnormalize(motion)
|
| 283 |
+
|
| 284 |
+
if not is_first_motion:
|
| 285 |
+
motion_with_transition = self.motion_rep.translate_2d(
|
| 286 |
+
motion,
|
| 287 |
+
last_smooth_root_2d,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
motion = motion_with_transition[:, num_transition_frames:]
|
| 291 |
+
transition_frames = motion_with_transition[:, :num_transition_frames]
|
| 292 |
+
# for sharing = True, the new motion contains the very last of A
|
| 293 |
+
|
| 294 |
+
# linearly combine the previously generated transitions with the newly generated ones
|
| 295 |
+
# so that we linearly go from previous gen to new gen
|
| 296 |
+
alpha = torch.linspace(1, 0, num_transition_frames, device=device)[:, None]
|
| 297 |
+
new_transition_frames = (
|
| 298 |
+
latest_frames[:, :num_transition_frames] * alpha + (1 - alpha) * transition_frames
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# add new transitions frames for A (merging with B predition of the history)
|
| 302 |
+
# for share_transition == True, this remove (do not add back) a small part of the end of A
|
| 303 |
+
# the small last part of A has been re-generated by B
|
| 304 |
+
generated_motions.append(new_transition_frames)
|
| 305 |
+
|
| 306 |
+
# motion[..., 2] += 0.5
|
| 307 |
+
|
| 308 |
+
generated_motions.append(motion)
|
| 309 |
+
current_frame += num_frame
|
| 310 |
+
|
| 311 |
+
generated_motions = torch.cat(generated_motions, axis=1) # temporal axis (b, t, d)
|
| 312 |
+
|
| 313 |
+
if tosqueeze:
|
| 314 |
+
generated_motions = generated_motions[0]
|
| 315 |
+
|
| 316 |
+
output = self.motion_rep.inverse(
|
| 317 |
+
generated_motions,
|
| 318 |
+
is_normalized=False,
|
| 319 |
+
return_numpy=False,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Apply post-processing if requested
|
| 323 |
+
if post_processing:
|
| 324 |
+
corrected = post_process_motion(
|
| 325 |
+
output["local_rot_mats"],
|
| 326 |
+
output["root_positions"],
|
| 327 |
+
output["foot_contacts"],
|
| 328 |
+
self.skeleton,
|
| 329 |
+
constraint_lst,
|
| 330 |
+
root_margin=root_margin,
|
| 331 |
+
)
|
| 332 |
+
output.update(corrected)
|
| 333 |
+
|
| 334 |
+
# Convert SOMA output to somaskel77 for external API
|
| 335 |
+
if isinstance(self.skeleton, SOMASkeleton30):
|
| 336 |
+
output = self.skeleton.output_to_SOMASkeleton77(output)
|
| 337 |
+
|
| 338 |
+
# Convert to numpy if requested
|
| 339 |
+
if return_numpy:
|
| 340 |
+
output = to_numpy(output)
|
| 341 |
+
return output
|
| 342 |
+
|
| 343 |
+
def __call__(
|
| 344 |
+
self,
|
| 345 |
+
prompts: str | list[str],
|
| 346 |
+
num_frames: int | list[int],
|
| 347 |
+
num_denoising_steps: int,
|
| 348 |
+
multi_prompt: bool = False,
|
| 349 |
+
constraint_lst: Optional[list] = [],
|
| 350 |
+
cfg_weight: Optional[float] = [2.0, 2.0],
|
| 351 |
+
num_samples: Optional[int] = None,
|
| 352 |
+
cfg_type: Optional[str] = None,
|
| 353 |
+
return_numpy: bool = False,
|
| 354 |
+
first_heading_angle: Optional[torch.Tensor] = None,
|
| 355 |
+
# for transitioning
|
| 356 |
+
num_transition_frames: int = 5,
|
| 357 |
+
share_transition: bool = True,
|
| 358 |
+
percentage_transition_override=0.10,
|
| 359 |
+
# for postprocess
|
| 360 |
+
post_processing: bool = False,
|
| 361 |
+
root_margin: float = 0.04,
|
| 362 |
+
# progress bar
|
| 363 |
+
progress_bar=tqdm,
|
| 364 |
+
) -> dict:
|
| 365 |
+
"""Generate motion from text prompts and optional kinematic constraints.
|
| 366 |
+
|
| 367 |
+
When a single prompt/num_frames pair is given, one motion is generated.
|
| 368 |
+
Passing lists of prompts and/or num_frames produces a batch of
|
| 369 |
+
independent motions. With ``multi_prompt=True``, the prompts are
|
| 370 |
+
treated as sequential segments that are generated and stitched together
|
| 371 |
+
with smooth transitions.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
prompts: One or more text descriptions of the desired motion.
|
| 375 |
+
A single string generates one sample; a list generates a batch
|
| 376 |
+
(or sequential segments when ``multi_prompt=True``).
|
| 377 |
+
num_frames: Duration of the generated motion in frames. Can be a
|
| 378 |
+
single int applied to every prompt or a per-prompt list.
|
| 379 |
+
num_denoising_steps: Number of DDIM denoising steps. More steps
|
| 380 |
+
generally improve quality at the cost of speed.
|
| 381 |
+
multi_prompt: If ``True``, treat ``prompts`` as an ordered sequence
|
| 382 |
+
of segments and concatenate them with transitions.
|
| 383 |
+
constraint_lst: Per-sample list of kinematic constraints (e.g.
|
| 384 |
+
keyframe poses, end-effector targets, 2-D paths). Pass an
|
| 385 |
+
empty list for unconstrained generation.
|
| 386 |
+
cfg_weight: Classifier-free guidance scale(s). A two-element list
|
| 387 |
+
``[text_cfg, constraint_cfg]`` controls text and constraint
|
| 388 |
+
guidance independently.
|
| 389 |
+
num_samples: Number of samples to generate.
|
| 390 |
+
cfg_type: Override the default CFG strategy set at init
|
| 391 |
+
(e.g. ``"separated"``).
|
| 392 |
+
return_numpy: If ``True``, convert all output tensors to numpy
|
| 393 |
+
arrays.
|
| 394 |
+
first_heading_angle: Initial body heading in radians. Shape
|
| 395 |
+
``(B,)`` or scalar. Defaults to ``0`` (facing +Z).
|
| 396 |
+
num_transition_frames: Number of overlapping frames used to blend
|
| 397 |
+
consecutive segments in multi-prompt mode.
|
| 398 |
+
share_transition: If ``True``, transition frames are shared between
|
| 399 |
+
adjacent segments rather than appended.
|
| 400 |
+
percentage_transition_override: Fraction of each segment's length
|
| 401 |
+
that may be overridden by the transition blend.
|
| 402 |
+
post_processing: If ``True``, apply post-processing
|
| 403 |
+
(foot-skate cleanup and constraint enforcement).
|
| 404 |
+
root_margin: Horizontal margin (in meters) used by the post-processor
|
| 405 |
+
to determine when to correct root motion. When root deviates more than
|
| 406 |
+
margin from the constraint, the post-processor will correct it.
|
| 407 |
+
progress_bar: Callable wrapping an iterable to display progress
|
| 408 |
+
(default: ``tqdm``). Pass a no-op to silence output.
|
| 409 |
+
|
| 410 |
+
Returns:
|
| 411 |
+
dict: A dictionary of motion tensors (or numpy arrays if
|
| 412 |
+
``return_numpy=True``) with the following keys:
|
| 413 |
+
|
| 414 |
+
- ``local_rot_mats`` – Local joint rotations as rotation matrices.
|
| 415 |
+
- ``global_rot_mats`` – Global joint rotations as rotation matrices.
|
| 416 |
+
- ``posed_joints`` – Joint positions in world space.
|
| 417 |
+
- ``root_positions`` – Root joint positions.
|
| 418 |
+
- ``smooth_root_pos`` – Smoothed root trajectory.
|
| 419 |
+
- ``foot_contacts`` – Boolean foot-contact labels [left heel, left toe, right heel, right toe].
|
| 420 |
+
- ``global_root_heading`` – Root heading angle over time.
|
| 421 |
+
"""
|
| 422 |
+
device = self.device
|
| 423 |
+
|
| 424 |
+
if multi_prompt:
|
| 425 |
+
# multi prompt generation
|
| 426 |
+
return self._multiprompt(
|
| 427 |
+
prompts,
|
| 428 |
+
num_frames,
|
| 429 |
+
num_denoising_steps,
|
| 430 |
+
constraint_lst,
|
| 431 |
+
cfg_weight,
|
| 432 |
+
num_samples,
|
| 433 |
+
cfg_type,
|
| 434 |
+
return_numpy,
|
| 435 |
+
first_heading_angle,
|
| 436 |
+
num_transition_frames,
|
| 437 |
+
share_transition,
|
| 438 |
+
percentage_transition_override,
|
| 439 |
+
post_processing,
|
| 440 |
+
root_margin,
|
| 441 |
+
progress_bar,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Input checking
|
| 445 |
+
tosqueeze = False
|
| 446 |
+
if isinstance(prompts, list) and isinstance(num_frames, list):
|
| 447 |
+
assert len(prompts) == len(num_frames), "The number of prompts should match the number of num_frames."
|
| 448 |
+
num_samples = len(prompts)
|
| 449 |
+
elif isinstance(prompts, list):
|
| 450 |
+
num_samples = len(prompts)
|
| 451 |
+
num_frames = [num_frames for _ in range(num_samples)]
|
| 452 |
+
elif isinstance(num_frames, list):
|
| 453 |
+
num_samples = len(num_frames)
|
| 454 |
+
prompts = [prompts for _ in range(num_samples)]
|
| 455 |
+
else:
|
| 456 |
+
if num_samples is None:
|
| 457 |
+
tosqueeze = True
|
| 458 |
+
num_samples = 1
|
| 459 |
+
prompts = [prompts for _ in range(num_samples)]
|
| 460 |
+
num_frames = [num_frames for _ in range(num_samples)]
|
| 461 |
+
|
| 462 |
+
bs = num_samples
|
| 463 |
+
texts = sanitize_texts(prompts)
|
| 464 |
+
|
| 465 |
+
lengths = torch.tensor(
|
| 466 |
+
num_frames,
|
| 467 |
+
device=device,
|
| 468 |
+
)
|
| 469 |
+
max_frames = max(lengths)
|
| 470 |
+
motion_pad_mask = length_to_mask(lengths)
|
| 471 |
+
|
| 472 |
+
if first_heading_angle is None:
|
| 473 |
+
# Start at 0 angle
|
| 474 |
+
first_heading_angle = torch.tensor([0.0] * bs, device=device)
|
| 475 |
+
else:
|
| 476 |
+
first_heading_angle = torch.as_tensor(first_heading_angle, device=device)
|
| 477 |
+
if first_heading_angle.numel() == 1:
|
| 478 |
+
first_heading_angle = first_heading_angle.repeat(bs)
|
| 479 |
+
|
| 480 |
+
observed_motion, motion_mask = None, None
|
| 481 |
+
if constraint_lst:
|
| 482 |
+
observed_motion, motion_mask = self.motion_rep.create_conditions_from_constraints_batched(
|
| 483 |
+
constraint_lst,
|
| 484 |
+
lengths,
|
| 485 |
+
to_normalize=True,
|
| 486 |
+
device=device,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
motion = self._generate(
|
| 490 |
+
texts,
|
| 491 |
+
max_frames,
|
| 492 |
+
num_denoising_steps=num_denoising_steps,
|
| 493 |
+
pad_mask=motion_pad_mask,
|
| 494 |
+
first_heading_angle=first_heading_angle,
|
| 495 |
+
motion_mask=motion_mask,
|
| 496 |
+
observed_motion=observed_motion,
|
| 497 |
+
cfg_weight=cfg_weight,
|
| 498 |
+
cfg_type=cfg_type,
|
| 499 |
+
progress_bar=progress_bar,
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
if tosqueeze:
|
| 503 |
+
motion = motion[0]
|
| 504 |
+
|
| 505 |
+
output = self.motion_rep.inverse(
|
| 506 |
+
motion,
|
| 507 |
+
is_normalized=True,
|
| 508 |
+
return_numpy=False, # Keep as tensor for potential post-processing
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# Apply post-processing if requested
|
| 512 |
+
if post_processing:
|
| 513 |
+
corrected = post_process_motion(
|
| 514 |
+
output["local_rot_mats"],
|
| 515 |
+
output["root_positions"],
|
| 516 |
+
output["foot_contacts"],
|
| 517 |
+
self.skeleton,
|
| 518 |
+
constraint_lst,
|
| 519 |
+
root_margin=root_margin,
|
| 520 |
+
)
|
| 521 |
+
# key frame outputs / foot contacts are not changed
|
| 522 |
+
output.update(corrected)
|
| 523 |
+
|
| 524 |
+
# Convert SOMA output to somaskel77 for external API
|
| 525 |
+
if isinstance(self.skeleton, SOMASkeleton30):
|
| 526 |
+
output = self.skeleton.output_to_SOMASkeleton77(output)
|
| 527 |
+
|
| 528 |
+
# Convert to numpy if requested
|
| 529 |
+
if return_numpy:
|
| 530 |
+
output = to_numpy(output)
|
| 531 |
+
return output
|
| 532 |
+
|
| 533 |
+
def _generate(
|
| 534 |
+
self,
|
| 535 |
+
texts: List[str],
|
| 536 |
+
max_frames: int,
|
| 537 |
+
num_denoising_steps: int,
|
| 538 |
+
pad_mask: torch.Tensor,
|
| 539 |
+
first_heading_angle: Optional[torch.Tensor],
|
| 540 |
+
motion_mask: torch.Tensor,
|
| 541 |
+
observed_motion: torch.Tensor,
|
| 542 |
+
cfg_weight: Optional[float] = 2.0,
|
| 543 |
+
text_feat: Optional[torch.Tensor] = None,
|
| 544 |
+
text_pad_mask: Optional[torch.Tensor] = None,
|
| 545 |
+
guide_masks: Optional[Dict] = None,
|
| 546 |
+
cfg_type: Optional[str] = None,
|
| 547 |
+
progress_bar=tqdm,
|
| 548 |
+
) -> torch.Tensor:
|
| 549 |
+
"""Sample full denoising loop.
|
| 550 |
+
|
| 551 |
+
Args:
|
| 552 |
+
texts (List[str]): batch of text prompts to use for sampling (if text_feat is not passed in)
|
| 553 |
+
"""
|
| 554 |
+
|
| 555 |
+
device = self.device
|
| 556 |
+
if text_feat is None:
|
| 557 |
+
assert text_pad_mask is None
|
| 558 |
+
log.info("Encoding text...")
|
| 559 |
+
text_feat, text_length = self.text_encoder(texts)
|
| 560 |
+
text_feat = text_feat.to(device)
|
| 561 |
+
|
| 562 |
+
# handle empty string (set to zero)
|
| 563 |
+
empty_text_mask = [len(text.strip()) == 0 for text in texts]
|
| 564 |
+
text_feat[empty_text_mask] = 0
|
| 565 |
+
|
| 566 |
+
# Create the pad mask for the text
|
| 567 |
+
batch_size, maxlen = text_feat.shape[:2]
|
| 568 |
+
tensor_text_length = torch.tensor(text_length, device=device)
|
| 569 |
+
tensor_text_length[empty_text_mask] = 0
|
| 570 |
+
text_pad_mask = torch.arange(maxlen, device=device).expand(batch_size, maxlen) < tensor_text_length[:, None]
|
| 571 |
+
|
| 572 |
+
if motion_mask is not None:
|
| 573 |
+
if motion_mask.dtype == torch.bool:
|
| 574 |
+
motion_mask = 1 * motion_mask
|
| 575 |
+
|
| 576 |
+
batch_size = text_feat.shape[0]
|
| 577 |
+
|
| 578 |
+
# sample loop
|
| 579 |
+
indices = list(range(num_denoising_steps))[::-1]
|
| 580 |
+
shape = (batch_size, max_frames, self.motion_rep.motion_rep_dim)
|
| 581 |
+
cur_mot = torch.randn(shape, device=self.device)
|
| 582 |
+
num_denoising_steps = torch.tensor(
|
| 583 |
+
[num_denoising_steps], device=self.device
|
| 584 |
+
) # this and t need to be tensor for onnx export
|
| 585 |
+
# init diffusion with correct num steps before looping
|
| 586 |
+
use_timesteps = self.diffusion.space_timesteps(num_denoising_steps[0])[0]
|
| 587 |
+
self.diffusion.calc_diffusion_vars(use_timesteps)
|
| 588 |
+
for i in progress_bar(indices):
|
| 589 |
+
t = torch.tensor([i] * cur_mot.size(0), device=self.device)
|
| 590 |
+
with torch.inference_mode():
|
| 591 |
+
cur_mot = self.denoising_step(
|
| 592 |
+
cur_mot,
|
| 593 |
+
pad_mask,
|
| 594 |
+
text_feat,
|
| 595 |
+
text_pad_mask,
|
| 596 |
+
t,
|
| 597 |
+
first_heading_angle,
|
| 598 |
+
motion_mask,
|
| 599 |
+
observed_motion,
|
| 600 |
+
num_denoising_steps,
|
| 601 |
+
cfg_weight,
|
| 602 |
+
guide_masks=guide_masks,
|
| 603 |
+
cfg_type=cfg_type,
|
| 604 |
+
)
|
| 605 |
+
return cur_mot
|
kimodo/model/llm2vec/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
This is a patched version of the original [LLM2Vec](https://github.com/McGill-NLP/llm2vec) codebase so that `McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised` works with `transformers==5.0.0rc3`.
|
kimodo/model/llm2vec/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""LLM2Vec text encoder and wrapper for Kimodo."""
|
| 4 |
+
|
| 5 |
+
from .llm2vec import LLM2Vec
|
| 6 |
+
from .llm2vec_wrapper import LLM2VecEncoder
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"LLM2Vec",
|
| 10 |
+
"LLM2VecEncoder",
|
| 11 |
+
]
|
kimodo/model/llm2vec/llm2vec.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
#
|
| 4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
| 5 |
+
# copy of this software and associated documentation files (the "Software"),
|
| 6 |
+
# to deal in the Software without restriction, including without limitation
|
| 7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
| 8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
| 9 |
+
# Software is furnished to do so, subject to the following conditions:
|
| 10 |
+
#
|
| 11 |
+
# The above copyright notice and this permission notice shall be included in
|
| 12 |
+
# all copies or substantial portions of the Software.
|
| 13 |
+
#
|
| 14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
| 17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
| 19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
| 20 |
+
# DEALINGS IN THE SOFTWARE.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 24 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 25 |
+
#
|
| 26 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 27 |
+
# you may not use this file except in compliance with the License.
|
| 28 |
+
# You may obtain a copy of the License at
|
| 29 |
+
#
|
| 30 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 31 |
+
#
|
| 32 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 33 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 34 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 35 |
+
# See the License for the specific language governing permissions and
|
| 36 |
+
# limitations under the License.
|
| 37 |
+
|
| 38 |
+
import json
|
| 39 |
+
import logging
|
| 40 |
+
import os
|
| 41 |
+
from functools import partial
|
| 42 |
+
from typing import Dict, List, Optional, Union
|
| 43 |
+
|
| 44 |
+
import numpy as np
|
| 45 |
+
import torch
|
| 46 |
+
import torch.multiprocessing as mp
|
| 47 |
+
from peft import PeftModel
|
| 48 |
+
from torch import Tensor, device, nn
|
| 49 |
+
from tqdm.autonotebook import tqdm, trange
|
| 50 |
+
from transformers import (
|
| 51 |
+
AutoConfig,
|
| 52 |
+
AutoModel,
|
| 53 |
+
AutoTokenizer,
|
| 54 |
+
GemmaConfig,
|
| 55 |
+
LlamaConfig,
|
| 56 |
+
MistralConfig,
|
| 57 |
+
PretrainedConfig,
|
| 58 |
+
Qwen2Config,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
logger = logging.getLogger(__name__)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def batch_to_device(batch, target_device: device):
|
| 65 |
+
"""Send a pytorch batch to a device (CPU/GPU)"""
|
| 66 |
+
for key in batch:
|
| 67 |
+
if isinstance(batch[key], Tensor):
|
| 68 |
+
batch[key] = batch[key].to(target_device)
|
| 69 |
+
return batch
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class LLM2Vec(nn.Module):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
model: AutoModel,
|
| 76 |
+
tokenizer: AutoTokenizer,
|
| 77 |
+
pooling_mode: str = "mean",
|
| 78 |
+
max_length: int = 512,
|
| 79 |
+
doc_max_length: int = 400,
|
| 80 |
+
skip_instruction: bool = True,
|
| 81 |
+
):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.model = model
|
| 84 |
+
self.tokenizer = tokenizer
|
| 85 |
+
self.pooling_mode = pooling_mode
|
| 86 |
+
self.skip_instruction = skip_instruction
|
| 87 |
+
self.max_length = max_length
|
| 88 |
+
self.doc_max_length = doc_max_length
|
| 89 |
+
self.config = model.config
|
| 90 |
+
|
| 91 |
+
@classmethod
|
| 92 |
+
def _get_model_class(cls, config_class_name, enable_bidirectional):
|
| 93 |
+
if not enable_bidirectional:
|
| 94 |
+
return AutoModel
|
| 95 |
+
if config_class_name == "MistralConfig":
|
| 96 |
+
from .models.bidirectional_mistral import MistralBiModel
|
| 97 |
+
|
| 98 |
+
return MistralBiModel
|
| 99 |
+
elif config_class_name == "LlamaConfig":
|
| 100 |
+
from .models.bidirectional_llama import LlamaBiModel
|
| 101 |
+
|
| 102 |
+
return LlamaBiModel
|
| 103 |
+
elif config_class_name == "GemmaConfig":
|
| 104 |
+
from .models.bidirectional_gemma import GemmaBiModel
|
| 105 |
+
|
| 106 |
+
return GemmaBiModel
|
| 107 |
+
elif config_class_name == "Qwen2Config":
|
| 108 |
+
from .models.bidirectional_qwen2 import Qwen2BiModel
|
| 109 |
+
|
| 110 |
+
return Qwen2BiModel
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"{config_class_name} is not supported yet with bidirectional models.")
|
| 113 |
+
|
| 114 |
+
@classmethod
|
| 115 |
+
def from_pretrained(
|
| 116 |
+
cls,
|
| 117 |
+
base_model_name_or_path,
|
| 118 |
+
peft_model_name_or_path=None,
|
| 119 |
+
merge_peft=False,
|
| 120 |
+
enable_bidirectional=True,
|
| 121 |
+
**kwargs,
|
| 122 |
+
):
|
| 123 |
+
# pop out encoder args
|
| 124 |
+
keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"]
|
| 125 |
+
encoder_args = {key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None}
|
| 126 |
+
|
| 127 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
|
| 128 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 129 |
+
tokenizer.padding_side = "left"
|
| 130 |
+
|
| 131 |
+
config = AutoConfig.from_pretrained(base_model_name_or_path)
|
| 132 |
+
config_class_name = config.__class__.__name__
|
| 133 |
+
|
| 134 |
+
model_class = cls._get_model_class(config_class_name, enable_bidirectional=enable_bidirectional)
|
| 135 |
+
|
| 136 |
+
model = model_class.from_pretrained(base_model_name_or_path, **kwargs)
|
| 137 |
+
|
| 138 |
+
if os.path.isdir(base_model_name_or_path) and os.path.exists(f"{base_model_name_or_path}/config.json"):
|
| 139 |
+
with open(f"{base_model_name_or_path}/config.json", "r") as fIn:
|
| 140 |
+
config_dict = json.load(fIn)
|
| 141 |
+
config = PretrainedConfig.from_dict(config_dict)
|
| 142 |
+
model.config._name_or_path = config._name_or_path
|
| 143 |
+
|
| 144 |
+
# For special case where config.json and adapter weights are in the same directory
|
| 145 |
+
if hasattr(model, "peft_config"):
|
| 146 |
+
model = PeftModel.from_pretrained(
|
| 147 |
+
model,
|
| 148 |
+
base_model_name_or_path,
|
| 149 |
+
)
|
| 150 |
+
model = model.merge_and_unload()
|
| 151 |
+
|
| 152 |
+
if peft_model_name_or_path is not None:
|
| 153 |
+
model = PeftModel.from_pretrained(
|
| 154 |
+
model,
|
| 155 |
+
peft_model_name_or_path,
|
| 156 |
+
)
|
| 157 |
+
if merge_peft:
|
| 158 |
+
model = model.merge_and_unload()
|
| 159 |
+
|
| 160 |
+
config = {}
|
| 161 |
+
config_addr = peft_model_name_or_path if peft_model_name_or_path is not None else base_model_name_or_path
|
| 162 |
+
if os.path.exists(f"{config_addr}/llm2vec_config.json"):
|
| 163 |
+
with open(f"{config_addr}/llm2vec_config.json", "r") as fIn:
|
| 164 |
+
llm2vec_config = json.load(fIn)
|
| 165 |
+
config.update(llm2vec_config)
|
| 166 |
+
|
| 167 |
+
for key, value in encoder_args.items():
|
| 168 |
+
config[key] = value
|
| 169 |
+
|
| 170 |
+
return cls(model=model, tokenizer=tokenizer, **config)
|
| 171 |
+
|
| 172 |
+
def prepare_for_tokenization(self, text):
|
| 173 |
+
if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct":
|
| 174 |
+
text = "<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>"
|
| 175 |
+
return text
|
| 176 |
+
if self.model.config._name_or_path in [
|
| 177 |
+
"mistralai/Mistral-7B-Instruct-v0.2",
|
| 178 |
+
"meta-llama/Llama-2-7b-chat-hf",
|
| 179 |
+
]:
|
| 180 |
+
text = "[INST] " + text.strip() + " [/INST]"
|
| 181 |
+
if self.model.config._name_or_path in [
|
| 182 |
+
"google/gemma-2-9b-it",
|
| 183 |
+
]:
|
| 184 |
+
text = "<bos><start_of_turn>user\n" + text.strip() + "<end_of_turn>"
|
| 185 |
+
if self.model.config._name_or_path in [
|
| 186 |
+
"Qwen/Qwen2-1.5B-Instruct",
|
| 187 |
+
"Qwen/Qwen2-7B-Instruct",
|
| 188 |
+
]:
|
| 189 |
+
text = "<|im_start|>user\n" + text.strip() + "<|im_end|>"
|
| 190 |
+
if self.pooling_mode == "eos_token":
|
| 191 |
+
if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B":
|
| 192 |
+
text = text.strip() + "<|end_of_text|>"
|
| 193 |
+
elif isinstance(self.model.config, LlamaConfig) or isinstance(self.model.config, MistralConfig):
|
| 194 |
+
text = text.strip() + " </s>"
|
| 195 |
+
elif isinstance(self.model.config, GemmaConfig):
|
| 196 |
+
text = text.strip() + "<eos>"
|
| 197 |
+
elif isinstance(self.model.config, Qwen2Config):
|
| 198 |
+
text = text.strip() + "<|endoftext|>"
|
| 199 |
+
return text
|
| 200 |
+
|
| 201 |
+
def tokenize(self, texts):
|
| 202 |
+
texts_2 = []
|
| 203 |
+
original_texts = []
|
| 204 |
+
for text in texts:
|
| 205 |
+
t = text.split("!@#$%^&*()")
|
| 206 |
+
texts_2.append(t[1] if len(t) > 1 else "")
|
| 207 |
+
original_texts.append("".join(t))
|
| 208 |
+
|
| 209 |
+
original = self.tokenizer(
|
| 210 |
+
original_texts,
|
| 211 |
+
return_tensors="pt",
|
| 212 |
+
padding=True,
|
| 213 |
+
truncation=True,
|
| 214 |
+
max_length=self.max_length,
|
| 215 |
+
)
|
| 216 |
+
embed_mask = None
|
| 217 |
+
for t_i, t in enumerate(texts_2):
|
| 218 |
+
ids = self.tokenizer(
|
| 219 |
+
[t],
|
| 220 |
+
return_tensors="pt",
|
| 221 |
+
padding=True,
|
| 222 |
+
truncation=True,
|
| 223 |
+
max_length=self.max_length,
|
| 224 |
+
add_special_tokens=False,
|
| 225 |
+
)
|
| 226 |
+
if embed_mask is None:
|
| 227 |
+
e_m = torch.zeros_like(original["attention_mask"][t_i])
|
| 228 |
+
if len(ids["input_ids"][0]) > 0:
|
| 229 |
+
e_m[-len(ids["input_ids"][0]) :] = torch.ones(len(ids["input_ids"][0]))
|
| 230 |
+
embed_mask = e_m.unsqueeze(0)
|
| 231 |
+
else:
|
| 232 |
+
e_m = torch.zeros_like(original["attention_mask"][t_i])
|
| 233 |
+
if len(ids["input_ids"][0]) > 0:
|
| 234 |
+
e_m[-len(ids["input_ids"][0]) :] = torch.ones(len(ids["input_ids"][0]))
|
| 235 |
+
embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0)
|
| 236 |
+
|
| 237 |
+
original["embed_mask"] = embed_mask
|
| 238 |
+
return original
|
| 239 |
+
|
| 240 |
+
def _skip_instruction(self, sentence_feature):
|
| 241 |
+
assert sentence_feature["attention_mask"].shape == sentence_feature["embed_mask"].shape
|
| 242 |
+
sentence_feature["attention_mask"] = sentence_feature["embed_mask"]
|
| 243 |
+
|
| 244 |
+
def forward(self, sentence_feature: Dict[str, Tensor]):
|
| 245 |
+
embed_mask = None
|
| 246 |
+
if "embed_mask" in sentence_feature:
|
| 247 |
+
embed_mask = sentence_feature.pop("embed_mask")
|
| 248 |
+
reps = self.model(**sentence_feature)
|
| 249 |
+
sentence_feature["embed_mask"] = embed_mask
|
| 250 |
+
|
| 251 |
+
return self.get_pooling(sentence_feature, reps.last_hidden_state)
|
| 252 |
+
|
| 253 |
+
def get_pooling(self, features, last_hidden_states): # All models padded from left
|
| 254 |
+
assert self.tokenizer.padding_side == "left", "Pooling modes are implemented for padding from left."
|
| 255 |
+
if self.skip_instruction:
|
| 256 |
+
self._skip_instruction(features)
|
| 257 |
+
seq_lengths = features["attention_mask"].sum(dim=-1)
|
| 258 |
+
if self.pooling_mode == "mean":
|
| 259 |
+
return torch.stack(
|
| 260 |
+
[last_hidden_states[i, -length:, :].mean(dim=0) for i, length in enumerate(seq_lengths)],
|
| 261 |
+
dim=0,
|
| 262 |
+
)
|
| 263 |
+
elif self.pooling_mode == "weighted_mean":
|
| 264 |
+
bs, l, _ = last_hidden_states.shape
|
| 265 |
+
complete_weights = torch.zeros(bs, l, device=last_hidden_states.device)
|
| 266 |
+
for i, seq_l in enumerate(seq_lengths):
|
| 267 |
+
if seq_l > 0:
|
| 268 |
+
complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1
|
| 269 |
+
complete_weights[i] /= torch.clamp(complete_weights[i].sum(), min=1e-9)
|
| 270 |
+
return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1)
|
| 271 |
+
elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token":
|
| 272 |
+
return last_hidden_states[:, -1]
|
| 273 |
+
elif self.pooling_mode == "bos_token":
|
| 274 |
+
return last_hidden_states[features["input_ids"] == self.tokenizer.bos_token_id]
|
| 275 |
+
else:
|
| 276 |
+
raise ValueError(f"{self.pooling_mode} is not implemented yet.")
|
| 277 |
+
|
| 278 |
+
def _convert_to_str(self, instruction, text):
|
| 279 |
+
tokenized_q = self.tokenizer(
|
| 280 |
+
text,
|
| 281 |
+
return_tensors="pt",
|
| 282 |
+
padding=True,
|
| 283 |
+
truncation=True,
|
| 284 |
+
max_length=self.max_length,
|
| 285 |
+
add_special_tokens=False,
|
| 286 |
+
)
|
| 287 |
+
tokenized_q_length = len(tokenized_q["input_ids"][0])
|
| 288 |
+
|
| 289 |
+
while tokenized_q_length > self.doc_max_length:
|
| 290 |
+
reduction_ratio = self.doc_max_length / tokenized_q_length
|
| 291 |
+
reduced_length = int(len(text.split()) * reduction_ratio)
|
| 292 |
+
text = " ".join(text.split()[:reduced_length])
|
| 293 |
+
tokenized_q = self.tokenizer(
|
| 294 |
+
text,
|
| 295 |
+
return_tensors="pt",
|
| 296 |
+
padding=True,
|
| 297 |
+
truncation=True,
|
| 298 |
+
max_length=self.max_length,
|
| 299 |
+
add_special_tokens=False,
|
| 300 |
+
)
|
| 301 |
+
tokenized_q_length = len(tokenized_q["input_ids"][0])
|
| 302 |
+
|
| 303 |
+
return f"{instruction.strip()} !@#$%^&*(){text}" if instruction else f"!@#$%^&*(){text}"
|
| 304 |
+
|
| 305 |
+
def encode(
|
| 306 |
+
self,
|
| 307 |
+
sentences: Union[str, List[str]],
|
| 308 |
+
batch_size: int = 32,
|
| 309 |
+
show_progress_bar: bool = True,
|
| 310 |
+
convert_to_numpy: bool = False,
|
| 311 |
+
convert_to_tensor: bool = False,
|
| 312 |
+
device: Optional[str] = None,
|
| 313 |
+
):
|
| 314 |
+
"""
|
| 315 |
+
Encode a list of sentences to their respective embeddings. The sentences can be a list of strings or a string.
|
| 316 |
+
Args:
|
| 317 |
+
sentences: sentence or sentences to encode.
|
| 318 |
+
batch_size: batch size for turning sentence tokens into embeddings.
|
| 319 |
+
show_progress_bar: whether to show progress bars during encoding steps.
|
| 320 |
+
convert_to_numpy: If true, return numpy arrays instead of torch tensors.
|
| 321 |
+
convert_to_tensor: If true, return torch tensors (default).
|
| 322 |
+
device: torch backend device identifier (e.g., 'cuda', 'cpu','mps' etc.). If not specified,
|
| 323 |
+
the default is to use cuda when available, otherwise cpu. Note that only the choice of 'cuda' supports
|
| 324 |
+
multiprocessing as currently implemented.
|
| 325 |
+
|
| 326 |
+
Returns: embeddings of the sentences. Embeddings are detached and always on the CPU (see _encode implementation).
|
| 327 |
+
|
| 328 |
+
"""
|
| 329 |
+
if isinstance(sentences[0], str) and isinstance(sentences[-1], int):
|
| 330 |
+
sentences = [sentences]
|
| 331 |
+
# required for MEDI version of MTEB
|
| 332 |
+
if isinstance(sentences[0], str):
|
| 333 |
+
sentences = [[""] + [sentence] for sentence in sentences]
|
| 334 |
+
|
| 335 |
+
if device is None:
|
| 336 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 337 |
+
|
| 338 |
+
concatenated_input_texts = []
|
| 339 |
+
for sentence in sentences:
|
| 340 |
+
assert isinstance(sentence[0], str)
|
| 341 |
+
assert isinstance(sentence[1], str)
|
| 342 |
+
concatenated_input_texts.append(self._convert_to_str(sentence[0], sentence[1]))
|
| 343 |
+
sentences = concatenated_input_texts
|
| 344 |
+
|
| 345 |
+
self.eval()
|
| 346 |
+
|
| 347 |
+
if convert_to_tensor:
|
| 348 |
+
convert_to_numpy = False
|
| 349 |
+
|
| 350 |
+
length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
|
| 351 |
+
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
|
| 352 |
+
all_embeddings = []
|
| 353 |
+
|
| 354 |
+
if torch.cuda.device_count() <= 1:
|
| 355 |
+
# This branch also support mps devices
|
| 356 |
+
self.to(device)
|
| 357 |
+
for start_index in trange(
|
| 358 |
+
0,
|
| 359 |
+
len(sentences),
|
| 360 |
+
batch_size,
|
| 361 |
+
desc="Batches",
|
| 362 |
+
disable=not show_progress_bar,
|
| 363 |
+
):
|
| 364 |
+
sentences_batch = sentences_sorted[start_index : start_index + batch_size]
|
| 365 |
+
embeddings = self._encode(sentences_batch, device=device, convert_to_numpy=convert_to_numpy)
|
| 366 |
+
all_embeddings.append(embeddings)
|
| 367 |
+
else:
|
| 368 |
+
num_proc = torch.cuda.device_count()
|
| 369 |
+
cuda_compatible_multiprocess = mp.get_context("spawn")
|
| 370 |
+
with cuda_compatible_multiprocess.Pool(num_proc) as p:
|
| 371 |
+
sentences_batches = [
|
| 372 |
+
sentences_sorted[start_index : start_index + batch_size]
|
| 373 |
+
for start_index in range(0, len(sentences), batch_size)
|
| 374 |
+
]
|
| 375 |
+
|
| 376 |
+
progress_bar = tqdm(
|
| 377 |
+
total=len(sentences_batches),
|
| 378 |
+
desc="Batches",
|
| 379 |
+
disable=not show_progress_bar,
|
| 380 |
+
)
|
| 381 |
+
results = []
|
| 382 |
+
|
| 383 |
+
def update(*args):
|
| 384 |
+
progress_bar.update()
|
| 385 |
+
|
| 386 |
+
for batch in sentences_batches:
|
| 387 |
+
results.append(
|
| 388 |
+
p.apply_async(
|
| 389 |
+
self._encode,
|
| 390 |
+
args=(batch, None, convert_to_numpy, True),
|
| 391 |
+
callback=update,
|
| 392 |
+
)
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
all_embeddings = [result.get() for result in results]
|
| 396 |
+
progress_bar.close()
|
| 397 |
+
|
| 398 |
+
all_embeddings = torch.cat(all_embeddings, dim=0)
|
| 399 |
+
all_embeddings = all_embeddings[np.argsort(length_sorted_idx)]
|
| 400 |
+
all_embeddings = all_embeddings.to(torch.float32)
|
| 401 |
+
if convert_to_numpy:
|
| 402 |
+
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
| 403 |
+
return all_embeddings
|
| 404 |
+
|
| 405 |
+
def save(self, output_path, merge_before_save=False, save_config=True):
|
| 406 |
+
if merge_before_save and isinstance(self.model, PeftModel):
|
| 407 |
+
self.model = self.model.merge_and_unload()
|
| 408 |
+
# Fixes the issue of saving - https://huggingface.co/McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse/discussions/1
|
| 409 |
+
if hasattr(self.model, "_hf_peft_config_loaded"):
|
| 410 |
+
self.model._hf_peft_config_loaded = False
|
| 411 |
+
|
| 412 |
+
self.model.save_pretrained(output_path)
|
| 413 |
+
self.tokenizer.save_pretrained(output_path)
|
| 414 |
+
|
| 415 |
+
llm2vec_config = {
|
| 416 |
+
"pooling_mode": self.pooling_mode,
|
| 417 |
+
"max_length": self.max_length,
|
| 418 |
+
"doc_max_length": self.doc_max_length,
|
| 419 |
+
"skip_instruction": self.skip_instruction,
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
if save_config:
|
| 423 |
+
os.makedirs(output_path, exist_ok=True)
|
| 424 |
+
with open(f"{output_path}/llm2vec_config.json", "w") as fOut:
|
| 425 |
+
json.dump(llm2vec_config, fOut, indent=4)
|
| 426 |
+
|
| 427 |
+
def _encode(
|
| 428 |
+
self,
|
| 429 |
+
sentences_batch,
|
| 430 |
+
device: Optional[str] = None,
|
| 431 |
+
convert_to_numpy: bool = False,
|
| 432 |
+
multiprocessing=False,
|
| 433 |
+
):
|
| 434 |
+
if multiprocessing:
|
| 435 |
+
# multiprocessing only supports CUDA devices at this time, so we ignore the value of device
|
| 436 |
+
# and use cuda:rank for the device
|
| 437 |
+
rank = mp.current_process()._identity[0]
|
| 438 |
+
if device is None and torch.cuda.is_available():
|
| 439 |
+
device = f"cuda:{rank % torch.cuda.device_count()}"
|
| 440 |
+
|
| 441 |
+
self.to(device)
|
| 442 |
+
features = self.tokenize([self.prepare_for_tokenization(sentence) for sentence in sentences_batch])
|
| 443 |
+
features = batch_to_device(features, device)
|
| 444 |
+
|
| 445 |
+
with torch.no_grad():
|
| 446 |
+
embeddings = self.forward(features)
|
| 447 |
+
embeddings = embeddings.detach()
|
| 448 |
+
embeddings = embeddings.cpu()
|
| 449 |
+
|
| 450 |
+
return embeddings
|
| 451 |
+
|
| 452 |
+
def _text_length(self, text: Union[List[int], List[List[int]]]):
|
| 453 |
+
"""Help function to get the length for the input text.
|
| 454 |
+
|
| 455 |
+
Text can be either a string (which means a single text) a list of ints (which means a single
|
| 456 |
+
tokenized text), or a tuple of list of ints (representing several text inputs to the model).
|
| 457 |
+
"""
|
| 458 |
+
if (
|
| 459 |
+
isinstance(text, str) or (isinstance(text, list) and isinstance(text[0], int)) or len(text) == 0
|
| 460 |
+
): # Single text, list of ints, or empty
|
| 461 |
+
return len(text)
|
| 462 |
+
if isinstance(text, dict): # {key: value} case
|
| 463 |
+
return len(next(iter(text.values())))
|
| 464 |
+
elif not hasattr(text, "__len__"): # Object has no len() method
|
| 465 |
+
return 1
|
| 466 |
+
else:
|
| 467 |
+
return sum([len(t) for t in text])
|
| 468 |
+
|
| 469 |
+
def resize_token_embeddings(
|
| 470 |
+
self,
|
| 471 |
+
new_num_tokens: Optional[int] = None,
|
| 472 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 473 |
+
) -> nn.Embedding:
|
| 474 |
+
return self.model.resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of)
|
| 475 |
+
|
| 476 |
+
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
| 477 |
+
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
|
kimodo/model/llm2vec/llm2vec_wrapper.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""LLM2Vec encoder wrapper for Kimodo text conditioning."""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from .llm2vec import LLM2Vec
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LLM2VecEncoder:
|
| 14 |
+
"""LLM2Vec text embeddings."""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
base_model_name_or_path: str,
|
| 19 |
+
peft_model_name_or_path: str,
|
| 20 |
+
dtype: str,
|
| 21 |
+
llm_dim: int,
|
| 22 |
+
) -> None:
|
| 23 |
+
torch_dtype = getattr(torch, dtype)
|
| 24 |
+
self.llm_dim = llm_dim
|
| 25 |
+
|
| 26 |
+
cache_dir = os.environ.get("HUGGINGFACE_CACHE_DIR")
|
| 27 |
+
|
| 28 |
+
if "TEXT_ENCODERS_DIR" in os.environ:
|
| 29 |
+
base_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], base_model_name_or_path)
|
| 30 |
+
peft_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], peft_model_name_or_path)
|
| 31 |
+
|
| 32 |
+
self.model = LLM2Vec.from_pretrained(
|
| 33 |
+
base_model_name_or_path=base_model_name_or_path,
|
| 34 |
+
peft_model_name_or_path=peft_model_name_or_path,
|
| 35 |
+
torch_dtype=torch_dtype,
|
| 36 |
+
cache_dir=cache_dir,
|
| 37 |
+
)
|
| 38 |
+
self.model.eval()
|
| 39 |
+
for p in self.model.parameters():
|
| 40 |
+
p.requires_grad = False
|
| 41 |
+
|
| 42 |
+
def to(self, device: torch.device):
|
| 43 |
+
self.model = self.model.to(device)
|
| 44 |
+
return self
|
| 45 |
+
|
| 46 |
+
def eval(self):
|
| 47 |
+
self.model.eval()
|
| 48 |
+
return self
|
| 49 |
+
|
| 50 |
+
def get_device(self):
|
| 51 |
+
return self.model.model.device
|
| 52 |
+
|
| 53 |
+
def __call__(self, text: list[str] | str):
|
| 54 |
+
is_string = False
|
| 55 |
+
if isinstance(text, str):
|
| 56 |
+
text = [text]
|
| 57 |
+
is_string = True
|
| 58 |
+
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
encoded_text = self.model.encode(text, batch_size=len(text), show_progress_bar=False)
|
| 61 |
+
|
| 62 |
+
assert len(encoded_text.shape)
|
| 63 |
+
assert self.llm_dim == encoded_text.shape[-1]
|
| 64 |
+
|
| 65 |
+
encoded_text = encoded_text[:, None]
|
| 66 |
+
lengths = np.ones(len(encoded_text), dtype=int).tolist()
|
| 67 |
+
|
| 68 |
+
if is_string:
|
| 69 |
+
encoded_text = encoded_text[0]
|
| 70 |
+
lengths = lengths[0]
|
| 71 |
+
|
| 72 |
+
encoded_text = torch.tensor(encoded_text).to(self.get_device())
|
| 73 |
+
return encoded_text, lengths
|
kimodo/model/llm2vec/models/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from .bidirectional_gemma import GemmaBiForMNTP, GemmaBiModel
|
| 2 |
+
# from .bidirectional_llama import LlamaBiForMNTP, LlamaBiModel
|
| 3 |
+
# from .bidirectional_mistral import MistralBiForMNTP, MistralBiModel
|
| 4 |
+
# from .bidirectional_qwen2 import Qwen2BiForMNTP, Qwen2BiModel
|
kimodo/model/llm2vec/models/attn_mask_utils.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
#
|
| 4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
| 5 |
+
# copy of this software and associated documentation files (the "Software"),
|
| 6 |
+
# to deal in the Software without restriction, including without limitation
|
| 7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
| 8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
| 9 |
+
# Software is furnished to do so, subject to the following conditions:
|
| 10 |
+
#
|
| 11 |
+
# The above copyright notice and this permission notice shall be included in
|
| 12 |
+
# all copies or substantial portions of the Software.
|
| 13 |
+
#
|
| 14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
| 17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
| 19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
| 20 |
+
# DEALINGS IN THE SOFTWARE.
|
| 21 |
+
|
| 22 |
+
from typing import List, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _prepare_4d_causal_attention_mask(
|
| 29 |
+
attention_mask: Optional[torch.Tensor],
|
| 30 |
+
input_shape: Union[torch.Size, Tuple, List],
|
| 31 |
+
inputs_embeds: torch.Tensor,
|
| 32 |
+
past_key_values_length: int,
|
| 33 |
+
sliding_window: Optional[int] = None,
|
| 34 |
+
):
|
| 35 |
+
"""Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D
|
| 36 |
+
mask of shape `(batch_size, key_value_length)`
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
attention_mask (`torch.Tensor` or `None`):
|
| 40 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
| 41 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
| 42 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
| 43 |
+
inputs_embeds (`torch.Tensor`):
|
| 44 |
+
The embedded inputs as a torch Tensor.
|
| 45 |
+
past_key_values_length (`int`):
|
| 46 |
+
The length of the key value cache.
|
| 47 |
+
sliding_window (`int`, *optional*):
|
| 48 |
+
If the model uses windowed attention, a sliding window should be passed.
|
| 49 |
+
"""
|
| 50 |
+
attn_mask_converter = AttentionMaskConverter(
|
| 51 |
+
is_causal=False, sliding_window=sliding_window
|
| 52 |
+
) # is_causal=True in original implementation
|
| 53 |
+
|
| 54 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
| 55 |
+
|
| 56 |
+
# 4d mask is passed through the layers
|
| 57 |
+
if attention_mask is not None and len(attention_mask.shape) == 2:
|
| 58 |
+
attention_mask = attn_mask_converter.to_4d(
|
| 59 |
+
attention_mask,
|
| 60 |
+
input_shape[-1],
|
| 61 |
+
key_value_length=key_value_length,
|
| 62 |
+
dtype=inputs_embeds.dtype,
|
| 63 |
+
)
|
| 64 |
+
elif attention_mask is not None and len(attention_mask.shape) == 4:
|
| 65 |
+
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
| 66 |
+
if tuple(attention_mask.shape) != expected_shape:
|
| 67 |
+
raise ValueError(
|
| 68 |
+
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
| 72 |
+
inverted_mask = 1.0 - attention_mask
|
| 73 |
+
attention_mask = inverted_mask.masked_fill(
|
| 74 |
+
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
| 78 |
+
input_shape[0],
|
| 79 |
+
input_shape[-1],
|
| 80 |
+
key_value_length,
|
| 81 |
+
dtype=inputs_embeds.dtype,
|
| 82 |
+
device=inputs_embeds.device,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
return attention_mask
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# Adapted from _prepare_4d_causal_attention_mask
|
| 89 |
+
def _prepare_4d_causal_attention_mask_for_sdpa(
|
| 90 |
+
attention_mask: Optional[torch.Tensor],
|
| 91 |
+
input_shape: Union[torch.Size, Tuple, List],
|
| 92 |
+
inputs_embeds: torch.Tensor,
|
| 93 |
+
past_key_values_length: int,
|
| 94 |
+
sliding_window: Optional[int] = None,
|
| 95 |
+
):
|
| 96 |
+
"""Prepares the correct `attn_mask` argument to be used by
|
| 97 |
+
`torch.nn.functional.scaled_dot_product_attention`.
|
| 98 |
+
|
| 99 |
+
In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
|
| 100 |
+
`key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
|
| 101 |
+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
|
| 102 |
+
"""
|
| 103 |
+
attn_mask_converter = AttentionMaskConverter(
|
| 104 |
+
is_causal=False, sliding_window=sliding_window
|
| 105 |
+
) # is_causal=True in original implementation
|
| 106 |
+
|
| 107 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
| 108 |
+
batch_size, query_length = input_shape
|
| 109 |
+
|
| 110 |
+
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
| 111 |
+
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
| 112 |
+
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
| 113 |
+
is_tracing = (
|
| 114 |
+
torch.jit.is_tracing()
|
| 115 |
+
or isinstance(inputs_embeds, torch.fx.Proxy)
|
| 116 |
+
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if attention_mask is not None:
|
| 120 |
+
# 4d mask is passed through
|
| 121 |
+
if len(attention_mask.shape) == 4:
|
| 122 |
+
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
| 123 |
+
if tuple(attention_mask.shape) != expected_shape:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
| 129 |
+
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
|
| 130 |
+
attention_mask = inverted_mask.masked_fill(
|
| 131 |
+
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
| 132 |
+
)
|
| 133 |
+
return attention_mask
|
| 134 |
+
|
| 135 |
+
elif not is_tracing and torch.all(attention_mask == 1):
|
| 136 |
+
if query_length == 1:
|
| 137 |
+
# For query_length == 1, causal attention and bi-directional attention are the same.
|
| 138 |
+
attention_mask = None
|
| 139 |
+
elif key_value_length == query_length:
|
| 140 |
+
attention_mask = None
|
| 141 |
+
else:
|
| 142 |
+
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
|
| 143 |
+
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
| 144 |
+
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
| 145 |
+
pass
|
| 146 |
+
elif query_length > 1 and key_value_length != query_length:
|
| 147 |
+
# See the comment above (https://github.com/pytorch/pytorch/issues/108108).
|
| 148 |
+
# Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
|
| 149 |
+
attention_mask = True
|
| 150 |
+
elif is_tracing:
|
| 151 |
+
raise ValueError(
|
| 152 |
+
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
if attention_mask is None:
|
| 156 |
+
expanded_4d_mask = None
|
| 157 |
+
elif attention_mask is True:
|
| 158 |
+
expanded_4d_mask = attn_mask_converter.to_causal_4d(
|
| 159 |
+
input_shape[0],
|
| 160 |
+
input_shape[-1],
|
| 161 |
+
key_value_length,
|
| 162 |
+
dtype=inputs_embeds.dtype,
|
| 163 |
+
device=inputs_embeds.device,
|
| 164 |
+
)
|
| 165 |
+
else:
|
| 166 |
+
expanded_4d_mask = attn_mask_converter.to_4d(
|
| 167 |
+
attention_mask,
|
| 168 |
+
input_shape[-1],
|
| 169 |
+
dtype=inputs_embeds.dtype,
|
| 170 |
+
key_value_length=key_value_length,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
|
| 174 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 175 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 176 |
+
if not is_tracing and expanded_4d_mask.device.type == "cuda":
|
| 177 |
+
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
| 178 |
+
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
return expanded_4d_mask
|
kimodo/model/llm2vec/models/bidirectional_llama.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
#
|
| 4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
| 5 |
+
# copy of this software and associated documentation files (the "Software"),
|
| 6 |
+
# to deal in the Software without restriction, including without limitation
|
| 7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
| 8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
| 9 |
+
# Software is furnished to do so, subject to the following conditions:
|
| 10 |
+
#
|
| 11 |
+
# The above copyright notice and this permission notice shall be included in
|
| 12 |
+
# all copies or substantial portions of the Software.
|
| 13 |
+
#
|
| 14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
| 17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
| 19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
| 20 |
+
# DEALINGS IN THE SOFTWARE.
|
| 21 |
+
|
| 22 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 23 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 24 |
+
#
|
| 25 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 26 |
+
# you may not use this file except in compliance with the License.
|
| 27 |
+
# You may obtain a copy of the License at
|
| 28 |
+
#
|
| 29 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 30 |
+
#
|
| 31 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 32 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 33 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 34 |
+
# See the License for the specific language governing permissions and
|
| 35 |
+
# limitations under the License.
|
| 36 |
+
|
| 37 |
+
import torch
|
| 38 |
+
from peft import PeftModel
|
| 39 |
+
from torch import nn
|
| 40 |
+
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel, LlamaPreTrainedModel
|
| 41 |
+
from transformers.cache_utils import Cache, StaticCache
|
| 42 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 43 |
+
from transformers.models.llama.modeling_llama import (
|
| 44 |
+
LlamaAttention,
|
| 45 |
+
LlamaDecoderLayer,
|
| 46 |
+
# LlamaFlashAttention2,
|
| 47 |
+
LlamaMLP,
|
| 48 |
+
LlamaRMSNorm,
|
| 49 |
+
LlamaRotaryEmbedding,
|
| 50 |
+
# LlamaSdpaAttention,
|
| 51 |
+
)
|
| 52 |
+
from transformers.utils import logging
|
| 53 |
+
|
| 54 |
+
from .utils import is_transformers_attn_greater_or_equal_4_43_1
|
| 55 |
+
|
| 56 |
+
logger = logging.get_logger(__name__)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ModifiedLlamaAttention(LlamaAttention):
|
| 60 |
+
def __init__(self, *args, **kwargs):
|
| 61 |
+
super().__init__(*args, **kwargs)
|
| 62 |
+
self.is_causal = False
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# class ModifiedLlamaFlashAttention2(LlamaFlashAttention2):
|
| 66 |
+
# def __init__(self, *args, **kwargs):
|
| 67 |
+
# super().__init__(*args, **kwargs)
|
| 68 |
+
# self.is_causal = False
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# class ModifiedLlamaSdpaAttention(LlamaSdpaAttention):
|
| 72 |
+
# def __init__(self, *args, **kwargs):
|
| 73 |
+
# super().__init__(*args, **kwargs)
|
| 74 |
+
# self.is_causal = False
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# LLAMA_ATTENTION_CLASSES = {
|
| 78 |
+
# "eager": ModifiedLlamaAttention,
|
| 79 |
+
# "flash_attention_2": ModifiedLlamaFlashAttention2,
|
| 80 |
+
# "sdpa": ModifiedLlamaSdpaAttention,
|
| 81 |
+
# }
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class ModifiedLlamaDecoderLayer(LlamaDecoderLayer):
|
| 85 |
+
def __init__(self, config: LlamaConfig, layer_idx: int):
|
| 86 |
+
nn.Module.__init__(self)
|
| 87 |
+
self.hidden_size = config.hidden_size
|
| 88 |
+
|
| 89 |
+
self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx)
|
| 90 |
+
# self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
|
| 91 |
+
# config=config, layer_idx=layer_idx
|
| 92 |
+
# )
|
| 93 |
+
|
| 94 |
+
self.mlp = LlamaMLP(config)
|
| 95 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 96 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class LlamaBiModel(LlamaModel):
|
| 100 |
+
_no_split_modules = ["ModifiedLlamaDecoderLayer"]
|
| 101 |
+
|
| 102 |
+
def __init__(self, config: LlamaConfig):
|
| 103 |
+
if not is_transformers_attn_greater_or_equal_4_43_1():
|
| 104 |
+
raise ValueError(
|
| 105 |
+
"The current implementation of LlamaEncoderModel follows modeling_llama.py of transformers version >= 4.43.1"
|
| 106 |
+
)
|
| 107 |
+
LlamaPreTrainedModel.__init__(self, config)
|
| 108 |
+
self.padding_idx = config.pad_token_id
|
| 109 |
+
self.vocab_size = config.vocab_size
|
| 110 |
+
|
| 111 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 112 |
+
self.layers = nn.ModuleList(
|
| 113 |
+
[ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 114 |
+
)
|
| 115 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 116 |
+
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
| 117 |
+
self.gradient_checkpointing = False
|
| 118 |
+
|
| 119 |
+
# Initialize weights and apply final processing
|
| 120 |
+
self.post_init()
|
| 121 |
+
|
| 122 |
+
def _update_causal_mask(
|
| 123 |
+
self,
|
| 124 |
+
attention_mask,
|
| 125 |
+
input_tensor,
|
| 126 |
+
cache_position,
|
| 127 |
+
past_key_values: Cache,
|
| 128 |
+
output_attentions: bool,
|
| 129 |
+
):
|
| 130 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 131 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 132 |
+
return attention_mask
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 136 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 137 |
+
# to infer the attention mask.
|
| 138 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 139 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 140 |
+
|
| 141 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 142 |
+
# if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
| 143 |
+
# if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 144 |
+
# attention_mask,
|
| 145 |
+
# inputs_embeds=input_tensor,
|
| 146 |
+
# past_key_values_length=past_seen_tokens,
|
| 147 |
+
# is_training=self.training,
|
| 148 |
+
# ):
|
| 149 |
+
# return None
|
| 150 |
+
|
| 151 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
| 152 |
+
min_dtype = torch.finfo(dtype).min
|
| 153 |
+
sequence_length = input_tensor.shape[1]
|
| 154 |
+
if using_static_cache:
|
| 155 |
+
target_length = past_key_values.get_max_length()
|
| 156 |
+
else:
|
| 157 |
+
target_length = (
|
| 158 |
+
attention_mask.shape[-1]
|
| 159 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 160 |
+
else past_seen_tokens + sequence_length + 1
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
causal_mask = torch.zeros(
|
| 164 |
+
(sequence_length, target_length), dtype=dtype, device=device
|
| 165 |
+
) # in original implementation - torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
| 166 |
+
# Commenting out next 2 lines to disable causal masking
|
| 167 |
+
# if sequence_length != 1:
|
| 168 |
+
# causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 169 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 170 |
+
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
| 171 |
+
if attention_mask is not None:
|
| 172 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 173 |
+
if attention_mask.dim() == 2:
|
| 174 |
+
mask_length = attention_mask.shape[-1]
|
| 175 |
+
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
| 176 |
+
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
| 177 |
+
elif attention_mask.dim() == 4:
|
| 178 |
+
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
| 179 |
+
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
| 180 |
+
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
| 181 |
+
offset = cache_position[0]
|
| 182 |
+
else:
|
| 183 |
+
offset = 0
|
| 184 |
+
mask_shape = attention_mask.shape
|
| 185 |
+
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
|
| 186 |
+
causal_mask[
|
| 187 |
+
: mask_shape[0],
|
| 188 |
+
: mask_shape[1],
|
| 189 |
+
offset : mask_shape[2] + offset,
|
| 190 |
+
: mask_shape[3],
|
| 191 |
+
] = mask_slice
|
| 192 |
+
|
| 193 |
+
if (
|
| 194 |
+
self.config._attn_implementation == "sdpa"
|
| 195 |
+
and attention_mask is not None
|
| 196 |
+
and attention_mask.device.type == "cuda"
|
| 197 |
+
and not output_attentions
|
| 198 |
+
):
|
| 199 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 200 |
+
|
| 201 |
+
return causal_mask
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class LlamaBiForMNTP(LlamaForCausalLM):
|
| 205 |
+
def __init__(self, config):
|
| 206 |
+
LlamaPreTrainedModel.__init__(self, config)
|
| 207 |
+
self.model = LlamaBiModel(config)
|
| 208 |
+
self.vocab_size = config.vocab_size
|
| 209 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 210 |
+
|
| 211 |
+
# Initialize weights and apply final processing
|
| 212 |
+
self.post_init()
|
| 213 |
+
|
| 214 |
+
# getter for PEFT model
|
| 215 |
+
def get_model_for_peft(self):
|
| 216 |
+
return self.model
|
| 217 |
+
|
| 218 |
+
# setter for PEFT model
|
| 219 |
+
def set_model_for_peft(self, model: PeftModel):
|
| 220 |
+
self.model = model
|
| 221 |
+
|
| 222 |
+
# save the PEFT model
|
| 223 |
+
def save_peft_model(self, path):
|
| 224 |
+
self.model.save_pretrained(path)
|
kimodo/model/llm2vec/models/utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
#
|
| 4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
| 5 |
+
# copy of this software and associated documentation files (the "Software"),
|
| 6 |
+
# to deal in the Software without restriction, including without limitation
|
| 7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
| 8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
| 9 |
+
# Software is furnished to do so, subject to the following conditions:
|
| 10 |
+
#
|
| 11 |
+
# The above copyright notice and this permission notice shall be included in
|
| 12 |
+
# all copies or substantial portions of the Software.
|
| 13 |
+
#
|
| 14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
| 17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
| 19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
| 20 |
+
# DEALINGS IN THE SOFTWARE.
|
| 21 |
+
|
| 22 |
+
import importlib.metadata
|
| 23 |
+
|
| 24 |
+
from packaging import version
|
| 25 |
+
from transformers.utils.import_utils import _is_package_available
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def is_transformers_attn_greater_or_equal_4_43_1():
|
| 29 |
+
if not _is_package_available("transformers"):
|
| 30 |
+
return False
|
| 31 |
+
|
| 32 |
+
return version.parse(importlib.metadata.version("transformers")) >= version.parse("4.43.1")
|
kimodo/model/load_model.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Load Kimodo diffusion models from local checkpoints or Hugging Face."""
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from huggingface_hub import snapshot_download
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
|
| 11 |
+
from .loading import (
|
| 12 |
+
AVAILABLE_MODELS,
|
| 13 |
+
DEFAULT_MODEL,
|
| 14 |
+
DEFAULT_TEXT_ENCODER_URL,
|
| 15 |
+
MODEL_NAMES,
|
| 16 |
+
TMR_MODELS,
|
| 17 |
+
get_env_var,
|
| 18 |
+
instantiate_from_dict,
|
| 19 |
+
)
|
| 20 |
+
from .registry import get_model_info, resolve_model_name
|
| 21 |
+
|
| 22 |
+
DEFAULT_TEXT_ENCODER = "llm2vec"
|
| 23 |
+
TEXT_ENCODER_PRESETS = {
|
| 24 |
+
"llm2vec": {
|
| 25 |
+
"target": "kimodo.model.LLM2VecEncoder",
|
| 26 |
+
"kwargs": {
|
| 27 |
+
"base_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
|
| 28 |
+
"peft_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised",
|
| 29 |
+
"dtype": "bfloat16",
|
| 30 |
+
"llm_dim": 4096,
|
| 31 |
+
},
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _resolve_hf_model_path(modelname: str) -> Path:
|
| 37 |
+
"""Resolve model name to a local path, using Hugging Face cache or CHECKPOINT_DIR."""
|
| 38 |
+
try:
|
| 39 |
+
repo_id = MODEL_NAMES[modelname]
|
| 40 |
+
except KeyError:
|
| 41 |
+
raise ValueError(f"Model '{modelname}' not found. Available models: {MODEL_NAMES.keys()}")
|
| 42 |
+
|
| 43 |
+
local_cache = get_env_var("LOCAL_CACHE", "False").lower() == "true"
|
| 44 |
+
if not local_cache:
|
| 45 |
+
snapshot_dir = snapshot_download(repo_id=repo_id) # will check online no matter what
|
| 46 |
+
return Path(snapshot_dir)
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
snapshot_dir = snapshot_download(repo_id=repo_id, local_files_only=True) # will check local cache only
|
| 50 |
+
return Path(snapshot_dir)
|
| 51 |
+
except Exception:
|
| 52 |
+
# if local cache is not found, download from online
|
| 53 |
+
try:
|
| 54 |
+
snapshot_dir = snapshot_download(repo_id=repo_id)
|
| 55 |
+
return Path(snapshot_dir)
|
| 56 |
+
except Exception:
|
| 57 |
+
raise RuntimeError(f"Could not resolve model '{modelname}' from Hugging Face (repo: {repo_id}). ") from None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _build_api_text_encoder_conf(text_encoder_url: str) -> dict:
|
| 61 |
+
return {
|
| 62 |
+
"_target_": "kimodo.model.text_encoder_api.TextEncoderAPI",
|
| 63 |
+
"url": text_encoder_url,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _build_local_text_encoder_conf() -> dict:
|
| 68 |
+
text_encoder_name = get_env_var("TEXT_ENCODER", DEFAULT_TEXT_ENCODER)
|
| 69 |
+
if text_encoder_name not in TEXT_ENCODER_PRESETS:
|
| 70 |
+
available = ", ".join(sorted(TEXT_ENCODER_PRESETS))
|
| 71 |
+
raise ValueError(f"Unknown TEXT_ENCODER='{text_encoder_name}'. Available: {available}")
|
| 72 |
+
|
| 73 |
+
preset = TEXT_ENCODER_PRESETS[text_encoder_name]
|
| 74 |
+
return {
|
| 75 |
+
"_target_": preset["target"],
|
| 76 |
+
**preset["kwargs"],
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _select_text_encoder_conf(text_encoder_url: str) -> dict:
|
| 81 |
+
# TEXT_ENCODER_MODE options:
|
| 82 |
+
# - "api": force TextEncoderAPI
|
| 83 |
+
# - "local": force local LLM2VecEncoder
|
| 84 |
+
# - "auto": try API first, fallback to local if unreachable
|
| 85 |
+
mode = get_env_var("TEXT_ENCODER_MODE", "auto").lower()
|
| 86 |
+
if mode == "local":
|
| 87 |
+
return _build_local_text_encoder_conf()
|
| 88 |
+
if mode == "api":
|
| 89 |
+
return _build_api_text_encoder_conf(text_encoder_url)
|
| 90 |
+
|
| 91 |
+
api_conf = _build_api_text_encoder_conf(text_encoder_url)
|
| 92 |
+
try:
|
| 93 |
+
text_encoder = instantiate_from_dict(api_conf)
|
| 94 |
+
# Probe availability early so inference doesn't fail later.
|
| 95 |
+
text_encoder(["healthcheck"])
|
| 96 |
+
return api_conf
|
| 97 |
+
except Exception as error:
|
| 98 |
+
print(
|
| 99 |
+
"Text encoder service is unreachable, falling back to local LLM2Vec "
|
| 100 |
+
f"encoder. ({type(error).__name__}: {error})"
|
| 101 |
+
)
|
| 102 |
+
return _build_local_text_encoder_conf()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def load_model(
|
| 106 |
+
modelname=None,
|
| 107 |
+
device=None,
|
| 108 |
+
eval_mode: bool = True,
|
| 109 |
+
default_family: Optional[str] = "Kimodo",
|
| 110 |
+
return_resolved_name: bool = False,
|
| 111 |
+
):
|
| 112 |
+
"""Load a kimodo model by name (e.g. 'g1', 'soma').
|
| 113 |
+
|
| 114 |
+
Resolution of partial/full names (e.g. Kimodo-SOMA-RP-v1, SOMA) is done
|
| 115 |
+
inside this function using default_family when the name is not a known
|
| 116 |
+
short key.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
modelname: Model identifier; uses DEFAULT_MODEL if None. Can be a short key,
|
| 120 |
+
a full name (e.g. Kimodo-SOMA-RP-v1), or a partial name; unknown names
|
| 121 |
+
are resolved via resolve_model_name using default_family.
|
| 122 |
+
device: Target device for the model (e.g. 'cuda', 'cpu').
|
| 123 |
+
eval_mode: If True, set model to eval mode.
|
| 124 |
+
default_family: Used when modelname is not in AVAILABLE_MODELS to resolve
|
| 125 |
+
partial names ("Kimodo" for demo/generation, "TMR" for embed script).
|
| 126 |
+
Default "Kimodo".
|
| 127 |
+
return_resolved_name: If True, return (model, resolved_short_key). If False,
|
| 128 |
+
return only the model.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Loaded model in eval mode, or (model, resolved short key) if
|
| 132 |
+
return_resolved_name is True.
|
| 133 |
+
|
| 134 |
+
Raises:
|
| 135 |
+
ValueError: If modelname is not in AVAILABLE_MODELS and cannot be resolved.
|
| 136 |
+
FileNotFoundError: If config.yaml is missing in the checkpoint folder.
|
| 137 |
+
"""
|
| 138 |
+
if modelname is None:
|
| 139 |
+
modelname = DEFAULT_MODEL
|
| 140 |
+
if modelname not in AVAILABLE_MODELS:
|
| 141 |
+
if default_family is not None:
|
| 142 |
+
modelname = resolve_model_name(modelname, default_family)
|
| 143 |
+
else:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
f"""The model is not recognized.
|
| 146 |
+
Please choose between: {AVAILABLE_MODELS}"""
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
resolved_modelname = modelname
|
| 150 |
+
|
| 151 |
+
# In case, we specify a custom checkpoint directory
|
| 152 |
+
configured_checkpoint_dir = get_env_var("CHECKPOINT_DIR")
|
| 153 |
+
if configured_checkpoint_dir:
|
| 154 |
+
print(f"CHECKPOINT_DIR is set to {configured_checkpoint_dir}, checking the local cache...")
|
| 155 |
+
# Checkpoint folders are named by display name (e.g. Kimodo-SOMA-RP-v1)
|
| 156 |
+
info = get_model_info(modelname)
|
| 157 |
+
checkpoint_folder_name = info.display_name if info is not None else modelname
|
| 158 |
+
model_path = Path(configured_checkpoint_dir) / checkpoint_folder_name
|
| 159 |
+
if not model_path.exists() and modelname != checkpoint_folder_name:
|
| 160 |
+
# Fallback: try short_key for backward compatibility
|
| 161 |
+
model_path = Path(configured_checkpoint_dir) / modelname
|
| 162 |
+
if not model_path.exists():
|
| 163 |
+
print(f"Model folder not found at '{model_path}', downloading it from Hugging Face...")
|
| 164 |
+
model_path = _resolve_hf_model_path(modelname)
|
| 165 |
+
else:
|
| 166 |
+
# Otherwise, we load the model from the local cache or download it from Hugging Face.
|
| 167 |
+
model_path = _resolve_hf_model_path(modelname)
|
| 168 |
+
|
| 169 |
+
model_config_path = model_path / "config.yaml"
|
| 170 |
+
if not model_config_path.exists():
|
| 171 |
+
raise FileNotFoundError(f"The model checkpoint folder exists but config.yaml is missing: {model_config_path}")
|
| 172 |
+
|
| 173 |
+
model_conf = OmegaConf.load(model_config_path)
|
| 174 |
+
|
| 175 |
+
if modelname in TMR_MODELS:
|
| 176 |
+
# Same process at the moment for TMR and Kimodo
|
| 177 |
+
pass
|
| 178 |
+
|
| 179 |
+
text_encoder_url = get_env_var("TEXT_ENCODER_URL", DEFAULT_TEXT_ENCODER_URL)
|
| 180 |
+
runtime_conf = OmegaConf.create(
|
| 181 |
+
{
|
| 182 |
+
"checkpoint_dir": str(model_path),
|
| 183 |
+
"text_encoder": _select_text_encoder_conf(text_encoder_url),
|
| 184 |
+
}
|
| 185 |
+
)
|
| 186 |
+
model_cfg = OmegaConf.to_container(OmegaConf.merge(model_conf, runtime_conf), resolve=True)
|
| 187 |
+
model_cfg.pop("checkpoint_dir", None)
|
| 188 |
+
|
| 189 |
+
model = instantiate_from_dict(model_cfg, overrides={"device": device})
|
| 190 |
+
if eval_mode:
|
| 191 |
+
model = model.eval()
|
| 192 |
+
if return_resolved_name:
|
| 193 |
+
return model, resolved_modelname
|
| 194 |
+
return model
|
kimodo/model/loading.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Model loading utilities: checkpoints, registry, env, and Hydra-based instantiation."""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, Optional, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from hydra.utils import instantiate
|
| 11 |
+
from omegaconf import OmegaConf
|
| 12 |
+
from safetensors.torch import load_file as load_safetensors
|
| 13 |
+
|
| 14 |
+
from .registry import (
|
| 15 |
+
AVAILABLE_MODELS,
|
| 16 |
+
DEFAULT_MODEL,
|
| 17 |
+
DEFAULT_TEXT_ENCODER_URL,
|
| 18 |
+
KIMODO_MODELS,
|
| 19 |
+
MODEL_NAMES,
|
| 20 |
+
TMR_MODELS,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_env_var(name: str, default: Optional[str] = None) -> Optional[str]:
|
| 25 |
+
"""Return environment variable value, or default if unset/empty."""
|
| 26 |
+
return os.environ.get(name) or default
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def instantiate_from_dict(
|
| 30 |
+
cfg: Dict[str, Any],
|
| 31 |
+
overrides: Optional[Dict[str, Any]] = None,
|
| 32 |
+
):
|
| 33 |
+
"""Instantiate an object from a config dict (e.g. from OmegaConf.to_container).
|
| 34 |
+
|
| 35 |
+
The dict must contain _target_ with a fully qualified class path. Nested configs are
|
| 36 |
+
instantiated recursively.
|
| 37 |
+
"""
|
| 38 |
+
if overrides:
|
| 39 |
+
cfg = {**cfg, **overrides}
|
| 40 |
+
conf = OmegaConf.create(cfg)
|
| 41 |
+
return instantiate(conf)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_checkpoint_state_dict(ckpt_path: Union[str, Path]) -> dict:
|
| 45 |
+
"""Load a state dict from a checkpoint file.
|
| 46 |
+
|
| 47 |
+
If the checkpoint is a dict with a 'state_dict' key (e.g. PyTorch Lightning),
|
| 48 |
+
that is returned; otherwise the whole checkpoint is treated as the state dict.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
ckpt_path: Path to the checkpoint file.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
state_dict suitable for model.load_state_dict().
|
| 55 |
+
"""
|
| 56 |
+
ckpt_path = str(ckpt_path)
|
| 57 |
+
|
| 58 |
+
if ckpt_path.endswith(".safetensors"):
|
| 59 |
+
state_dict = load_safetensors(ckpt_path)
|
| 60 |
+
else:
|
| 61 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 62 |
+
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
| 63 |
+
state_dict = checkpoint["state_dict"]
|
| 64 |
+
elif isinstance(checkpoint, dict):
|
| 65 |
+
state_dict = checkpoint
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError(f"Unsupported checkpoint format: {ckpt_path}")
|
| 68 |
+
return {key: val.detach().cpu() for key, val in state_dict.items()}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
__all__ = [
|
| 72 |
+
"get_env_var",
|
| 73 |
+
"instantiate_from_dict",
|
| 74 |
+
"KIMODO_MODELS",
|
| 75 |
+
"TMR_MODELS",
|
| 76 |
+
"AVAILABLE_MODELS",
|
| 77 |
+
"MODEL_NAMES",
|
| 78 |
+
"DEFAULT_MODEL",
|
| 79 |
+
"DEFAULT_TEXT_ENCODER_URL",
|
| 80 |
+
"load_checkpoint_state_dict",
|
| 81 |
+
]
|
kimodo/model/registry.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Registry of model names and Hugging Face repo IDs for Kimodo and TMR.
|
| 4 |
+
|
| 5 |
+
Canonical source of truth is the list of repo IDs. Short keys (e.g. soma-rp) and metadata (dataset,
|
| 6 |
+
skeleton, version, display name) are derived by parsing.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import re
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
# Canonical list: repo IDs in the same syntax as Hugging Face (org/Model-Name-v1).
|
| 14 |
+
# Parser expects: org/Family-SKELETON-DATASET-version (e.g. Kimodo-SOMA-RP-v1).
|
| 15 |
+
KIMODO_REPO_IDS = [
|
| 16 |
+
"nvidia/Kimodo-SOMA-RP-v1",
|
| 17 |
+
"nvidia/Kimodo-SMPLX-RP-v1",
|
| 18 |
+
"nvidia/Kimodo-G1-RP-v1",
|
| 19 |
+
"nvidia/Kimodo-SOMA-SEED-v1",
|
| 20 |
+
"nvidia/Kimodo-G1-SEED-v1",
|
| 21 |
+
]
|
| 22 |
+
TMR_REPO_IDS = [
|
| 23 |
+
"nvidia/TMR-SOMA-RP-v1",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
# Repo ID without org, for display (e.g. Kimodo-SOMA-RP-v1).
|
| 27 |
+
_REPO_NAME_PATTERN = re.compile(r"^(Kimodo|TMR)-([A-Za-z0-9]+)-(RP|SEED)-v(\d+)$")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class ModelInfo:
|
| 32 |
+
"""Structured metadata for one model, derived from its repo ID."""
|
| 33 |
+
|
| 34 |
+
repo_id: str
|
| 35 |
+
short_key: str
|
| 36 |
+
family: str
|
| 37 |
+
skeleton: str
|
| 38 |
+
dataset: str
|
| 39 |
+
version: str
|
| 40 |
+
display_name: str
|
| 41 |
+
|
| 42 |
+
@property
|
| 43 |
+
def dataset_ui_label(self) -> str:
|
| 44 |
+
return "Rigplay" if self.dataset == "RP" else "SEED"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _parse_repo_id(repo_id: str) -> Optional[ModelInfo]:
|
| 48 |
+
"""Parse a repo ID into ModelInfo.
|
| 49 |
+
|
| 50 |
+
Returns None if format is unrecognized.
|
| 51 |
+
"""
|
| 52 |
+
# repo_id is "org/Model-Name-v1"
|
| 53 |
+
if "/" in repo_id:
|
| 54 |
+
_, name = repo_id.split("/", 1)
|
| 55 |
+
else:
|
| 56 |
+
name = repo_id
|
| 57 |
+
m = _REPO_NAME_PATTERN.match(name)
|
| 58 |
+
if not m:
|
| 59 |
+
return None
|
| 60 |
+
family, skeleton, dataset, ver = m.groups()
|
| 61 |
+
# Normalize skeleton for display (as is for now)
|
| 62 |
+
skeleton_display = skeleton
|
| 63 |
+
# Include family so Kimodo-SOMA-RP and TMR-SOMA-RP have distinct keys.
|
| 64 |
+
short_key = f"{family.lower()}-{skeleton.lower()}-{dataset.lower()}"
|
| 65 |
+
return ModelInfo(
|
| 66 |
+
repo_id=repo_id,
|
| 67 |
+
short_key=short_key,
|
| 68 |
+
family=family,
|
| 69 |
+
skeleton=skeleton_display,
|
| 70 |
+
dataset=dataset,
|
| 71 |
+
version=f"v{ver}",
|
| 72 |
+
display_name=name,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _build_registry() -> tuple[list[ModelInfo], dict[str, str], list[str]]:
|
| 77 |
+
"""Build model infos, short_key -> repo_id map, and list of short keys.
|
| 78 |
+
|
| 79 |
+
When multiple versions exist for the same (family, skeleton, dataset), the base short_key (e.g.
|
| 80 |
+
kimodo-soma-rp) maps to the latest version's repo_id so that HF resolution finds the newest
|
| 81 |
+
model.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def _version_key(info: ModelInfo) -> int:
|
| 85 |
+
v = info.version
|
| 86 |
+
if v.startswith("v") and v[1:].isdigit():
|
| 87 |
+
return int(v[1:])
|
| 88 |
+
return 0
|
| 89 |
+
|
| 90 |
+
all_repos = KIMODO_REPO_IDS + TMR_REPO_IDS
|
| 91 |
+
infos: list[ModelInfo] = []
|
| 92 |
+
for repo_id in all_repos:
|
| 93 |
+
info = _parse_repo_id(repo_id)
|
| 94 |
+
if info is None:
|
| 95 |
+
raise ValueError(f"Registry repo ID does not match expected pattern: {repo_id}")
|
| 96 |
+
infos.append(info)
|
| 97 |
+
|
| 98 |
+
# Map each base short_key to the latest version's repo_id (by version number)
|
| 99 |
+
model_names: dict[str, str] = {}
|
| 100 |
+
seen_short_keys: set[str] = set()
|
| 101 |
+
for info in infos:
|
| 102 |
+
if info.short_key in seen_short_keys:
|
| 103 |
+
continue
|
| 104 |
+
seen_short_keys.add(info.short_key)
|
| 105 |
+
candidates = [
|
| 106 |
+
i for i in infos if i.family == info.family and i.skeleton == info.skeleton and i.dataset == info.dataset
|
| 107 |
+
]
|
| 108 |
+
if candidates:
|
| 109 |
+
latest = max(candidates, key=_version_key)
|
| 110 |
+
model_names[info.short_key] = latest.repo_id
|
| 111 |
+
|
| 112 |
+
return infos, model_names, list(model_names.keys())
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
MODEL_INFOS, MODEL_NAMES, _SHORT_KEYS = _build_registry()
|
| 116 |
+
AVAILABLE_MODELS = _SHORT_KEYS
|
| 117 |
+
|
| 118 |
+
# Short-key lists for Kimodo vs TMR (load_model uses TMR_MODELS to branch).
|
| 119 |
+
KIMODO_MODELS = [info.short_key for info in MODEL_INFOS if info.family == "Kimodo"]
|
| 120 |
+
TMR_MODELS = [info.short_key for info in MODEL_INFOS if info.family == "TMR"]
|
| 121 |
+
|
| 122 |
+
# Backward compatibility: FRIENDLY_NAMES for any code that still expects it.
|
| 123 |
+
FRIENDLY_NAMES = {info.short_key: info.display_name for info in MODEL_INFOS}
|
| 124 |
+
|
| 125 |
+
DEFAULT_MODEL = "kimodo-soma-rp"
|
| 126 |
+
DEFAULT_TEXT_ENCODER_URL = "http://127.0.0.1:9550/"
|
| 127 |
+
|
| 128 |
+
# Friendly names for skeleton dropdown (key -> label).
|
| 129 |
+
SKELETON_DISPLAY_NAMES = {
|
| 130 |
+
"SOMA": "SOMA Human Body",
|
| 131 |
+
"SMPLX": "SMPLX Human Body",
|
| 132 |
+
"G1": "Unitree G1 Humanoid Robot",
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
# Order for skeleton dropdown: SOMA, SMPLX, G1.
|
| 136 |
+
SKELETON_ORDER = ("SOMA", "SMPLX", "G1")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_skeleton_display_name(skeleton_key: str) -> str:
|
| 140 |
+
"""Return the UI label for a skeleton key (e.g. SOMA -> SOMA Human Body)."""
|
| 141 |
+
return SKELETON_DISPLAY_NAMES.get(skeleton_key, skeleton_key)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_skeleton_key_from_display_name(display_name: str) -> Optional[str]:
|
| 145 |
+
"""Return the skeleton key for a UI label, or None."""
|
| 146 |
+
for key, label in SKELETON_DISPLAY_NAMES.items():
|
| 147 |
+
if label == display_name:
|
| 148 |
+
return key
|
| 149 |
+
return None
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_skeleton_display_names_for_dataset(dataset_ui_label: str, family: Optional[str] = None) -> list[str]:
|
| 153 |
+
"""Return skeleton UI labels for the given dataset.
|
| 154 |
+
|
| 155 |
+
If family is set (e.g. "Kimodo"), only skeletons with a model of that family are included.
|
| 156 |
+
"""
|
| 157 |
+
keys = get_skeletons_for_dataset(dataset_ui_label, family=family)
|
| 158 |
+
return [get_skeleton_display_name(k) for k in keys]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_short_key(repo_id: str) -> Optional[str]:
|
| 162 |
+
"""Return the short key for a repo ID, or None if not in registry."""
|
| 163 |
+
for info in MODEL_INFOS:
|
| 164 |
+
if info.repo_id == repo_id:
|
| 165 |
+
return info.short_key
|
| 166 |
+
return None
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def get_model_info(short_key: str) -> Optional[ModelInfo]:
|
| 170 |
+
"""Return ModelInfo for a short key, or None if not found.
|
| 171 |
+
|
| 172 |
+
When multiple versions share the same short_key, returns the one used for loading (the latest
|
| 173 |
+
version), so CHECKPOINT_DIR and HF use the same version.
|
| 174 |
+
"""
|
| 175 |
+
repo_id = MODEL_NAMES.get(short_key)
|
| 176 |
+
if repo_id is None:
|
| 177 |
+
return None
|
| 178 |
+
for info in MODEL_INFOS:
|
| 179 |
+
if info.repo_id == repo_id:
|
| 180 |
+
return info
|
| 181 |
+
return None
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def get_short_key_from_display_name(display_name: str) -> Optional[str]:
|
| 185 |
+
"""Return short_key for a display name (e.g. Kimodo-SOMA-RP-v1), or None."""
|
| 186 |
+
for info in MODEL_INFOS:
|
| 187 |
+
if info.display_name == display_name:
|
| 188 |
+
return info.short_key
|
| 189 |
+
return None
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def get_models_for_demo() -> list[ModelInfo]:
|
| 193 |
+
"""Return all model infos in registry order (for demo model list)."""
|
| 194 |
+
return list(MODEL_INFOS)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_datasets(family: Optional[str] = None) -> list[str]:
|
| 198 |
+
"""Return unique dataset UI labels (Rigplay, SEED) present in registry.
|
| 199 |
+
|
| 200 |
+
If family is set (e.g. "Kimodo"), only datasets that have a model of that family are included.
|
| 201 |
+
"""
|
| 202 |
+
infos = MODEL_INFOS
|
| 203 |
+
if family is not None:
|
| 204 |
+
infos = [i for i in infos if i.family == family]
|
| 205 |
+
labels = set()
|
| 206 |
+
for info in infos:
|
| 207 |
+
labels.add(info.dataset_ui_label)
|
| 208 |
+
return sorted(labels)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def get_skeletons_for_dataset(dataset_ui_label: str, family: Optional[str] = None) -> list[str]:
|
| 212 |
+
"""Return skeleton names that have a model for the given dataset.
|
| 213 |
+
|
| 214 |
+
Order: SOMA, SMPLX, G1 (only those present for the dataset).
|
| 215 |
+
If family is set (e.g. "Kimodo"), only skeletons with a model of that
|
| 216 |
+
family are included.
|
| 217 |
+
"""
|
| 218 |
+
dataset = "RP" if dataset_ui_label == "Rigplay" else "SEED"
|
| 219 |
+
infos = MODEL_INFOS
|
| 220 |
+
if family is not None:
|
| 221 |
+
infos = [i for i in infos if i.family == family]
|
| 222 |
+
skeletons = set()
|
| 223 |
+
for info in infos:
|
| 224 |
+
if info.dataset == dataset:
|
| 225 |
+
skeletons.add(info.skeleton)
|
| 226 |
+
return [s for s in SKELETON_ORDER if s in skeletons]
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def get_versions_for_dataset_skeleton(dataset_ui_label: str, skeleton: str) -> list[str]:
|
| 230 |
+
"""Return version strings (e.g. v1) for the given dataset/skeleton.
|
| 231 |
+
|
| 232 |
+
Sorted by version number so the last element is the highest (e.g. v1, v2).
|
| 233 |
+
"""
|
| 234 |
+
dataset = "RP" if dataset_ui_label == "Rigplay" else "SEED"
|
| 235 |
+
versions = []
|
| 236 |
+
for info in MODEL_INFOS:
|
| 237 |
+
if info.dataset == dataset and info.skeleton == skeleton:
|
| 238 |
+
versions.append(info.version)
|
| 239 |
+
|
| 240 |
+
# Sort by numeric part so v2 comes after v1.
|
| 241 |
+
def version_key(v: str) -> int:
|
| 242 |
+
if v.startswith("v") and v[1:].isdigit():
|
| 243 |
+
return int(v[1:])
|
| 244 |
+
return 0
|
| 245 |
+
|
| 246 |
+
return sorted(set(versions), key=version_key)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def get_models_for_dataset_skeleton(
|
| 250 |
+
dataset_ui_label: str, skeleton: str, family: Optional[str] = None
|
| 251 |
+
) -> list[ModelInfo]:
|
| 252 |
+
"""Return model infos for the given dataset/skeleton, sorted by version (max first).
|
| 253 |
+
|
| 254 |
+
Used to build the Version dropdown (options = full display names, one per model). If family is
|
| 255 |
+
set (e.g. "Kimodo"), only models of that family are returned.
|
| 256 |
+
"""
|
| 257 |
+
dataset = "RP" if dataset_ui_label == "Rigplay" else "SEED"
|
| 258 |
+
infos = [info for info in MODEL_INFOS if info.dataset == dataset and info.skeleton == skeleton]
|
| 259 |
+
if family is not None:
|
| 260 |
+
infos = [i for i in infos if i.family == family]
|
| 261 |
+
|
| 262 |
+
def version_key(info: ModelInfo) -> int:
|
| 263 |
+
v = info.version
|
| 264 |
+
if v.startswith("v") and v[1:].isdigit():
|
| 265 |
+
return int(v[1:])
|
| 266 |
+
return 0
|
| 267 |
+
|
| 268 |
+
return sorted(infos, key=version_key, reverse=True)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def resolve_to_short_key(dataset_ui_label: str, skeleton: str, version: str) -> Optional[str]:
|
| 272 |
+
"""Return the short key for (dataset, skeleton, version), or None."""
|
| 273 |
+
for info in MODEL_INFOS:
|
| 274 |
+
if info.dataset_ui_label == dataset_ui_label and info.skeleton == skeleton and info.version == version:
|
| 275 |
+
return info.short_key
|
| 276 |
+
return None
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# -----------------------------------------------------------------------------
|
| 280 |
+
# Flexible model name resolution (partial names, case-insensitive, defaults)
|
| 281 |
+
# -----------------------------------------------------------------------------
|
| 282 |
+
|
| 283 |
+
_FAMILY_ALIASES = {"kimodo": "Kimodo", "tmr": "TMR"}
|
| 284 |
+
_DATASET_ALIASES = {"rp": "RP", "rigplay": "RP", "seed": "SEED"}
|
| 285 |
+
_SKELETON_ALIASES = {
|
| 286 |
+
"soma": "SOMA",
|
| 287 |
+
"smplx": "SMPLX",
|
| 288 |
+
"g1": "G1",
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def _normalize_family(s: str) -> Optional[str]:
|
| 293 |
+
"""Return canonical family (Kimodo/TMR) or None if unknown."""
|
| 294 |
+
return _FAMILY_ALIASES.get(s.strip().lower())
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def _normalize_dataset(s: str) -> Optional[str]:
|
| 298 |
+
"""Return canonical dataset (RP/SEED) or None if unknown."""
|
| 299 |
+
return _DATASET_ALIASES.get(s.strip().lower())
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def _normalize_skeleton(s: str) -> Optional[str]:
|
| 303 |
+
"""Return canonical skeleton (SOMA/SMPLX/G1) or None if unknown."""
|
| 304 |
+
return _SKELETON_ALIASES.get(s.strip().lower())
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _get_latest_for_family_skeleton_dataset(family: str, skeleton: str, dataset: str) -> Optional[ModelInfo]:
|
| 308 |
+
"""Return the model info with the highest version for (family, skeleton, dataset)."""
|
| 309 |
+
candidates = [
|
| 310 |
+
info for info in MODEL_INFOS if info.family == family and info.skeleton == skeleton and info.dataset == dataset
|
| 311 |
+
]
|
| 312 |
+
if not candidates:
|
| 313 |
+
return None
|
| 314 |
+
|
| 315 |
+
def version_key(info: ModelInfo) -> int:
|
| 316 |
+
v = info.version
|
| 317 |
+
if v.startswith("v") and v[1:].isdigit():
|
| 318 |
+
return int(v[1:])
|
| 319 |
+
return 0
|
| 320 |
+
|
| 321 |
+
return max(candidates, key=version_key)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def kimodo_short_key_for_skeleton_dataset(skeleton: str, dataset: str) -> Optional[str]:
|
| 325 |
+
"""Return the latest Kimodo model short_key for ``skeleton`` and ``dataset`` (RP/SEED), or
|
| 326 |
+
None."""
|
| 327 |
+
info = _get_latest_for_family_skeleton_dataset("Kimodo", skeleton, dataset)
|
| 328 |
+
return info.short_key if info is not None else None
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def registry_skeleton_for_joint_count(nb_joints: int) -> str:
|
| 332 |
+
"""Map motion joint count to registry skeleton key (SOMA / SMPLX / G1)."""
|
| 333 |
+
if nb_joints == 34:
|
| 334 |
+
return "G1"
|
| 335 |
+
if nb_joints == 22:
|
| 336 |
+
return "SMPLX"
|
| 337 |
+
if nb_joints in (77, 30):
|
| 338 |
+
return "SOMA"
|
| 339 |
+
raise ValueError(f"No Kimodo model registered for motion with J={nb_joints}")
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# Optional version: Family-Skeleton-Dataset-vN or Family-Skeleton-Dataset
|
| 343 |
+
_RESOLVE_FULL_PATTERN = re.compile(
|
| 344 |
+
r"^(Kimodo|TMR|kimodo|tmr)[\-_]" r"([A-Za-z0-9]+)[\-_]" r"(RP|SEED|rp|seed)" r"(?:[\-_]v(\d+))?$",
|
| 345 |
+
re.IGNORECASE,
|
| 346 |
+
)
|
| 347 |
+
# Partial: Skeleton-Dataset or Skeleton or Dataset (no family)
|
| 348 |
+
_RESOLVE_PARTIAL_PATTERN = re.compile(
|
| 349 |
+
r"^([A-Za-z0-9]+)(?:[\-_](RP|SEED|rp|seed))?(?:[\-_]v(\d+))?$",
|
| 350 |
+
re.IGNORECASE,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def resolve_model_name(name: Optional[str], default_family: Optional[str] = None) -> str:
|
| 355 |
+
"""Resolve a user-facing model name to a short_key.
|
| 356 |
+
|
| 357 |
+
Accepts full names (e.g. Kimodo-SOMA-RP-v1), case-insensitive matching,
|
| 358 |
+
and partial names with defaults: dataset=RP, skeleton=SOMA, family from
|
| 359 |
+
default_family (Kimodo for demo/generation, TMR for embed script).
|
| 360 |
+
Omitted version resolves to the latest for that model.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
name: User-provided name (can be None or empty).
|
| 364 |
+
default_family: "Kimodo" or "TMR" when name is empty or omits family.
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
Short key (e.g. kimodo-soma-rp) for use with load_model / MODEL_NAMES.
|
| 368 |
+
|
| 369 |
+
Raises:
|
| 370 |
+
ValueError: If name cannot be resolved or default_family is missing when needed.
|
| 371 |
+
"""
|
| 372 |
+
if name is not None:
|
| 373 |
+
name = name.strip()
|
| 374 |
+
if not name:
|
| 375 |
+
if default_family is None:
|
| 376 |
+
raise ValueError('Model name is empty; provide a name or set default_family ("Kimodo" or "TMR").')
|
| 377 |
+
fam = _normalize_family(default_family)
|
| 378 |
+
if fam is None:
|
| 379 |
+
raise ValueError(f"default_family must be 'Kimodo' or 'TMR', got {default_family!r}")
|
| 380 |
+
info = _get_latest_for_family_skeleton_dataset(fam, "SOMA", "RP")
|
| 381 |
+
if info is None:
|
| 382 |
+
raise ValueError(f"No model found for {fam}-SOMA-RP. Available: {list(MODEL_NAMES.keys())}")
|
| 383 |
+
return info.short_key
|
| 384 |
+
|
| 385 |
+
# Exact short_key
|
| 386 |
+
if name in MODEL_NAMES:
|
| 387 |
+
return name
|
| 388 |
+
|
| 389 |
+
# Case-insensitive match against short_key or display_name
|
| 390 |
+
name_lower = name.lower()
|
| 391 |
+
matches = []
|
| 392 |
+
for info in MODEL_INFOS:
|
| 393 |
+
if name_lower == info.short_key.lower():
|
| 394 |
+
matches.append(info)
|
| 395 |
+
disp = info.display_name.lower()
|
| 396 |
+
if name_lower == disp or name_lower == ("nvidia/" + disp):
|
| 397 |
+
matches.append(info)
|
| 398 |
+
if len(matches) == 1:
|
| 399 |
+
return matches[0].short_key
|
| 400 |
+
if len(matches) > 1:
|
| 401 |
+
return matches[0].short_key
|
| 402 |
+
|
| 403 |
+
# Parsed full form: Family-Skeleton-Dataset or Family-Skeleton-Dataset-vN
|
| 404 |
+
m = _RESOLVE_FULL_PATTERN.match(name)
|
| 405 |
+
if m:
|
| 406 |
+
fam_raw, skel_raw, ds_raw, ver_num = m.groups()
|
| 407 |
+
fam = _normalize_family(fam_raw)
|
| 408 |
+
skel = _normalize_skeleton(skel_raw)
|
| 409 |
+
ds = _normalize_dataset(ds_raw)
|
| 410 |
+
if fam is not None and skel is not None and ds is not None:
|
| 411 |
+
if ver_num is not None:
|
| 412 |
+
version = f"v{ver_num}"
|
| 413 |
+
for info in MODEL_INFOS:
|
| 414 |
+
if info.family == fam and info.skeleton == skel and info.dataset == ds and info.version == version:
|
| 415 |
+
return info.short_key
|
| 416 |
+
else:
|
| 417 |
+
info = _get_latest_for_family_skeleton_dataset(fam, skel, ds)
|
| 418 |
+
if info is not None:
|
| 419 |
+
return info.short_key
|
| 420 |
+
|
| 421 |
+
# Parsed partial: Skeleton-Dataset, Skeleton, or Dataset (use default_family)
|
| 422 |
+
if default_family is not None:
|
| 423 |
+
m = _RESOLVE_PARTIAL_PATTERN.match(name)
|
| 424 |
+
if m:
|
| 425 |
+
tok1, ds_raw, ver_num = m.groups()
|
| 426 |
+
fam = _normalize_family(default_family)
|
| 427 |
+
if fam is not None:
|
| 428 |
+
skel = _normalize_skeleton(tok1)
|
| 429 |
+
ds_candidate = _normalize_dataset(ds_raw) if ds_raw else None
|
| 430 |
+
if skel is not None and ds_candidate is not None:
|
| 431 |
+
ds = ds_candidate
|
| 432 |
+
elif skel is not None:
|
| 433 |
+
ds = "RP"
|
| 434 |
+
else:
|
| 435 |
+
skel = "SOMA"
|
| 436 |
+
ds = _normalize_dataset(tok1) if tok1 else "RP"
|
| 437 |
+
if ds is None:
|
| 438 |
+
ds = "RP"
|
| 439 |
+
if ver_num is not None:
|
| 440 |
+
version = f"v{ver_num}"
|
| 441 |
+
for info in MODEL_INFOS:
|
| 442 |
+
if (
|
| 443 |
+
info.family == fam
|
| 444 |
+
and info.skeleton == skel
|
| 445 |
+
and info.dataset == ds
|
| 446 |
+
and info.version == version
|
| 447 |
+
):
|
| 448 |
+
return info.short_key
|
| 449 |
+
else:
|
| 450 |
+
info = _get_latest_for_family_skeleton_dataset(fam, skel, ds)
|
| 451 |
+
if info is not None:
|
| 452 |
+
return info.short_key
|
| 453 |
+
|
| 454 |
+
# Single token: skeleton or dataset
|
| 455 |
+
fam = _normalize_family(default_family)
|
| 456 |
+
if fam is not None:
|
| 457 |
+
skel = _normalize_skeleton(name)
|
| 458 |
+
if skel is not None:
|
| 459 |
+
info = _get_latest_for_family_skeleton_dataset(fam, skel, "RP")
|
| 460 |
+
if info is not None:
|
| 461 |
+
return info.short_key
|
| 462 |
+
ds = _normalize_dataset(name)
|
| 463 |
+
if ds is not None:
|
| 464 |
+
info = _get_latest_for_family_skeleton_dataset(fam, "SOMA", ds)
|
| 465 |
+
if info is not None:
|
| 466 |
+
return info.short_key
|
| 467 |
+
|
| 468 |
+
raise ValueError(
|
| 469 |
+
f"Model name {name!r} could not be resolved. "
|
| 470 |
+
f"Use a short key (e.g. {list(MODEL_NAMES.keys())[:3]}...), "
|
| 471 |
+
"a full name (e.g. Kimodo-SOMA-RP-v1), or a partial (e.g. SOMA-RP, SOMA) "
|
| 472 |
+
"with default_family set."
|
| 473 |
+
)
|
kimodo/model/text_encoder_api.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Remote text encoder API client (Gradio) for motion generation."""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from gradio_client import Client
|
| 10 |
+
|
| 11 |
+
# Suppress the [httpx] logs (GET requests)
|
| 12 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 13 |
+
|
| 14 |
+
# Suppress internal gradio_client logs
|
| 15 |
+
logging.getLogger("gradio_client").setLevel(logging.WARNING)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TextEncoderAPI:
|
| 19 |
+
"""Text encoder API client for motion generation."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, url: str):
|
| 22 |
+
self.client = Client(url, verbose=False)
|
| 23 |
+
self.device = "cpu"
|
| 24 |
+
self.dtype = torch.float
|
| 25 |
+
|
| 26 |
+
def _create_np_random_name(self):
|
| 27 |
+
import uuid
|
| 28 |
+
|
| 29 |
+
return str(uuid.uuid4()) + ".npy"
|
| 30 |
+
|
| 31 |
+
def to(self, device=None, dtype=None):
|
| 32 |
+
if device is not None:
|
| 33 |
+
self.device = device
|
| 34 |
+
if dtype is not None:
|
| 35 |
+
self.dtype = dtype
|
| 36 |
+
return self
|
| 37 |
+
|
| 38 |
+
def __call__(self, texts):
|
| 39 |
+
"""Encode text prompts into tensors.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
texts (str | list[str]): text prompts to encode
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
tuple[torch.Tensor, list[int]]: encoded text tensors and their lengths
|
| 46 |
+
"""
|
| 47 |
+
if isinstance(texts, str):
|
| 48 |
+
texts = [texts]
|
| 49 |
+
|
| 50 |
+
tensors = []
|
| 51 |
+
lengths = []
|
| 52 |
+
for text in texts:
|
| 53 |
+
filename = self._create_np_random_name()
|
| 54 |
+
|
| 55 |
+
# Use a long result timeout to tolerate text-encoder cold-start (LLM2Vec model load ~60-120s).
|
| 56 |
+
result = self.client.submit(
|
| 57 |
+
text=text,
|
| 58 |
+
filename=filename,
|
| 59 |
+
api_name="/DemoWrapper",
|
| 60 |
+
).result(timeout=300)
|
| 61 |
+
path = result[0]["value"]
|
| 62 |
+
tensor = np.load(path)
|
| 63 |
+
length = tensor.shape[0]
|
| 64 |
+
|
| 65 |
+
tensors.append(tensor)
|
| 66 |
+
lengths.append(length)
|
| 67 |
+
|
| 68 |
+
padded_tensor = np.zeros((len(lengths), max(lengths), tensors[0].shape[-1]), dtype=tensors[0].dtype)
|
| 69 |
+
for idx, (tensor, length) in enumerate(zip(tensors, lengths)):
|
| 70 |
+
padded_tensor[idx, :length] = tensor
|
| 71 |
+
|
| 72 |
+
padded_tensor = torch.from_numpy(padded_tensor)
|
| 73 |
+
padded_tensor = padded_tensor.to(device=self.device, dtype=self.dtype)
|
| 74 |
+
return padded_tensor, lengths
|
kimodo/model/tmr.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""TMR model: encoder, and text-to-motion retrieval head."""
|
| 4 |
+
|
| 5 |
+
import contextlib
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from einops import repeat
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
|
| 14 |
+
from kimodo.model import load_checkpoint_state_dict
|
| 15 |
+
from kimodo.motion_rep.feature_utils import length_to_mask
|
| 16 |
+
from kimodo.sanitize import sanitize_texts
|
| 17 |
+
from kimodo.skeleton import SkeletonBase, build_skeleton
|
| 18 |
+
from kimodo.tools import ensure_batched
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PositionalEncoding(nn.Module):
|
| 22 |
+
"""Sinusoidal positional encoding for sequences (batch_first optional)."""
|
| 23 |
+
|
| 24 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.batch_first = batch_first
|
| 27 |
+
|
| 28 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 29 |
+
|
| 30 |
+
pe = torch.zeros(max_len, d_model)
|
| 31 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 32 |
+
# Note: have to replace torch.exp() and math.log() with torch.pow()
|
| 33 |
+
# due to MKL exp() and ln() throws floating point exceptions on certain CPUs
|
| 34 |
+
div_term = torch.pow(10000.0, -torch.arange(0, d_model, 2).float() / d_model)
|
| 35 |
+
# div_term = torch.exp(
|
| 36 |
+
# torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
|
| 37 |
+
# )
|
| 38 |
+
|
| 39 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 40 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 41 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 42 |
+
self.register_buffer("pe", pe, persistent=False)
|
| 43 |
+
|
| 44 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 45 |
+
if self.batch_first:
|
| 46 |
+
x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]
|
| 47 |
+
else:
|
| 48 |
+
x = x + self.pe[: x.shape[0], :]
|
| 49 |
+
return self.dropout(x)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_ckpt(self, ckpt_path):
|
| 53 |
+
"""Load model weights from checkpoint path."""
|
| 54 |
+
state_dict = load_checkpoint_state_dict(ckpt_path)
|
| 55 |
+
self.load_state_dict(state_dict)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ACTORStyleEncoder(nn.Module):
|
| 59 |
+
"""Motion encoder in ACTOR style: optional motion_rep projection, VAE/MLP tokens, transformer."""
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
motion_rep: Optional[nn.Module],
|
| 64 |
+
llm_shape: Optional[Tuple],
|
| 65 |
+
vae: bool,
|
| 66 |
+
latent_dim: int = 256,
|
| 67 |
+
ff_size: int = 1024,
|
| 68 |
+
num_layers: int = 4,
|
| 69 |
+
num_heads: int = 4,
|
| 70 |
+
dropout: float = 0.1,
|
| 71 |
+
activation: str = "gelu",
|
| 72 |
+
ckpt_path: Optional[str] = None,
|
| 73 |
+
) -> None:
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
self.motion_rep = motion_rep
|
| 77 |
+
if motion_rep is not None and llm_shape is None:
|
| 78 |
+
nfeats = motion_rep.motion_rep_dim
|
| 79 |
+
elif motion_rep is None and llm_shape is not None:
|
| 80 |
+
nfeats = llm_shape[-1]
|
| 81 |
+
else:
|
| 82 |
+
raise ValueError
|
| 83 |
+
|
| 84 |
+
self.nfeats = nfeats
|
| 85 |
+
self.projection = nn.Linear(nfeats, latent_dim)
|
| 86 |
+
|
| 87 |
+
self.vae = vae
|
| 88 |
+
self.nbtokens = 2 if vae else 1
|
| 89 |
+
self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim))
|
| 90 |
+
|
| 91 |
+
self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout=dropout, batch_first=True)
|
| 92 |
+
|
| 93 |
+
seq_trans_encoder_layer = nn.TransformerEncoderLayer(
|
| 94 |
+
d_model=latent_dim,
|
| 95 |
+
nhead=num_heads,
|
| 96 |
+
dim_feedforward=ff_size,
|
| 97 |
+
dropout=dropout,
|
| 98 |
+
activation=activation,
|
| 99 |
+
batch_first=True,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self.seqTransEncoder = nn.TransformerEncoder(
|
| 103 |
+
seq_trans_encoder_layer,
|
| 104 |
+
num_layers=num_layers,
|
| 105 |
+
enable_nested_tensor=False,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
if ckpt_path is not None:
|
| 109 |
+
load_ckpt(self, ckpt_path)
|
| 110 |
+
|
| 111 |
+
def forward(self, x_dict: Dict) -> Tensor:
|
| 112 |
+
x = x_dict["x"]
|
| 113 |
+
mask = x_dict["mask"]
|
| 114 |
+
|
| 115 |
+
x = self.projection(x)
|
| 116 |
+
|
| 117 |
+
device = x.device
|
| 118 |
+
bs = len(x)
|
| 119 |
+
|
| 120 |
+
tokens = repeat(self.tokens, "nbtoken dim -> bs nbtoken dim", bs=bs)
|
| 121 |
+
xseq = torch.cat((tokens, x), 1)
|
| 122 |
+
|
| 123 |
+
token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device)
|
| 124 |
+
aug_mask = torch.cat((token_mask, mask), 1)
|
| 125 |
+
|
| 126 |
+
# add positional encoding
|
| 127 |
+
xseq = self.sequence_pos_encoding(xseq)
|
| 128 |
+
final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
|
| 129 |
+
return final[:, : self.nbtokens]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class TMR(nn.Module):
|
| 133 |
+
r"""TMR: Text-to-Motion Retrieval inference code (no decoder)
|
| 134 |
+
Find more information about the model on the following website:
|
| 135 |
+
https://mathis.petrovich.fr/tmr
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
@classmethod
|
| 139 |
+
def from_args(
|
| 140 |
+
cls,
|
| 141 |
+
motion_rep: nn.Module,
|
| 142 |
+
llm_shape: tuple | list,
|
| 143 |
+
vae: bool,
|
| 144 |
+
latent_dim: int = 256,
|
| 145 |
+
ff_size: int = 1024,
|
| 146 |
+
num_layers: int = 4,
|
| 147 |
+
num_heads: int = 4,
|
| 148 |
+
dropout: float = 0.1,
|
| 149 |
+
activation: str = "gelu",
|
| 150 |
+
ckpt_folder: Optional[str] = None,
|
| 151 |
+
device: Optional[str] = None,
|
| 152 |
+
**kwargs,
|
| 153 |
+
):
|
| 154 |
+
motion_encoder, top_text_encoder = None, None
|
| 155 |
+
|
| 156 |
+
motion_encoder = ACTORStyleEncoder(
|
| 157 |
+
motion_rep=motion_rep,
|
| 158 |
+
llm_shape=None,
|
| 159 |
+
vae=vae,
|
| 160 |
+
latent_dim=latent_dim,
|
| 161 |
+
ff_size=ff_size,
|
| 162 |
+
num_layers=num_layers,
|
| 163 |
+
num_heads=num_heads,
|
| 164 |
+
dropout=dropout,
|
| 165 |
+
activation=activation,
|
| 166 |
+
ckpt_path=Path(ckpt_folder) / "motion_encoder.pt",
|
| 167 |
+
).to(device)
|
| 168 |
+
|
| 169 |
+
top_text_encoder = ACTORStyleEncoder(
|
| 170 |
+
motion_rep=None,
|
| 171 |
+
llm_shape=llm_shape,
|
| 172 |
+
vae=vae,
|
| 173 |
+
latent_dim=latent_dim,
|
| 174 |
+
ff_size=ff_size,
|
| 175 |
+
num_layers=num_layers,
|
| 176 |
+
num_heads=num_heads,
|
| 177 |
+
dropout=dropout,
|
| 178 |
+
activation=activation,
|
| 179 |
+
ckpt_path=Path(ckpt_folder) / "text_encoder.pt",
|
| 180 |
+
).to(device)
|
| 181 |
+
return cls(
|
| 182 |
+
motion_encoder,
|
| 183 |
+
top_text_encoder,
|
| 184 |
+
vae,
|
| 185 |
+
device=device,
|
| 186 |
+
**kwargs,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def __init__(
|
| 190 |
+
self,
|
| 191 |
+
motion_encoder: nn.Module,
|
| 192 |
+
top_text_encoder: nn.Module,
|
| 193 |
+
vae: bool,
|
| 194 |
+
text_encoder: Optional = None,
|
| 195 |
+
fact: Optional[float] = None,
|
| 196 |
+
sample_mean: Optional[bool] = True,
|
| 197 |
+
unit_vector: Optional[bool] = False,
|
| 198 |
+
compute_grads: bool = False,
|
| 199 |
+
device: Optional[str] = None,
|
| 200 |
+
) -> None:
|
| 201 |
+
super().__init__()
|
| 202 |
+
|
| 203 |
+
self.motion_encoder = motion_encoder
|
| 204 |
+
self.text_encoder = top_text_encoder
|
| 205 |
+
self.raw_text_encoder = text_encoder
|
| 206 |
+
|
| 207 |
+
self.motion_rep = None
|
| 208 |
+
self.skeleton = None
|
| 209 |
+
if self.motion_encoder is not None:
|
| 210 |
+
self.motion_rep = self.motion_encoder.motion_rep
|
| 211 |
+
if self.motion_rep is not None:
|
| 212 |
+
self.skeleton = self.motion_rep.skeleton
|
| 213 |
+
|
| 214 |
+
self.compute_grads = compute_grads
|
| 215 |
+
|
| 216 |
+
self.device = device
|
| 217 |
+
|
| 218 |
+
# sampling parameters
|
| 219 |
+
self.vae = vae
|
| 220 |
+
self.fact = fact if fact is not None else 1.0
|
| 221 |
+
self.sample_mean = sample_mean
|
| 222 |
+
self.unit_vector = unit_vector
|
| 223 |
+
|
| 224 |
+
def full_text_encoder(self, texts: list[str]):
|
| 225 |
+
assert isinstance(texts, list), "The input should be batched."
|
| 226 |
+
# sanitize the texts first
|
| 227 |
+
# then encode the text, and then use the top text encoder
|
| 228 |
+
texts = sanitize_texts(texts)
|
| 229 |
+
text_feat, text_length = self.raw_text_encoder(texts)
|
| 230 |
+
if isinstance(text_length, list):
|
| 231 |
+
text_length = torch.tensor(text_length, device=self.device)
|
| 232 |
+
else:
|
| 233 |
+
text_length = text_length.to(self.device)
|
| 234 |
+
inputs = {
|
| 235 |
+
"x": text_feat.to(self.device),
|
| 236 |
+
"mask": length_to_mask(text_length, device=self.device),
|
| 237 |
+
}
|
| 238 |
+
return self.text_encoder(inputs)
|
| 239 |
+
|
| 240 |
+
def _find_encoder(self, inputs, modality):
|
| 241 |
+
assert modality in ["text", "motion", "raw_text", "auto"]
|
| 242 |
+
|
| 243 |
+
if modality == "text":
|
| 244 |
+
return self.text_encoder
|
| 245 |
+
elif modality == "motion":
|
| 246 |
+
return self.motion_encoder
|
| 247 |
+
elif modality == "raw_text":
|
| 248 |
+
return self.full_text_encoder
|
| 249 |
+
|
| 250 |
+
if isinstance(inputs[0], str):
|
| 251 |
+
return self.full_text_encoder
|
| 252 |
+
|
| 253 |
+
m_nfeats = self.motion_encoder.nfeats
|
| 254 |
+
t_nfeats = self.text_encoder.nfeats
|
| 255 |
+
|
| 256 |
+
if m_nfeats == t_nfeats:
|
| 257 |
+
raise ValueError("Cannot automatically find the encoder, as they share the same input space.")
|
| 258 |
+
|
| 259 |
+
nfeats = inputs["x"].shape[-1]
|
| 260 |
+
if nfeats == m_nfeats:
|
| 261 |
+
return self.motion_encoder
|
| 262 |
+
elif nfeats == t_nfeats:
|
| 263 |
+
return self.text_encoder
|
| 264 |
+
else:
|
| 265 |
+
raise ValueError("The inputs is not recognized.")
|
| 266 |
+
|
| 267 |
+
def _encode(
|
| 268 |
+
self,
|
| 269 |
+
inputs,
|
| 270 |
+
modality: str = "auto",
|
| 271 |
+
sample_mean: Optional[bool] = None,
|
| 272 |
+
fact: Optional[float] = None,
|
| 273 |
+
return_distribution: bool = False,
|
| 274 |
+
unit_vector: Optional[bool] = None,
|
| 275 |
+
):
|
| 276 |
+
sample_mean = self.sample_mean if sample_mean is None else sample_mean
|
| 277 |
+
fact = self.fact if fact is None else fact
|
| 278 |
+
unit_vector = self.unit_vector if unit_vector is None else unit_vector
|
| 279 |
+
|
| 280 |
+
# Encode the inputs
|
| 281 |
+
encoder = self._find_encoder(inputs, modality)
|
| 282 |
+
encoded = encoder(inputs)
|
| 283 |
+
|
| 284 |
+
# Sampling
|
| 285 |
+
if self.vae:
|
| 286 |
+
dists = encoded.unbind(1)
|
| 287 |
+
mu, logvar = dists
|
| 288 |
+
if sample_mean:
|
| 289 |
+
latent_vectors = mu
|
| 290 |
+
else:
|
| 291 |
+
# Reparameterization trick
|
| 292 |
+
std = logvar.exp().pow(0.5)
|
| 293 |
+
eps = std.data.new(std.size()).normal_()
|
| 294 |
+
latent_vectors = mu + fact * eps * std
|
| 295 |
+
else:
|
| 296 |
+
dists = None
|
| 297 |
+
(latent_vectors,) = encoded.unbind(1)
|
| 298 |
+
|
| 299 |
+
if unit_vector:
|
| 300 |
+
latent_vectors = torch.nn.functional.normalize(latent_vectors, dim=-1)
|
| 301 |
+
|
| 302 |
+
if return_distribution:
|
| 303 |
+
return latent_vectors, dists
|
| 304 |
+
|
| 305 |
+
return latent_vectors
|
| 306 |
+
|
| 307 |
+
@ensure_batched(posed_joints=4, lengths=1)
|
| 308 |
+
def encode_motion(
|
| 309 |
+
self,
|
| 310 |
+
posed_joints: torch.Tensor,
|
| 311 |
+
original_skeleton: Optional[SkeletonBase] = None,
|
| 312 |
+
lengths: Optional[torch.Tensor] = None,
|
| 313 |
+
unit_vector: Optional[bool] = None,
|
| 314 |
+
):
|
| 315 |
+
# TODO here.
|
| 316 |
+
convert_ctx = torch.no_grad() if not self.compute_grads else contextlib.nullcontext()
|
| 317 |
+
|
| 318 |
+
if original_skeleton is None:
|
| 319 |
+
original_skeleton = build_skeleton(posed_joints.shape[-2])
|
| 320 |
+
|
| 321 |
+
if lengths is None:
|
| 322 |
+
nbatch, nbframes = posed_joints.shape[:2]
|
| 323 |
+
device = posed_joints.device
|
| 324 |
+
assert nbatch == 1, "If lenghts is not provided, the input should not be batched."
|
| 325 |
+
lengths = torch.tensor([nbframes], device=device)
|
| 326 |
+
|
| 327 |
+
# slice the posed joints if we use less joints
|
| 328 |
+
skel_slice = self.motion_rep.skeleton.get_skel_slice(original_skeleton)
|
| 329 |
+
posed_joints = posed_joints[..., skel_slice, :]
|
| 330 |
+
|
| 331 |
+
with convert_ctx:
|
| 332 |
+
features = self.motion_rep(
|
| 333 |
+
posed_joints=posed_joints,
|
| 334 |
+
to_normalize=True,
|
| 335 |
+
lengths=lengths,
|
| 336 |
+
)
|
| 337 |
+
mask = length_to_mask(lengths, device=features.device)
|
| 338 |
+
x_dict = {"x": features, "mask": mask}
|
| 339 |
+
latent_vectors = self._encode(
|
| 340 |
+
x_dict,
|
| 341 |
+
modality="motion",
|
| 342 |
+
unit_vector=unit_vector,
|
| 343 |
+
)
|
| 344 |
+
return latent_vectors
|
| 345 |
+
|
| 346 |
+
def encode_text(
|
| 347 |
+
self,
|
| 348 |
+
x_dict: Dict,
|
| 349 |
+
unit_vector: Optional[bool] = None,
|
| 350 |
+
):
|
| 351 |
+
# TODO: make it ensure batched
|
| 352 |
+
convert_ctx = torch.no_grad() if not self.compute_grads else contextlib.nullcontext()
|
| 353 |
+
|
| 354 |
+
with convert_ctx:
|
| 355 |
+
latent_vectors = self._encode(
|
| 356 |
+
x_dict,
|
| 357 |
+
modality="text",
|
| 358 |
+
unit_vector=unit_vector,
|
| 359 |
+
)
|
| 360 |
+
return latent_vectors
|
| 361 |
+
|
| 362 |
+
def encode_raw_text(
|
| 363 |
+
self,
|
| 364 |
+
texts: List[str],
|
| 365 |
+
unit_vector: Optional[bool] = None,
|
| 366 |
+
):
|
| 367 |
+
is_batched = True
|
| 368 |
+
if isinstance(texts, str):
|
| 369 |
+
is_batched = False
|
| 370 |
+
texts = [texts]
|
| 371 |
+
|
| 372 |
+
convert_ctx = torch.no_grad() if not self.compute_grads else contextlib.nullcontext()
|
| 373 |
+
|
| 374 |
+
with convert_ctx:
|
| 375 |
+
latent_vectors = self._encode(
|
| 376 |
+
texts,
|
| 377 |
+
modality="raw_text",
|
| 378 |
+
unit_vector=unit_vector,
|
| 379 |
+
)
|
| 380 |
+
if not is_batched:
|
| 381 |
+
latent_vectors = latent_vectors[0]
|
| 382 |
+
return latent_vectors
|
kimodo/model/twostage_denoiser.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Two-stage transformer denoiser: root stage then body stage for motion diffusion."""
|
| 4 |
+
|
| 5 |
+
import contextlib
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from .backbone import TransformerEncoderBlock
|
| 12 |
+
from .loading import load_checkpoint_state_dict
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TwostageDenoiser(nn.Module):
|
| 16 |
+
"""Two-stage denoiser: first predicts global root features, then body features conditioned on local root."""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
motion_rep,
|
| 21 |
+
motion_mask_mode,
|
| 22 |
+
ckpt_path: Optional[str] = None,
|
| 23 |
+
**kwargs,
|
| 24 |
+
):
|
| 25 |
+
"""Build root and body transformer blocks; optionally load checkpoint from ckpt_path."""
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.motion_rep = motion_rep
|
| 28 |
+
self.motion_mask_mode = motion_mask_mode
|
| 29 |
+
|
| 30 |
+
# it should be a dual motion_rep
|
| 31 |
+
# and be global by default
|
| 32 |
+
# global motion_rep as inpnut
|
| 33 |
+
input_dim = motion_rep.motion_rep_dim
|
| 34 |
+
will_concatenate = motion_mask_mode == "concat"
|
| 35 |
+
|
| 36 |
+
# stage 1: root only
|
| 37 |
+
root_input_dim = input_dim * 2 if will_concatenate else input_dim
|
| 38 |
+
root_output_dim = motion_rep.global_root_dim
|
| 39 |
+
|
| 40 |
+
self.root_model = TransformerEncoderBlock(
|
| 41 |
+
input_dim=root_input_dim,
|
| 42 |
+
output_dim=root_output_dim,
|
| 43 |
+
skeleton=self.motion_rep.skeleton,
|
| 44 |
+
**kwargs,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# replace the global root by the local root
|
| 48 |
+
local_motion_rep_dim = input_dim - motion_rep.global_root_dim + motion_rep.local_root_dim
|
| 49 |
+
|
| 50 |
+
# stage 2: local body
|
| 51 |
+
body_input_dim = local_motion_rep_dim + (
|
| 52 |
+
input_dim if will_concatenate else 0
|
| 53 |
+
) # body stage always takes in local root info for motion (but still the global mask)
|
| 54 |
+
|
| 55 |
+
body_output_dim = input_dim - motion_rep.global_root_dim
|
| 56 |
+
self.body_model = TransformerEncoderBlock(
|
| 57 |
+
input_dim=body_input_dim,
|
| 58 |
+
output_dim=body_output_dim,
|
| 59 |
+
skeleton=self.motion_rep.skeleton,
|
| 60 |
+
**kwargs,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
if ckpt_path:
|
| 64 |
+
self.load_ckpt(ckpt_path)
|
| 65 |
+
|
| 66 |
+
def load_ckpt(self, ckpt_path: str) -> None:
|
| 67 |
+
"""Load checkpoint from path; state dict keys are stripped of 'denoiser.backbone.'
|
| 68 |
+
prefix."""
|
| 69 |
+
state_dict = load_checkpoint_state_dict(ckpt_path)
|
| 70 |
+
state_dict = {key.replace("denoiser.backbone.", ""): val for key, val in state_dict.items()}
|
| 71 |
+
self.load_state_dict(state_dict)
|
| 72 |
+
|
| 73 |
+
def forward(
|
| 74 |
+
self,
|
| 75 |
+
x: torch.Tensor,
|
| 76 |
+
x_pad_mask: torch.Tensor,
|
| 77 |
+
text_feat: torch.Tensor,
|
| 78 |
+
text_feat_pad_mask: torch.Tensor,
|
| 79 |
+
timesteps: torch.Tensor,
|
| 80 |
+
first_heading_angle: Optional[torch.Tensor] = None,
|
| 81 |
+
motion_mask: Optional[torch.Tensor] = None,
|
| 82 |
+
observed_motion: Optional[torch.Tensor] = None,
|
| 83 |
+
) -> torch.Tensor:
|
| 84 |
+
"""
|
| 85 |
+
Args:
|
| 86 |
+
x (torch.Tensor): [B, T, dim_motion] current noisy motion
|
| 87 |
+
x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not
|
| 88 |
+
text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts
|
| 89 |
+
text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not
|
| 90 |
+
timesteps (torch.Tensor): [B,] current denoising step
|
| 91 |
+
motion_mask
|
| 92 |
+
observed_motion
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
torch.Tensor: same size as input x
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
if self.motion_mask_mode == "concat":
|
| 99 |
+
if motion_mask is None or observed_motion is None:
|
| 100 |
+
motion_mask = torch.zeros_like(x)
|
| 101 |
+
observed_motion = torch.zeros_like(x)
|
| 102 |
+
x = x * (1 - motion_mask) + observed_motion * motion_mask
|
| 103 |
+
x_extended = torch.cat([x, motion_mask], axis=-1)
|
| 104 |
+
else:
|
| 105 |
+
x_extended = x
|
| 106 |
+
|
| 107 |
+
# Stage 1: predict root motion in global
|
| 108 |
+
root_motion_pred = self.root_model(
|
| 109 |
+
x_extended,
|
| 110 |
+
x_pad_mask,
|
| 111 |
+
text_feat,
|
| 112 |
+
text_feat_pad_mask,
|
| 113 |
+
timesteps,
|
| 114 |
+
first_heading_angle,
|
| 115 |
+
) # [B, T, 5]
|
| 116 |
+
|
| 117 |
+
# Maybe pass this as argument instead of recomputing it
|
| 118 |
+
lengths = x_pad_mask.sum(-1)
|
| 119 |
+
|
| 120 |
+
# Convert root pred to local rep
|
| 121 |
+
# At test-time want to allow gradient through for guidance
|
| 122 |
+
convert_ctx = torch.no_grad() if self.training else contextlib.nullcontext()
|
| 123 |
+
with convert_ctx:
|
| 124 |
+
root_motion_local = self.motion_rep.global_root_to_local_root(
|
| 125 |
+
root_motion_pred,
|
| 126 |
+
normalized=True,
|
| 127 |
+
lengths=lengths,
|
| 128 |
+
)
|
| 129 |
+
if self.training:
|
| 130 |
+
root_motion_local = root_motion_local.detach()
|
| 131 |
+
|
| 132 |
+
# concatenate the predicted local root with the body motion
|
| 133 |
+
body_x = x[..., self.motion_rep.body_slice]
|
| 134 |
+
x_new = torch.cat([root_motion_local, body_x], axis=-1)
|
| 135 |
+
|
| 136 |
+
if self.motion_mask_mode == "concat":
|
| 137 |
+
x_new_extended = torch.cat([x_new, motion_mask], axis=-1)
|
| 138 |
+
else:
|
| 139 |
+
x_new_extended = x_new
|
| 140 |
+
|
| 141 |
+
# Stage 2: predict local body motion based on local root
|
| 142 |
+
predicted_body = self.body_model(
|
| 143 |
+
x_new_extended,
|
| 144 |
+
x_pad_mask,
|
| 145 |
+
text_feat,
|
| 146 |
+
text_feat_pad_mask,
|
| 147 |
+
timesteps,
|
| 148 |
+
first_heading_angle,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# concatenate the predicted local body with the predicted root
|
| 152 |
+
output = torch.cat([root_motion_pred, predicted_body], axis=-1)
|
| 153 |
+
return output
|
kimodo/motion_rep/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Motion representation utilities."""
|
| 4 |
+
|
| 5 |
+
from .reps import KimodoMotionRep, MotionRepBase, TMRMotionRep
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"MotionRepBase",
|
| 9 |
+
"KimodoMotionRep",
|
| 10 |
+
"TMRMotionRep",
|
| 11 |
+
]
|
kimodo/motion_rep/conditioning.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Constraint conditioning: build index and data dicts from constraint sets for the denoiser."""
|
| 4 |
+
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def build_condition_dicts(constraints_lst: list):
|
| 11 |
+
index_dict = defaultdict(list)
|
| 12 |
+
data_dict = defaultdict(list)
|
| 13 |
+
for constraint in constraints_lst:
|
| 14 |
+
constraint.update_constraints(data_dict, index_dict)
|
| 15 |
+
return index_dict, data_dict
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_unique_index_and_data(indices_lst, data):
|
| 19 |
+
# unique + sort them by t
|
| 20 |
+
indices_unique, inverse = torch.unique(indices_lst, dim=0, return_inverse=True)
|
| 21 |
+
# pick first value for each unique (t, j)
|
| 22 |
+
first_idx = torch.zeros(indices_unique.size(0), dtype=torch.long, device=inverse.device)
|
| 23 |
+
first_idx.scatter_(0, inverse, torch.arange(len(inverse), device=inverse.device))
|
| 24 |
+
assert (indices_lst[first_idx] == indices_unique).all()
|
| 25 |
+
# get the data
|
| 26 |
+
indices_lst = indices_lst[first_idx]
|
| 27 |
+
data = data[first_idx]
|
| 28 |
+
return indices_lst, data
|
kimodo/motion_rep/feature_utils.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Motion representation helpers: velocity, heading, masks, and rotation of features."""
|
| 4 |
+
|
| 5 |
+
from typing import List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import einops
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from kimodo.geometry import cont6d_to_matrix, matrix_to_cont6d
|
| 11 |
+
from kimodo.skeleton import SkeletonBase
|
| 12 |
+
from kimodo.tools import ensure_batched
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def diff_angles(angles: torch.Tensor, fps: float) -> torch.Tensor:
|
| 16 |
+
"""Compute frame-to-frame angular differences in radians, scaled by fps.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
angles: [..., T] batched sequences of rotation angles in radians.
|
| 20 |
+
fps: Sampling rate used to convert frame differences to per-second rate.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
[..., T-1] difference between consecutive angles (rad/s).
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
cos = torch.cos(angles)
|
| 27 |
+
sin = torch.sin(angles)
|
| 28 |
+
|
| 29 |
+
cos_diff = cos[..., 1:] * cos[..., :-1] + sin[..., 1:] * sin[..., :-1]
|
| 30 |
+
sin_diff = sin[..., 1:] * cos[..., :-1] - cos[..., 1:] * sin[..., :-1]
|
| 31 |
+
|
| 32 |
+
# should be close to angles.diff() but more robust
|
| 33 |
+
# multiply by fps = 1 / dt
|
| 34 |
+
angles_diff = fps * torch.arctan2(sin_diff, cos_diff)
|
| 35 |
+
return angles_diff
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@ensure_batched(positions=4, lengths=1)
|
| 39 |
+
def compute_vel_xyz(
|
| 40 |
+
positions: torch.Tensor,
|
| 41 |
+
fps: float,
|
| 42 |
+
lengths: Optional[torch.Tensor] = None,
|
| 43 |
+
) -> torch.Tensor:
|
| 44 |
+
"""Compute the velocities from positions: dx/dt. Works with batches. The last velocity is duplicated to keep the same size.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
positions (torch.Tensor): [..., T, J, 3] xyz positions of a human skeleton
|
| 48 |
+
fps (float): frame per seconds
|
| 49 |
+
lengths (Optional[torch.Tensor]): [...] size of each input batched. If not provided, positions should not be batched
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
velocity (torch.Tensor): [..., T, J, 3] velocities computed from the positions
|
| 53 |
+
"""
|
| 54 |
+
device = positions.device
|
| 55 |
+
|
| 56 |
+
if lengths is None:
|
| 57 |
+
assert positions.shape[0] == 1, "If lengths is not provided, the input should not be batched."
|
| 58 |
+
lengths = torch.tensor([len(positions)], device=device)
|
| 59 |
+
|
| 60 |
+
# useful for indexing
|
| 61 |
+
range_len = torch.arange(len(lengths))
|
| 62 |
+
|
| 63 |
+
# compute velocities with fps
|
| 64 |
+
velocity = fps * (positions[:, 1:] - positions[:, :-1])
|
| 65 |
+
# pading the velocity vector
|
| 66 |
+
vel_pad = torch.zeros_like(velocity[:, 0])
|
| 67 |
+
velocity, _ = einops.pack([velocity, vel_pad], "batch * nbjoints dim")
|
| 68 |
+
|
| 69 |
+
# repeat the last velocities
|
| 70 |
+
# with special care for different lengths with batches
|
| 71 |
+
velocity[(range_len, lengths - 1)] = velocity[(range_len, lengths - 2)]
|
| 72 |
+
return velocity
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@ensure_batched(root_rot_angles=2, lengths=1)
|
| 76 |
+
def compute_vel_angle(
|
| 77 |
+
root_rot_angles: torch.Tensor,
|
| 78 |
+
fps: float,
|
| 79 |
+
lengths: Optional[torch.Tensor] = None,
|
| 80 |
+
) -> torch.Tensor:
|
| 81 |
+
"""Compute the local root rotation velocity: dtheta/dt.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
root_rot_angles (torch.Tensor): [..., T] rotation angle (in radian)
|
| 85 |
+
fps (float): frame per seconds
|
| 86 |
+
lengths (Optional[torch.Tensor]): [...] size of each input batched. If not provided, root_rot_angles should not be batched
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
local_root_rot_vel (torch.Tensor): [..., T] local root rotation velocity (in radian/s)
|
| 90 |
+
"""
|
| 91 |
+
device = root_rot_angles.device
|
| 92 |
+
if lengths is None:
|
| 93 |
+
assert root_rot_angles.shape[0] == 1, "If lengths is not provided, the input should not be batched."
|
| 94 |
+
lengths = torch.tensor([len(root_rot_angles)], device=device)
|
| 95 |
+
|
| 96 |
+
# useful for indexing
|
| 97 |
+
range_len = torch.arange(len(lengths))
|
| 98 |
+
|
| 99 |
+
local_root_rot_vel = diff_angles(root_rot_angles, fps)
|
| 100 |
+
pad_rot_vel_angles = torch.zeros_like(root_rot_angles[:, 0])
|
| 101 |
+
local_root_rot_vel, _ = einops.pack(
|
| 102 |
+
[local_root_rot_vel, pad_rot_vel_angles],
|
| 103 |
+
"batch *",
|
| 104 |
+
)
|
| 105 |
+
# repeat the last rotation angle
|
| 106 |
+
# with special care for different lengths with batches
|
| 107 |
+
local_root_rot_vel[(range_len, lengths - 1)] = local_root_rot_vel[(range_len, lengths - 2)]
|
| 108 |
+
return local_root_rot_vel
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@ensure_batched(posed_joints=4)
|
| 112 |
+
def compute_heading_angle(posed_joints: torch.Tensor, skeleton: SkeletonBase) -> torch.Tensor:
|
| 113 |
+
"""Compute the heading direction from joint positions using the hip vector.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
posed_joints: [B, T, J, 3] global joint positions.
|
| 117 |
+
skeleton: Skeleton instance used to get hip joint indices.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
[B] heading angle in radians.
|
| 121 |
+
"""
|
| 122 |
+
# compute root heading for the sequence from hip positions
|
| 123 |
+
r_hip, l_hip = skeleton.hip_joint_idx
|
| 124 |
+
diff = posed_joints[:, :, r_hip] - posed_joints[:, :, l_hip]
|
| 125 |
+
heading_angle = torch.atan2(diff[..., 2], -diff[..., 0])
|
| 126 |
+
return heading_angle
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def length_to_mask(
|
| 130 |
+
length: Union[torch.Tensor, List],
|
| 131 |
+
max_len: Optional[int] = None,
|
| 132 |
+
device=None,
|
| 133 |
+
) -> torch.Tensor:
|
| 134 |
+
"""Convert sequence lengths to a boolean validity mask.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
length: Sequence lengths, either a tensor ``[B]`` or a Python list.
|
| 138 |
+
max_len: Optional mask width. If omitted, uses ``max(length)``.
|
| 139 |
+
device: Optional device. When ``length`` is a list, this controls where
|
| 140 |
+
the new tensor is created.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
A boolean tensor of shape ``[B, max_len]`` where ``True`` marks valid
|
| 144 |
+
timesteps.
|
| 145 |
+
"""
|
| 146 |
+
if isinstance(length, list):
|
| 147 |
+
if device is None:
|
| 148 |
+
device = "cpu"
|
| 149 |
+
length = torch.tensor(length, device=device)
|
| 150 |
+
|
| 151 |
+
# Use requested device for output; move length if needed so mask and length match
|
| 152 |
+
if device is not None:
|
| 153 |
+
target = torch.device(device)
|
| 154 |
+
if length.device != target:
|
| 155 |
+
length = length.to(target)
|
| 156 |
+
device = length.device
|
| 157 |
+
|
| 158 |
+
if max_len is None:
|
| 159 |
+
max_len = max(length)
|
| 160 |
+
|
| 161 |
+
mask = torch.arange(max_len, device=device).expand(len(length), max_len) < length.unsqueeze(1)
|
| 162 |
+
return mask
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class RotateFeatures:
|
| 166 |
+
"""Helper that applies a global heading rotation to motion features."""
|
| 167 |
+
|
| 168 |
+
def __init__(self, angle: torch.Tensor):
|
| 169 |
+
"""Precompute 2D and 3D rotation matrices for a batch of angles.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
angle: Rotation angle(s) in radians, shaped ``[B]``.
|
| 173 |
+
"""
|
| 174 |
+
self.angle = angle
|
| 175 |
+
|
| 176 |
+
## Create the necessary rotations matrices
|
| 177 |
+
cos, sin = torch.cos(angle), torch.sin(angle)
|
| 178 |
+
one, zero = torch.ones_like(angle), torch.zeros_like(angle)
|
| 179 |
+
|
| 180 |
+
# 2D rotation transposed (sin are -sin)
|
| 181 |
+
self.corrective_mat_2d_T = torch.stack((cos, sin, -sin, cos), -1).reshape(angle.shape + (2, 2))
|
| 182 |
+
# 3D rotation on Y axis
|
| 183 |
+
self.corrective_mat_Y = torch.stack((cos, zero, sin, zero, one, zero, -sin, zero, cos), -1).reshape(
|
| 184 |
+
angle.shape + (3, 3)
|
| 185 |
+
)
|
| 186 |
+
self.corrective_mat_Y_T = self.corrective_mat_Y.transpose(1, 2).contiguous()
|
| 187 |
+
|
| 188 |
+
def rotate_positions(self, positions: torch.Tensor):
|
| 189 |
+
"""Rotate 3D positions around the Y axis."""
|
| 190 |
+
return positions @ self.corrective_mat_Y_T
|
| 191 |
+
|
| 192 |
+
def rotate_2d_positions(self, positions_2d: torch.Tensor):
|
| 193 |
+
"""Rotate 2D ``(x, z)`` vectors in the ground plane."""
|
| 194 |
+
return positions_2d @ self.corrective_mat_2d_T
|
| 195 |
+
|
| 196 |
+
def rotate_rotations(self, rotations: torch.Tensor):
|
| 197 |
+
"""Left-multiply global rotation matrices by the heading correction."""
|
| 198 |
+
# "Rotate" the global rotations
|
| 199 |
+
# which means add an extra Y rotation after the transform
|
| 200 |
+
# so at the left R' = R_y R
|
| 201 |
+
# (since we use the convention x' = R x)
|
| 202 |
+
# "bik,btdkj->btdij"
|
| 203 |
+
|
| 204 |
+
B, T, J = rotations.shape[:3]
|
| 205 |
+
BTJ = B * T * J
|
| 206 |
+
return (
|
| 207 |
+
self.corrective_mat_Y[:, None, None].expand(B, T, J, 3, 3).reshape(BTJ, 3, 3) @ rotations.reshape(BTJ, 3, 3)
|
| 208 |
+
).reshape(B, T, J, 3, 3)
|
| 209 |
+
|
| 210 |
+
def rotate_6d_rotations(self, rotations_6d: torch.Tensor):
|
| 211 |
+
"""Rotate 6D rotation features via matrix conversion."""
|
| 212 |
+
return matrix_to_cont6d(self.rotate_rotations(cont6d_to_matrix(rotations_6d)))
|
kimodo/motion_rep/feet.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Foot contact detection from joint positions and velocities."""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ..tools import ensure_batched
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@ensure_batched(positions=4, velocity=4)
|
| 11 |
+
def foot_detect_from_pos_and_vel(
|
| 12 |
+
positions: torch.Tensor,
|
| 13 |
+
velocity: torch.Tensor,
|
| 14 |
+
skeleton,
|
| 15 |
+
vel_thres: float,
|
| 16 |
+
height_thresh: float,
|
| 17 |
+
) -> torch.Tensor:
|
| 18 |
+
"""Compute foot contact labels using heuristics combining joint height and velocities.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
positions (torch.Tensor): [X, T, J, 3] global joint positions
|
| 22 |
+
velocity (torch.Tensor): [X, T, J, 3] velocities (already padded correctly), already multiplied by 1 / dt
|
| 23 |
+
vel_thres (float): threshold for joint velocity
|
| 24 |
+
height_thresh (float): threshold for joint height
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
torch.Tensor: [X, T, 4] contact labels for left and right foot joints
|
| 28 |
+
(heel/toe order follows the skeleton joint index definition), where
|
| 29 |
+
``1`` denotes contact.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
device = positions.device
|
| 33 |
+
# Use at most 2 foot joints per side (ankle + toe); SOMA77 defines a
|
| 34 |
+
# third end-effector (ToeEnd) that SOMA30 and other skeletons omit.
|
| 35 |
+
fid_l = skeleton.left_foot_joint_idx[:2]
|
| 36 |
+
fid_r = skeleton.right_foot_joint_idx[:2]
|
| 37 |
+
|
| 38 |
+
velfactor, heightfactor = (
|
| 39 |
+
torch.tensor([vel_thres, vel_thres], device=device),
|
| 40 |
+
torch.tensor([height_thresh, height_thresh], device=device),
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
feet_l_v = torch.linalg.norm(velocity[:, :, fid_l], axis=-1)
|
| 44 |
+
feet_l_h = positions[:, :, fid_l, 1]
|
| 45 |
+
|
| 46 |
+
feet_l = torch.logical_and(
|
| 47 |
+
feet_l_v < velfactor,
|
| 48 |
+
feet_l_h < heightfactor,
|
| 49 |
+
).to(positions.dtype)
|
| 50 |
+
|
| 51 |
+
feet_r_v = torch.linalg.norm(velocity[:, :, fid_r], axis=-1)
|
| 52 |
+
feet_r_h = positions[:, :, fid_r, 1]
|
| 53 |
+
|
| 54 |
+
feet_r = torch.logical_and(
|
| 55 |
+
feet_r_v < velfactor,
|
| 56 |
+
feet_r_h < heightfactor,
|
| 57 |
+
).to(positions.dtype)
|
| 58 |
+
|
| 59 |
+
foot_contacts = torch.cat((feet_l, feet_r), axis=-1)
|
| 60 |
+
return foot_contacts
|
kimodo/motion_rep/reps/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Motion representation implementations: base, Kimodo, and TMR."""
|
| 4 |
+
|
| 5 |
+
from .base import MotionRepBase
|
| 6 |
+
from .kimodo_motionrep import KimodoMotionRep
|
| 7 |
+
from .tmr_motionrep import TMRMotionRep
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"MotionRepBase",
|
| 11 |
+
"KimodoMotionRep",
|
| 12 |
+
"TMRMotionRep",
|
| 13 |
+
]
|
kimodo/motion_rep/reps/base.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Base motion representation: feature layout, normalization, and conditioning helpers."""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import einops
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from einops import repeat
|
| 12 |
+
|
| 13 |
+
from ...tools import ensure_batched
|
| 14 |
+
from ..conditioning import build_condition_dicts
|
| 15 |
+
from ..feature_utils import compute_vel_angle, compute_vel_xyz
|
| 16 |
+
from ..stats import Stats
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _require_split_stats_layout(stats_path: str) -> None:
|
| 20 |
+
"""Raise if stats_path does not contain the required global_root, local_root, body subdirs."""
|
| 21 |
+
subdirs = ("global_root", "local_root", "body")
|
| 22 |
+
missing = []
|
| 23 |
+
for name in subdirs:
|
| 24 |
+
subpath = os.path.join(stats_path, name)
|
| 25 |
+
mean_path = os.path.join(subpath, "mean.npy")
|
| 26 |
+
if not os.path.isfile(mean_path):
|
| 27 |
+
missing.append(f"{subpath}/ (mean.npy)")
|
| 28 |
+
if missing:
|
| 29 |
+
raise FileNotFoundError(
|
| 30 |
+
f"Checkpoint stats must use the split layout with subfolders "
|
| 31 |
+
f"global_root/, local_root/, and body/ under '{stats_path}'. "
|
| 32 |
+
f"Missing or incomplete: {', '.join(missing)}. "
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class MotionRepBase:
|
| 37 |
+
"""Base class for motion representations used in generation and conditioning.
|
| 38 |
+
|
| 39 |
+
Subclasses define:
|
| 40 |
+
- ``size_dict``: feature blocks and their shapes,
|
| 41 |
+
- ``last_root_feature``: last entry of the root block,
|
| 42 |
+
- ``local_root_size_dict``: local-root feature layout,
|
| 43 |
+
and implement transform-specific methods such as ``__call__``, ``inverse``,
|
| 44 |
+
``rotate``, ``translate_2d`` and ``create_conditions``.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
skeleton,
|
| 50 |
+
fps,
|
| 51 |
+
stats_path: Optional[str] = None,
|
| 52 |
+
):
|
| 53 |
+
"""Initialize feature slicing metadata and optional normalization stats."""
|
| 54 |
+
|
| 55 |
+
self.skeleton = skeleton
|
| 56 |
+
self.fps = fps
|
| 57 |
+
self.nbjoints = skeleton.nbjoints
|
| 58 |
+
|
| 59 |
+
self.feature_names = list(self.size_dict.keys())
|
| 60 |
+
self.ps = list(self.size_dict.values())
|
| 61 |
+
self.nfeats_dict = {key: val.numel() for key, val in self.size_dict.items()}
|
| 62 |
+
feats_cumsum = np.cumsum([0] + list(self.nfeats_dict.values())).tolist()
|
| 63 |
+
self.slice_dict = {key: slice(feats_cumsum[i], feats_cumsum[i + 1]) for i, key in enumerate(self.feature_names)}
|
| 64 |
+
|
| 65 |
+
self.motion_rep_dim = sum(self.nfeats_dict.values())
|
| 66 |
+
self.root_slice = slice(0, self.slice_dict[self.last_root_feature].stop)
|
| 67 |
+
self.body_slice = slice(self.root_slice.stop, self.motion_rep_dim)
|
| 68 |
+
self.body_dim = self.body_slice.stop - self.body_slice.start
|
| 69 |
+
self.global_root_dim = self.root_slice.stop
|
| 70 |
+
self.local_root_dim = sum(val.numel() for val in self.local_root_size_dict.values())
|
| 71 |
+
|
| 72 |
+
if stats_path:
|
| 73 |
+
_require_split_stats_layout(stats_path)
|
| 74 |
+
self.global_root_stats = Stats(os.path.join(stats_path, "global_root"))
|
| 75 |
+
self.local_root_stats = Stats(os.path.join(stats_path, "local_root"))
|
| 76 |
+
self.body_stats = Stats(os.path.join(stats_path, "body"))
|
| 77 |
+
# self.stats not set; normalize/unnormalize apply per-part below
|
| 78 |
+
|
| 79 |
+
def get_root_pos(self, features: torch.Tensor, fallback_to_smooth: bool = True):
|
| 80 |
+
"""Extract root positions from a feature tensor.
|
| 81 |
+
|
| 82 |
+
Supports both ``root_pos`` and ``smooth_root_pos`` representations.
|
| 83 |
+
"""
|
| 84 |
+
if "root_pos" in self.slice_dict:
|
| 85 |
+
return features[..., self.slice_dict["root_pos"]]
|
| 86 |
+
|
| 87 |
+
if "smooth_root_pos" not in self.slice_dict:
|
| 88 |
+
raise TypeError("This motion rep should have either a root_pos or smooth_root_pos field")
|
| 89 |
+
|
| 90 |
+
if fallback_to_smooth:
|
| 91 |
+
return features[:, :, self.slice_dict["smooth_root_pos"]]
|
| 92 |
+
|
| 93 |
+
# else compute the root pos from the smooth root and local joints offset
|
| 94 |
+
smooth_root_pos = features[:, :, self.slice_dict["smooth_root_pos"]].clone()
|
| 95 |
+
local_joints_positions_flatten = features[..., self.slice_dict["local_joints_positions"]]
|
| 96 |
+
hips_offset = local_joints_positions_flatten[..., self.skeleton.root_idx : self.skeleton.root_idx + 3]
|
| 97 |
+
root_pos = torch.stack(
|
| 98 |
+
[
|
| 99 |
+
smooth_root_pos[..., 0] + hips_offset[..., 0],
|
| 100 |
+
smooth_root_pos[..., 1],
|
| 101 |
+
smooth_root_pos[..., 2] + hips_offset[..., 2],
|
| 102 |
+
],
|
| 103 |
+
axis=-1,
|
| 104 |
+
)
|
| 105 |
+
return root_pos
|
| 106 |
+
|
| 107 |
+
@ensure_batched(root_features=3, lengths=1)
|
| 108 |
+
def global_root_to_local_root(
|
| 109 |
+
self,
|
| 110 |
+
root_features: torch.Tensor,
|
| 111 |
+
normalized: bool,
|
| 112 |
+
lengths: Optional[torch.Tensor],
|
| 113 |
+
):
|
| 114 |
+
"""Convert global root features to local-root motion features.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
root_features: Root feature tensor containing root position and
|
| 118 |
+
global heading, shaped ``[B, T, D_root]``.
|
| 119 |
+
normalized: Whether ``root_features`` are normalized.
|
| 120 |
+
lengths: Optional valid lengths per sequence.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Tensor ``[B, T, 4]`` with local root rotational velocity, planar
|
| 124 |
+
velocity, and global root height.
|
| 125 |
+
"""
|
| 126 |
+
if normalized:
|
| 127 |
+
root_features = self.global_root_stats.unnormalize(root_features)
|
| 128 |
+
|
| 129 |
+
[root_pos, global_root_heading] = einops.unpack(root_features, self.ps[:2], "batch time *")
|
| 130 |
+
cos, sin = global_root_heading.unbind(-1)
|
| 131 |
+
heading_angle = torch.arctan2(sin, cos)
|
| 132 |
+
|
| 133 |
+
local_root_rot_vel = compute_vel_angle(heading_angle, self.fps, lengths=lengths)
|
| 134 |
+
local_root_vel = compute_vel_xyz(
|
| 135 |
+
root_pos[..., None, :],
|
| 136 |
+
self.fps,
|
| 137 |
+
lengths=lengths,
|
| 138 |
+
)[..., 0, [0, 2]]
|
| 139 |
+
global_root_y = root_pos[..., 1]
|
| 140 |
+
local_root_motion = torch.cat(
|
| 141 |
+
[
|
| 142 |
+
local_root_rot_vel[..., None],
|
| 143 |
+
local_root_vel,
|
| 144 |
+
global_root_y[..., None],
|
| 145 |
+
],
|
| 146 |
+
axis=-1,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
if normalized:
|
| 150 |
+
local_root_motion = self.local_root_stats.normalize(local_root_motion)
|
| 151 |
+
return local_root_motion
|
| 152 |
+
|
| 153 |
+
def get_root_heading_angle(self, features: torch.Tensor) -> torch.Tensor:
|
| 154 |
+
"""Compute root heading angle from cosine/sine heading features."""
|
| 155 |
+
global_root_heading = features[:, :, self.slice_dict["global_root_heading"]]
|
| 156 |
+
cos, sin = global_root_heading.unbind(-1)
|
| 157 |
+
return torch.arctan2(sin, cos)
|
| 158 |
+
|
| 159 |
+
@ensure_batched(features=3)
|
| 160 |
+
def rotate_to(
|
| 161 |
+
self,
|
| 162 |
+
features: torch.Tensor,
|
| 163 |
+
target_angle: torch.Tensor,
|
| 164 |
+
return_delta_angle=False,
|
| 165 |
+
):
|
| 166 |
+
"""Rotate each sequence so frame-0 heading matches ``target_angle``."""
|
| 167 |
+
# rotate so that the first frame angle is the target
|
| 168 |
+
# it put the motion_rep to the angle
|
| 169 |
+
current_first_angle = self.get_root_heading_angle(features)[:, 0]
|
| 170 |
+
delta_angle = target_angle - current_first_angle
|
| 171 |
+
rotated_features = self.rotate(features, delta_angle)
|
| 172 |
+
if return_delta_angle:
|
| 173 |
+
return rotated_features, delta_angle
|
| 174 |
+
return rotated_features
|
| 175 |
+
|
| 176 |
+
@ensure_batched(features=3)
|
| 177 |
+
def rotate_to_zero(
|
| 178 |
+
self,
|
| 179 |
+
features: torch.Tensor,
|
| 180 |
+
return_delta_angle=False,
|
| 181 |
+
):
|
| 182 |
+
"""Rotate each sequence so frame-0 heading becomes zero."""
|
| 183 |
+
target_angle = torch.zeros(len(features), device=features.device)
|
| 184 |
+
return self.rotate_to(features, target_angle, return_delta_angle=return_delta_angle)
|
| 185 |
+
|
| 186 |
+
@ensure_batched(features=3)
|
| 187 |
+
def randomize_first_heading(
|
| 188 |
+
self,
|
| 189 |
+
features: torch.Tensor,
|
| 190 |
+
return_delta_angle=False,
|
| 191 |
+
) -> torch.Tensor:
|
| 192 |
+
"""Rotate each sequence to a random frame-0 heading."""
|
| 193 |
+
target_heading_angle = torch.rand(features.shape[0]) * 2 * np.pi
|
| 194 |
+
return self.rotate_to(
|
| 195 |
+
features,
|
| 196 |
+
target_heading_angle,
|
| 197 |
+
return_delta_angle=return_delta_angle,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
@ensure_batched(features=3, target_2d_pos=2)
|
| 201 |
+
def translate_2d_to(
|
| 202 |
+
self,
|
| 203 |
+
features: torch.Tensor,
|
| 204 |
+
target_2d_pos: torch.Tensor,
|
| 205 |
+
return_delta_pos: bool = False,
|
| 206 |
+
) -> torch.Tensor:
|
| 207 |
+
"""Translate each sequence so frame-0 root ``(x, z)`` matches a target."""
|
| 208 |
+
root_pos = self.get_root_pos(features)
|
| 209 |
+
current_first_2d_pos = root_pos[:, 0, [0, 2]].clone()
|
| 210 |
+
delta_2d_pos = target_2d_pos - current_first_2d_pos
|
| 211 |
+
translated_features = self.translate_2d(features, delta_2d_pos)
|
| 212 |
+
if return_delta_pos:
|
| 213 |
+
return translated_features, delta_2d_pos
|
| 214 |
+
return translated_features
|
| 215 |
+
|
| 216 |
+
@ensure_batched(features=3)
|
| 217 |
+
def translate_2d_to_zero(
|
| 218 |
+
self,
|
| 219 |
+
features: torch.Tensor,
|
| 220 |
+
return_delta_pos: bool = False,
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
"""Translate each sequence so frame-0 root ``(x, z)`` is at the origin."""
|
| 223 |
+
target_2d_pos = torch.zeros(len(features), 2, device=features.device)
|
| 224 |
+
return self.translate_2d_to(features, target_2d_pos, return_delta_pos=return_delta_pos)
|
| 225 |
+
|
| 226 |
+
@ensure_batched(features=3)
|
| 227 |
+
def canonicalize(self, features: torch.Tensor):
|
| 228 |
+
"""Canonicalize heading and planar position at frame 0."""
|
| 229 |
+
rotated_features = self.rotate_to_zero(features)
|
| 230 |
+
return self.translate_2d_to_zero(rotated_features)
|
| 231 |
+
|
| 232 |
+
def normalize(self, features):
|
| 233 |
+
"""Normalize features using per-part stats (global_root, local_root, body)."""
|
| 234 |
+
gr = slice(0, self.global_root_dim)
|
| 235 |
+
lr = slice(self.global_root_dim, self.global_root_dim + self.local_root_dim)
|
| 236 |
+
out = torch.empty_like(features, device=features.device, dtype=features.dtype)
|
| 237 |
+
out[..., gr] = self.global_root_stats.normalize(features[..., gr])
|
| 238 |
+
out[..., lr] = self.local_root_stats.normalize(features[..., lr])
|
| 239 |
+
out[..., self.body_slice] = self.body_stats.normalize(features[..., self.body_slice])
|
| 240 |
+
return out
|
| 241 |
+
|
| 242 |
+
def unnormalize(self, features):
|
| 243 |
+
"""Undo feature normalization using per-part stats."""
|
| 244 |
+
gr = slice(0, self.global_root_dim)
|
| 245 |
+
lr = slice(self.global_root_dim, self.global_root_dim + self.local_root_dim)
|
| 246 |
+
out = torch.empty_like(features, device=features.device, dtype=features.dtype)
|
| 247 |
+
out[..., gr] = self.global_root_stats.unnormalize(features[..., gr])
|
| 248 |
+
out[..., lr] = self.local_root_stats.unnormalize(features[..., lr])
|
| 249 |
+
out[..., self.body_slice] = self.body_stats.unnormalize(features[..., self.body_slice])
|
| 250 |
+
return out
|
| 251 |
+
|
| 252 |
+
def create_conditions_from_constraints(
|
| 253 |
+
self,
|
| 254 |
+
constraints_lst: list,
|
| 255 |
+
length: int,
|
| 256 |
+
to_normalize: bool,
|
| 257 |
+
device: str,
|
| 258 |
+
):
|
| 259 |
+
"""Create a conditioning tensor and mask from constraint objects."""
|
| 260 |
+
index_dict, data_dict = build_condition_dicts(constraints_lst)
|
| 261 |
+
return self.create_conditions(index_dict, data_dict, length, to_normalize, device)
|
| 262 |
+
|
| 263 |
+
def create_conditions_from_constraints_batched(
|
| 264 |
+
self,
|
| 265 |
+
constraints_lst: list | list[list],
|
| 266 |
+
lengths: torch.Tensor,
|
| 267 |
+
to_normalize: bool,
|
| 268 |
+
device: str,
|
| 269 |
+
):
|
| 270 |
+
"""Batched version of ``create_conditions_from_constraints``.
|
| 271 |
+
|
| 272 |
+
Supports either one shared constraint list for all batch elements, or a per-sample list of
|
| 273 |
+
constraint lists.
|
| 274 |
+
"""
|
| 275 |
+
num_samples = len(lengths)
|
| 276 |
+
if not constraints_lst or not isinstance(constraints_lst[0], list):
|
| 277 |
+
# If no constraints, or constraints are shared across the batch,
|
| 278 |
+
# build once and repeat.
|
| 279 |
+
observed_motion, motion_mask = self.create_conditions_from_constraints(
|
| 280 |
+
constraints_lst, int(lengths.max()), to_normalize, device
|
| 281 |
+
)
|
| 282 |
+
observed_motion = repeat(observed_motion, "t d -> b t d", b=num_samples)
|
| 283 |
+
motion_mask = repeat(motion_mask, "t d -> b t d", b=num_samples)
|
| 284 |
+
return observed_motion, motion_mask
|
| 285 |
+
|
| 286 |
+
length = int(lengths.max())
|
| 287 |
+
observed_motion_lst = []
|
| 288 |
+
motion_mask_lst = []
|
| 289 |
+
for constraints_lst_el in constraints_lst:
|
| 290 |
+
observed_motion, motion_mask = self.create_conditions_from_constraints(
|
| 291 |
+
constraints_lst_el,
|
| 292 |
+
length,
|
| 293 |
+
to_normalize,
|
| 294 |
+
device,
|
| 295 |
+
)
|
| 296 |
+
observed_motion_lst.append(observed_motion)
|
| 297 |
+
motion_mask_lst.append(motion_mask)
|
| 298 |
+
observed_motion = torch.stack(observed_motion_lst, axis=0)
|
| 299 |
+
motion_mask = torch.stack(motion_mask_lst, axis=0)
|
| 300 |
+
return observed_motion, motion_mask
|
kimodo/motion_rep/reps/kimodo_motionrep.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import einops
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
from kimodo.tools import to_numpy
|
| 11 |
+
|
| 12 |
+
from ...geometry import cont6d_to_matrix, matrix_to_cont6d
|
| 13 |
+
from ...skeleton.kinematics import fk
|
| 14 |
+
from ...skeleton.transforms import global_rots_to_local_rots
|
| 15 |
+
from ...tools import ensure_batched
|
| 16 |
+
from ..conditioning import get_unique_index_and_data
|
| 17 |
+
from ..feature_utils import RotateFeatures, compute_heading_angle, compute_vel_xyz
|
| 18 |
+
from ..feet import foot_detect_from_pos_and_vel
|
| 19 |
+
from ..smooth_root import get_smooth_root_pos
|
| 20 |
+
from .base import MotionRepBase
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class KimodoMotionRep(MotionRepBase):
|
| 24 |
+
"""Global root / global joints rotations representation, relative to a smooth root."""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
skeleton,
|
| 29 |
+
fps,
|
| 30 |
+
stats_path: Optional[str] = None,
|
| 31 |
+
):
|
| 32 |
+
nbjoints = skeleton.nbjoints
|
| 33 |
+
|
| 34 |
+
self.size_dict = {
|
| 35 |
+
"smooth_root_pos": torch.Size([3]),
|
| 36 |
+
"global_root_heading": torch.Size([2]),
|
| 37 |
+
"local_joints_positions": torch.Size([nbjoints, 3]),
|
| 38 |
+
"global_rot_data": torch.Size([nbjoints, 6]),
|
| 39 |
+
"velocities": torch.Size([nbjoints, 3]),
|
| 40 |
+
"foot_contacts": torch.Size([4]),
|
| 41 |
+
}
|
| 42 |
+
self.last_root_feature = "global_root_heading"
|
| 43 |
+
self.local_root_size_dict = {
|
| 44 |
+
"local_root_rot_vel": torch.Size([1]),
|
| 45 |
+
"local_root_vel": torch.Size([2]),
|
| 46 |
+
"global_root_y": torch.Size([1]),
|
| 47 |
+
}
|
| 48 |
+
super().__init__(skeleton, fps, stats_path)
|
| 49 |
+
|
| 50 |
+
@ensure_batched(local_joint_rots=5, root_positions=3, lengths=1)
|
| 51 |
+
def __call__(
|
| 52 |
+
self,
|
| 53 |
+
local_joint_rots: torch.Tensor,
|
| 54 |
+
root_positions: torch.Tensor,
|
| 55 |
+
to_normalize: bool,
|
| 56 |
+
lengths: Optional[torch.Tensor] = None,
|
| 57 |
+
) -> torch.Tensor:
|
| 58 |
+
"""Convert local rotations and root trajectory into smooth-root features.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
local_joint_rots: Local joint rotation matrices ``[B, T, J, 3, 3]``.
|
| 62 |
+
root_positions: Root positions ``[B, T, 3]``.
|
| 63 |
+
to_normalize: Whether to normalize output features.
|
| 64 |
+
lengths: Optional valid lengths for variable-length batches.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Motion features with shape ``[B, T, motion_rep_dim]``.
|
| 68 |
+
"""
|
| 69 |
+
device = local_joint_rots.device
|
| 70 |
+
if lengths is None:
|
| 71 |
+
assert local_joint_rots.shape[0] == 1, "If lenghts is not provided, the input should not be batched."
|
| 72 |
+
lengths = torch.tensor([local_joint_rots.shape[1]], device=device)
|
| 73 |
+
|
| 74 |
+
(
|
| 75 |
+
global_joints_rots,
|
| 76 |
+
global_joints_positions,
|
| 77 |
+
local_joints_positions_origin_is_pelvis,
|
| 78 |
+
) = fk(local_joint_rots, root_positions, self.skeleton)
|
| 79 |
+
|
| 80 |
+
root_heading_angle = compute_heading_angle(global_joints_positions, self.skeleton)
|
| 81 |
+
global_root_heading = torch.stack([torch.cos(root_heading_angle), torch.sin(root_heading_angle)], dim=-1)
|
| 82 |
+
|
| 83 |
+
smooth_root_pos = get_smooth_root_pos(root_positions)
|
| 84 |
+
hips_offset = root_positions - smooth_root_pos
|
| 85 |
+
hips_offset[..., 1] = root_positions[..., 1]
|
| 86 |
+
local_joints_positions = local_joints_positions_origin_is_pelvis + hips_offset[:, :, None]
|
| 87 |
+
|
| 88 |
+
velocities = compute_vel_xyz(global_joints_positions, self.fps, lengths=lengths)
|
| 89 |
+
foot_contacts = foot_detect_from_pos_and_vel(global_joints_positions, velocities, self.skeleton, 0.15, 0.10)
|
| 90 |
+
global_rot_data = matrix_to_cont6d(global_joints_rots)
|
| 91 |
+
|
| 92 |
+
features, _ = einops.pack(
|
| 93 |
+
[
|
| 94 |
+
smooth_root_pos,
|
| 95 |
+
global_root_heading,
|
| 96 |
+
local_joints_positions,
|
| 97 |
+
global_rot_data,
|
| 98 |
+
velocities,
|
| 99 |
+
foot_contacts,
|
| 100 |
+
],
|
| 101 |
+
"batch time *",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if to_normalize:
|
| 105 |
+
features = self.normalize(features)
|
| 106 |
+
return features
|
| 107 |
+
|
| 108 |
+
@ensure_batched(features=3, angle=1)
|
| 109 |
+
def rotate(self, features: torch.Tensor, angle: torch.Tensor):
|
| 110 |
+
"""Rotate root/joint positional and rotational features by heading."""
|
| 111 |
+
# assume it is not normalized
|
| 112 |
+
bs = features.shape[0]
|
| 113 |
+
device = features.device
|
| 114 |
+
[
|
| 115 |
+
smooth_root_pos,
|
| 116 |
+
global_root_heading,
|
| 117 |
+
local_joints_positions,
|
| 118 |
+
global_rot_data,
|
| 119 |
+
velocities,
|
| 120 |
+
foot_contacts,
|
| 121 |
+
] = einops.unpack(features, self.ps, "batch time *")
|
| 122 |
+
|
| 123 |
+
if not isinstance(angle, torch.Tensor):
|
| 124 |
+
angle = torch.tensor(angle, device=device)
|
| 125 |
+
if len(angle.shape) == 0:
|
| 126 |
+
angle = angle.repeat(bs)
|
| 127 |
+
|
| 128 |
+
RF = RotateFeatures(angle)
|
| 129 |
+
new_features, _ = einops.pack(
|
| 130 |
+
[
|
| 131 |
+
RF.rotate_positions(smooth_root_pos),
|
| 132 |
+
RF.rotate_2d_positions(global_root_heading),
|
| 133 |
+
RF.rotate_positions(local_joints_positions),
|
| 134 |
+
RF.rotate_6d_rotations(global_rot_data),
|
| 135 |
+
RF.rotate_positions(velocities),
|
| 136 |
+
foot_contacts,
|
| 137 |
+
],
|
| 138 |
+
"batch time *",
|
| 139 |
+
)
|
| 140 |
+
return new_features
|
| 141 |
+
|
| 142 |
+
@ensure_batched(features=3, translation_2d=2)
|
| 143 |
+
def translate_2d(
|
| 144 |
+
self,
|
| 145 |
+
features: torch.Tensor,
|
| 146 |
+
translation_2d: torch.Tensor,
|
| 147 |
+
) -> torch.Tensor:
|
| 148 |
+
"""Translate smooth root planar position by ``(dx, dz)``."""
|
| 149 |
+
# only move on the ground
|
| 150 |
+
# If we need a translate_3D function, we should not forget to move the local_joints_positions as well
|
| 151 |
+
bs = features.shape[0]
|
| 152 |
+
if len(translation_2d.shape) == 1:
|
| 153 |
+
translation_2d = translation_2d.repeat(bs, 1)
|
| 154 |
+
|
| 155 |
+
new_features = features.clone()
|
| 156 |
+
new_smooth_root_pos = new_features[:, :, self.slice_dict["smooth_root_pos"]]
|
| 157 |
+
new_smooth_root_pos[:, :, 0] += translation_2d[:, [0]]
|
| 158 |
+
new_smooth_root_pos[:, :, 2] += translation_2d[:, [1]]
|
| 159 |
+
return new_features
|
| 160 |
+
|
| 161 |
+
@ensure_batched(features=3)
|
| 162 |
+
def inverse(
|
| 163 |
+
self,
|
| 164 |
+
features: torch.Tensor,
|
| 165 |
+
is_normalized: bool,
|
| 166 |
+
posed_joints_from="rotations",
|
| 167 |
+
return_numpy: bool = False,
|
| 168 |
+
) -> torch.Tensor:
|
| 169 |
+
"""Decode smooth-root features into motion tensors."""
|
| 170 |
+
assert posed_joints_from in [
|
| 171 |
+
"rotations",
|
| 172 |
+
"positions",
|
| 173 |
+
], "posed_joints_from should 'rotations' or 'positions'"
|
| 174 |
+
|
| 175 |
+
if is_normalized:
|
| 176 |
+
features = self.unnormalize(features)
|
| 177 |
+
|
| 178 |
+
[
|
| 179 |
+
smooth_root_pos,
|
| 180 |
+
global_root_heading,
|
| 181 |
+
local_joints_positions,
|
| 182 |
+
global_rot_data,
|
| 183 |
+
velocities,
|
| 184 |
+
foot_contacts,
|
| 185 |
+
] = einops.unpack(features, self.ps, "batch time *")
|
| 186 |
+
|
| 187 |
+
global_rot_mats = cont6d_to_matrix(global_rot_data)
|
| 188 |
+
local_rot_mats = global_rots_to_local_rots(global_rot_mats, self.skeleton)
|
| 189 |
+
|
| 190 |
+
posed_joints_from_pos = local_joints_positions.clone()
|
| 191 |
+
posed_joints_from_pos[..., 0] += smooth_root_pos[..., None, 0]
|
| 192 |
+
posed_joints_from_pos[..., 2] += smooth_root_pos[..., None, 2]
|
| 193 |
+
root_positions = posed_joints_from_pos[..., self.skeleton.root_idx, :]
|
| 194 |
+
foot_contacts = foot_contacts > 0.5
|
| 195 |
+
|
| 196 |
+
if posed_joints_from == "rotations":
|
| 197 |
+
_, posed_joints, _ = self.skeleton.fk(
|
| 198 |
+
local_rot_mats,
|
| 199 |
+
root_positions,
|
| 200 |
+
)
|
| 201 |
+
else:
|
| 202 |
+
posed_joints = posed_joints_from_pos
|
| 203 |
+
|
| 204 |
+
output_tensor_dict = {
|
| 205 |
+
"local_rot_mats": local_rot_mats,
|
| 206 |
+
"global_rot_mats": global_rot_mats,
|
| 207 |
+
"posed_joints": posed_joints,
|
| 208 |
+
"root_positions": root_positions,
|
| 209 |
+
"smooth_root_pos": smooth_root_pos,
|
| 210 |
+
"foot_contacts": foot_contacts,
|
| 211 |
+
"global_root_heading": global_root_heading,
|
| 212 |
+
}
|
| 213 |
+
if return_numpy:
|
| 214 |
+
return to_numpy(output_tensor_dict)
|
| 215 |
+
return output_tensor_dict
|
| 216 |
+
|
| 217 |
+
def create_conditions(
|
| 218 |
+
self,
|
| 219 |
+
index_dict: dict[Tensor],
|
| 220 |
+
data_dict: dict[Tensor],
|
| 221 |
+
length: int,
|
| 222 |
+
to_normalize: bool,
|
| 223 |
+
device: str,
|
| 224 |
+
):
|
| 225 |
+
"""Build sparse conditioning tensors for smooth-root representation."""
|
| 226 |
+
# create empty features and mask to be filled in
|
| 227 |
+
observed_motion = torch.zeros(length, self.motion_rep_dim, device=device)
|
| 228 |
+
motion_mask = torch.zeros(length, self.motion_rep_dim, dtype=bool, device=device)
|
| 229 |
+
|
| 230 |
+
def _cat_indices(indices_list: list[Tensor]) -> Tensor:
|
| 231 |
+
indices = torch.cat([torch.tensor(x) if not isinstance(x, Tensor) else x for x in indices_list])
|
| 232 |
+
return indices.to(device=device, dtype=torch.long)
|
| 233 |
+
|
| 234 |
+
def _match_obs_dtype(tensor: Tensor) -> Tensor:
|
| 235 |
+
return tensor.to(device=device, dtype=observed_motion.dtype)
|
| 236 |
+
|
| 237 |
+
if (fname := "smooth_root_2d") in index_dict and index_dict[fname]:
|
| 238 |
+
indices = _cat_indices(index_dict[fname])
|
| 239 |
+
indices, smooth_root_2d = get_unique_index_and_data(indices, torch.cat(data_dict[fname]))
|
| 240 |
+
smooth_root_2d = _match_obs_dtype(smooth_root_2d)
|
| 241 |
+
f_sliced = observed_motion[:, self.slice_dict["smooth_root_pos"]]
|
| 242 |
+
f_sliced[indices, 0] = smooth_root_2d[:, 0]
|
| 243 |
+
f_sliced[indices, 2] = smooth_root_2d[:, 1]
|
| 244 |
+
m_sliced = motion_mask[:, self.slice_dict["smooth_root_pos"]]
|
| 245 |
+
m_sliced[indices, 0] = True
|
| 246 |
+
m_sliced[indices, 2] = True
|
| 247 |
+
|
| 248 |
+
if (fname := "root_y_pos") in index_dict and index_dict[fname]:
|
| 249 |
+
indices = _cat_indices(index_dict[fname])
|
| 250 |
+
indices, root_pos_Y = get_unique_index_and_data(indices, torch.cat(data_dict[fname]))
|
| 251 |
+
root_pos_Y = _match_obs_dtype(root_pos_Y)
|
| 252 |
+
f_sliced = observed_motion[:, self.slice_dict["smooth_root_pos"]]
|
| 253 |
+
f_sliced[indices, 1] = root_pos_Y
|
| 254 |
+
m_sliced = motion_mask[:, self.slice_dict["smooth_root_pos"]]
|
| 255 |
+
m_sliced[indices, 1] = True
|
| 256 |
+
|
| 257 |
+
if (fname := "global_root_heading") in index_dict and index_dict[fname]:
|
| 258 |
+
indices = _cat_indices(index_dict[fname])
|
| 259 |
+
indices, global_root_heading = get_unique_index_and_data(indices, torch.cat(data_dict[fname]))
|
| 260 |
+
global_root_heading = _match_obs_dtype(global_root_heading)
|
| 261 |
+
f_sliced = observed_motion[:, self.slice_dict[fname]]
|
| 262 |
+
f_sliced[indices] = global_root_heading
|
| 263 |
+
m_sliced = motion_mask[:, self.slice_dict[fname]]
|
| 264 |
+
m_sliced[indices] = True
|
| 265 |
+
|
| 266 |
+
if (fname := "global_joints_rots") in index_dict and index_dict[fname]:
|
| 267 |
+
indices_lst = _cat_indices(index_dict[fname])
|
| 268 |
+
indices_lst, global_joints_rots = get_unique_index_and_data(indices_lst, torch.cat(data_dict[fname]))
|
| 269 |
+
global_joints_rots = _match_obs_dtype(global_joints_rots)
|
| 270 |
+
global_rot_data = matrix_to_cont6d(global_joints_rots)
|
| 271 |
+
f_sliced = observed_motion[:, self.slice_dict["global_rot_data"]]
|
| 272 |
+
masking = torch.zeros(len(f_sliced) * self.nbjoints, 6, device=device, dtype=bool)
|
| 273 |
+
masking[indices_lst.T[0] * self.nbjoints + indices_lst.T[1]] = True
|
| 274 |
+
masking = masking.reshape(len(f_sliced), self.nbjoints * 6)
|
| 275 |
+
f_sliced[masking] = global_rot_data.flatten()
|
| 276 |
+
m_sliced = motion_mask[:, self.slice_dict["global_rot_data"]]
|
| 277 |
+
m_sliced[masking] = True
|
| 278 |
+
|
| 279 |
+
if (fname := "global_joints_positions") in index_dict and index_dict[fname]:
|
| 280 |
+
indices_lst = _cat_indices(index_dict[fname])
|
| 281 |
+
indices_lst, global_joints_positions = get_unique_index_and_data(indices_lst, torch.cat(data_dict[fname]))
|
| 282 |
+
global_joints_positions = _match_obs_dtype(global_joints_positions)
|
| 283 |
+
T_indices = indices_lst[:, 0].contiguous()
|
| 284 |
+
_test = motion_mask[T_indices, self.slice_dict["smooth_root_pos"]]
|
| 285 |
+
if not _test[:, [0, 2]].all():
|
| 286 |
+
raise ValueError("For constraining global positions, the smooth root should also be constrained.")
|
| 287 |
+
smooth_root_pos = observed_motion[T_indices, self.slice_dict["smooth_root_pos"]].clone()
|
| 288 |
+
local_reference = smooth_root_pos.clone()
|
| 289 |
+
local_reference[..., 1] = 0.0
|
| 290 |
+
local_joints_positions = global_joints_positions - local_reference
|
| 291 |
+
f_sliced = observed_motion[:, self.slice_dict["local_joints_positions"]]
|
| 292 |
+
masking = torch.zeros(len(f_sliced) * self.nbjoints, 3, device=device, dtype=bool)
|
| 293 |
+
masking[indices_lst.T[0] * self.nbjoints + indices_lst.T[1]] = True
|
| 294 |
+
masking = masking.reshape(len(f_sliced), self.nbjoints * 3)
|
| 295 |
+
f_sliced[masking] = local_joints_positions.flatten()
|
| 296 |
+
m_sliced = motion_mask[:, self.slice_dict["local_joints_positions"]]
|
| 297 |
+
m_sliced[masking] = True
|
| 298 |
+
|
| 299 |
+
if to_normalize:
|
| 300 |
+
observed_motion = self.normalize(observed_motion)
|
| 301 |
+
return observed_motion, motion_mask
|
kimodo/motion_rep/reps/tmr_motionrep.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""TMR motion representation: global root, global joints, velocities, and foot contacts."""
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import einops
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from ...skeleton.kinematics import fk
|
| 11 |
+
from ...tools import ensure_batched, to_numpy
|
| 12 |
+
from ..feature_utils import RotateFeatures, compute_heading_angle, compute_vel_xyz
|
| 13 |
+
from ..feet import foot_detect_from_pos_and_vel
|
| 14 |
+
from .base import MotionRepBase
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TMRMotionRep(MotionRepBase):
|
| 18 |
+
"""Motion representation with global root and global joint positions.
|
| 19 |
+
|
| 20 |
+
Feature layout:
|
| 21 |
+
- root position ``(x, y, z)``
|
| 22 |
+
- root heading as ``(cos(theta), sin(theta))``
|
| 23 |
+
- local joint positions (root removed, ground-referenced)
|
| 24 |
+
- global joint velocities
|
| 25 |
+
- binary foot contacts
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
skeleton,
|
| 31 |
+
fps,
|
| 32 |
+
stats_path: Optional[str] = None,
|
| 33 |
+
):
|
| 34 |
+
nbjoints = skeleton.nbjoints
|
| 35 |
+
|
| 36 |
+
self.size_dict = {
|
| 37 |
+
"root_pos": torch.Size([3]),
|
| 38 |
+
"global_root_heading": torch.Size([2]),
|
| 39 |
+
"local_joints_positions": torch.Size([nbjoints - 1, 3]),
|
| 40 |
+
"velocities": torch.Size([nbjoints, 3]),
|
| 41 |
+
"foot_contacts": torch.Size([4]),
|
| 42 |
+
}
|
| 43 |
+
self.last_root_feature = "global_root_heading"
|
| 44 |
+
self.local_root_size_dict = {
|
| 45 |
+
"local_root_rot_vel": torch.Size([1]),
|
| 46 |
+
"local_root_vel": torch.Size([2]),
|
| 47 |
+
"global_root_y": torch.Size([1]),
|
| 48 |
+
}
|
| 49 |
+
super().__init__(skeleton, fps, stats_path)
|
| 50 |
+
|
| 51 |
+
@ensure_batched(local_joint_rots=5, root_positions=3, posed_joints=4, lengths=1)
|
| 52 |
+
def __call__(
|
| 53 |
+
self,
|
| 54 |
+
local_joint_rots: Optional[torch.Tensor] = None,
|
| 55 |
+
root_positions: Optional[torch.Tensor] = None,
|
| 56 |
+
posed_joints: Optional[torch.Tensor] = None,
|
| 57 |
+
*,
|
| 58 |
+
to_normalize: bool,
|
| 59 |
+
lengths: Optional[torch.Tensor] = None,
|
| 60 |
+
) -> torch.Tensor:
|
| 61 |
+
"""Convert motion inputs to this feature representation.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
local_joint_rots: Local joint rotation matrices ``[B, T, J, 3, 3]``.
|
| 65 |
+
Required when ``posed_joints`` is not provided.
|
| 66 |
+
root_positions: Root translations ``[B, T, 3]``. Required when
|
| 67 |
+
``posed_joints`` is not provided.
|
| 68 |
+
posed_joints: Optional precomputed global joint positions
|
| 69 |
+
``[B, T, J, 3]``. If passed, FK is skipped.
|
| 70 |
+
to_normalize: Whether to normalize output features.
|
| 71 |
+
lengths: Optional valid lengths for variable-length batches.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Motion features with shape ``[B, T, motion_rep_dim]``.
|
| 75 |
+
"""
|
| 76 |
+
if posed_joints is not None:
|
| 77 |
+
device = posed_joints.device
|
| 78 |
+
nbatch, nbframes, nbjoints = posed_joints.shape[:3]
|
| 79 |
+
else:
|
| 80 |
+
device = local_joint_rots.device
|
| 81 |
+
nbatch, nbframes, nbjoints = local_joint_rots.shape[:3]
|
| 82 |
+
|
| 83 |
+
if lengths is None:
|
| 84 |
+
assert nbatch == 1, "If lenghts is not provided, the input should not be batched."
|
| 85 |
+
lengths = torch.tensor([nbframes], device=device)
|
| 86 |
+
|
| 87 |
+
if posed_joints is None:
|
| 88 |
+
_, global_positions, local_joints_positions_origin_is_pelvis = fk(
|
| 89 |
+
local_joint_rots, root_positions, self.skeleton
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
global_positions = posed_joints
|
| 93 |
+
root_positions = posed_joints[:, :, 0]
|
| 94 |
+
local_joints_positions_origin_is_pelvis = posed_joints - root_positions[:, :, None]
|
| 95 |
+
|
| 96 |
+
root_heading_angle = compute_heading_angle(global_positions, self.skeleton)
|
| 97 |
+
global_root_heading = torch.stack([torch.cos(root_heading_angle), torch.sin(root_heading_angle)], dim=-1)
|
| 98 |
+
|
| 99 |
+
ground_offset = 0 * root_positions
|
| 100 |
+
ground_offset[..., 1] = root_positions[..., 1]
|
| 101 |
+
local_joints_positions = local_joints_positions_origin_is_pelvis[:, :, 1:] + ground_offset[:, :, None]
|
| 102 |
+
velocities = compute_vel_xyz(global_positions, self.fps, lengths=lengths)
|
| 103 |
+
foot_contacts = foot_detect_from_pos_and_vel(global_positions, velocities, self.skeleton, 0.15, 0.10)
|
| 104 |
+
|
| 105 |
+
features, _ = einops.pack(
|
| 106 |
+
[
|
| 107 |
+
root_positions,
|
| 108 |
+
global_root_heading,
|
| 109 |
+
local_joints_positions,
|
| 110 |
+
velocities,
|
| 111 |
+
foot_contacts,
|
| 112 |
+
],
|
| 113 |
+
"batch time *",
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if to_normalize:
|
| 117 |
+
features = self.normalize(features)
|
| 118 |
+
return features
|
| 119 |
+
|
| 120 |
+
@ensure_batched(features=3, angle=1)
|
| 121 |
+
def rotate(self, features: torch.Tensor, angle: torch.Tensor):
|
| 122 |
+
"""Rotate all spatial features by a heading delta (radians)."""
|
| 123 |
+
# rotate by the angle
|
| 124 |
+
# it add the angle to the current features
|
| 125 |
+
# assume it is not normalized
|
| 126 |
+
bs = features.shape[0]
|
| 127 |
+
device = features.device
|
| 128 |
+
[
|
| 129 |
+
root_pos,
|
| 130 |
+
global_root_heading,
|
| 131 |
+
local_joints_positions,
|
| 132 |
+
velocities,
|
| 133 |
+
foot_contacts,
|
| 134 |
+
] = einops.unpack(features, self.ps, "batch time *")
|
| 135 |
+
|
| 136 |
+
if not isinstance(angle, torch.Tensor):
|
| 137 |
+
angle = torch.tensor(angle, device=device)
|
| 138 |
+
if len(angle.shape) == 0:
|
| 139 |
+
angle = angle.repeat(bs)
|
| 140 |
+
|
| 141 |
+
RF = RotateFeatures(angle)
|
| 142 |
+
new_features, _ = einops.pack(
|
| 143 |
+
[
|
| 144 |
+
RF.rotate_positions(root_pos),
|
| 145 |
+
RF.rotate_2d_positions(global_root_heading),
|
| 146 |
+
RF.rotate_positions(local_joints_positions),
|
| 147 |
+
RF.rotate_positions(velocities),
|
| 148 |
+
foot_contacts,
|
| 149 |
+
],
|
| 150 |
+
"batch time *",
|
| 151 |
+
)
|
| 152 |
+
return new_features
|
| 153 |
+
|
| 154 |
+
@ensure_batched(features=3, translation_2d=2)
|
| 155 |
+
def translate_2d(
|
| 156 |
+
self,
|
| 157 |
+
features: torch.Tensor,
|
| 158 |
+
translation_2d: torch.Tensor,
|
| 159 |
+
) -> torch.Tensor:
|
| 160 |
+
"""Translate root planar position by ``(dx, dz)``."""
|
| 161 |
+
# only move on the ground
|
| 162 |
+
# For 3D, we should not forget to move the local_joints_positions as well
|
| 163 |
+
bs = features.shape[0]
|
| 164 |
+
if len(translation_2d.shape) == 1:
|
| 165 |
+
translation_2d = translation_2d.repeat(bs, 1)
|
| 166 |
+
|
| 167 |
+
new_features = features.clone()
|
| 168 |
+
new_root_pos = new_features[:, :, self.slice_dict["root_pos"]]
|
| 169 |
+
new_root_pos[:, :, 0] += translation_2d[:, 0]
|
| 170 |
+
new_root_pos[:, :, 2] += translation_2d[:, 1]
|
| 171 |
+
return new_features
|
| 172 |
+
|
| 173 |
+
@ensure_batched(features=3)
|
| 174 |
+
def inverse(
|
| 175 |
+
self,
|
| 176 |
+
features: torch.Tensor,
|
| 177 |
+
is_normalized: bool,
|
| 178 |
+
posed_joints_from="positions",
|
| 179 |
+
return_numpy: bool = False,
|
| 180 |
+
) -> torch.Tensor:
|
| 181 |
+
"""Decode features back to a motion dictionary.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
features: Feature tensor ``[B, T, D]``.
|
| 185 |
+
is_normalized: Whether input features are normalized.
|
| 186 |
+
posed_joints_from: Must be ``"positions"`` for this representation.
|
| 187 |
+
return_numpy: Whether to convert tensors to numpy arrays.
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Dictionary containing reconstructed positions and auxiliary data.
|
| 191 |
+
"""
|
| 192 |
+
assert posed_joints_from == "positions"
|
| 193 |
+
if is_normalized:
|
| 194 |
+
features = self.unnormalize(features)
|
| 195 |
+
|
| 196 |
+
[
|
| 197 |
+
root_positions,
|
| 198 |
+
global_root_heading,
|
| 199 |
+
local_joints_positions,
|
| 200 |
+
velocities,
|
| 201 |
+
foot_contacts,
|
| 202 |
+
] = einops.unpack(features, self.ps, "batch time *")
|
| 203 |
+
|
| 204 |
+
dummy_root = 0 * local_joints_positions[:, :, [0]]
|
| 205 |
+
posed_joints_from_pos = torch.stack([dummy_root, local_joints_positions], axis=2)
|
| 206 |
+
posed_joints_from_pos[..., 0] += root_positions[..., None, 0]
|
| 207 |
+
posed_joints_from_pos[..., 2] += root_positions[..., None, 2]
|
| 208 |
+
root_positions = posed_joints_from_pos[..., self.skeleton.root_idx, :]
|
| 209 |
+
foot_contacts = foot_contacts > 0.5
|
| 210 |
+
posed_joints = posed_joints_from_pos
|
| 211 |
+
|
| 212 |
+
output_tensor_dict = {
|
| 213 |
+
"local_rot_mats": None,
|
| 214 |
+
"global_rot_mats": None,
|
| 215 |
+
"posed_joints": posed_joints,
|
| 216 |
+
"root_positions": root_positions,
|
| 217 |
+
"foot_contacts": foot_contacts,
|
| 218 |
+
"global_root_heading": global_root_heading,
|
| 219 |
+
}
|
| 220 |
+
if return_numpy:
|
| 221 |
+
return to_numpy(output_tensor_dict)
|
| 222 |
+
return output_tensor_dict
|
kimodo/motion_rep/smooth_root.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Smooth root trajectory: ADMM-based smoother with margin constraints and get_smooth_root_pos helper."""
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from scipy import sparse
|
| 10 |
+
from scipy.sparse.linalg import splu
|
| 11 |
+
|
| 12 |
+
from kimodo.tools import ensure_batched
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TrajectorySmoother:
|
| 16 |
+
"""Modify trajectories to hit target values while respecting soft constraints.
|
| 17 |
+
|
| 18 |
+
This smoother keeps the trajectory close to the original positions while minimizing
|
| 19 |
+
accelerations. Targets are enforced at specified frames via soft constraints.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
margins,
|
| 25 |
+
pos_weight=0.0,
|
| 26 |
+
loop=False,
|
| 27 |
+
admm_iters=100,
|
| 28 |
+
alpha_overrelax=1.0,
|
| 29 |
+
circle_project=False,
|
| 30 |
+
):
|
| 31 |
+
"""Initialize the TrajectorySmoother.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
margins: Array of margin values for each frame.
|
| 35 |
+
margins[i] < 0: unconstrained
|
| 36 |
+
margins[i] == 0: pinned on this frame
|
| 37 |
+
margins[i] > 0: can deviate within the margin
|
| 38 |
+
pos_weight: Weight for position preservation
|
| 39 |
+
loop: Whether the trajectory should loop
|
| 40 |
+
admm_iters: Number of ADMM iterations
|
| 41 |
+
"""
|
| 42 |
+
self.pos_weight = pos_weight
|
| 43 |
+
self.admm_iters = admm_iters
|
| 44 |
+
self.alpha_overrelax = alpha_overrelax
|
| 45 |
+
self.circle_project = circle_project
|
| 46 |
+
N = len(margins)
|
| 47 |
+
|
| 48 |
+
# Store margin information as numpy arrays
|
| 49 |
+
self.margin_vals = margins
|
| 50 |
+
|
| 51 |
+
# Build acceleration matrix A
|
| 52 |
+
a_data = []
|
| 53 |
+
a_rows = []
|
| 54 |
+
a_cols = []
|
| 55 |
+
|
| 56 |
+
for i in range(1, N - 1):
|
| 57 |
+
scale = 1.0
|
| 58 |
+
a_data.extend([-scale, 2.0 * scale, -scale])
|
| 59 |
+
a_rows.extend([i, i, i])
|
| 60 |
+
a_cols.extend([i - 1, i, i + 1])
|
| 61 |
+
|
| 62 |
+
if loop:
|
| 63 |
+
# Add periodic accelerations
|
| 64 |
+
scale = 1.0
|
| 65 |
+
a_data.extend([-scale, 2.0 * scale, -scale])
|
| 66 |
+
a_rows.extend([0, 0, 0])
|
| 67 |
+
a_cols.extend([N - 1, 0, 1])
|
| 68 |
+
|
| 69 |
+
scale = 1.0
|
| 70 |
+
a_data.extend([-scale, 2.0 * scale, -scale])
|
| 71 |
+
a_rows.extend([N - 1, N - 1, N - 1])
|
| 72 |
+
a_cols.extend([N - 2, N - 1, 0])
|
| 73 |
+
|
| 74 |
+
A = sparse.csr_matrix((a_data, (a_rows, a_cols)), shape=(N, N))
|
| 75 |
+
|
| 76 |
+
# Build identity matrix
|
| 77 |
+
identity_matrix = sparse.eye(N)
|
| 78 |
+
|
| 79 |
+
# Build system matrix M
|
| 80 |
+
M = pos_weight * identity_matrix + A.T @ A
|
| 81 |
+
|
| 82 |
+
# Calculate ADMM step size
|
| 83 |
+
diag_max = max(abs(M.diagonal()))
|
| 84 |
+
self.admm_stepsize = 0.25 * np.sqrt(diag_max)
|
| 85 |
+
|
| 86 |
+
M = M + self.admm_stepsize * identity_matrix
|
| 87 |
+
self.system_lu = splu(M.tocsc())
|
| 88 |
+
|
| 89 |
+
def smooth(self, targets, x0):
|
| 90 |
+
"""Interpolate between reference positions while satisfying constraints.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
observations: Target positions for constrained frames (numpy array)
|
| 94 |
+
ref_positions: Reference positions defining original shape
|
| 95 |
+
(numpy array)
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Interpolated positions (numpy array)
|
| 99 |
+
"""
|
| 100 |
+
x_target = targets.copy()
|
| 101 |
+
x = x0.copy()
|
| 102 |
+
z = np.zeros_like(x)
|
| 103 |
+
u = np.zeros_like(x)
|
| 104 |
+
|
| 105 |
+
for _ in range(self.admm_iters):
|
| 106 |
+
self.z_update(z, x, x_target, u)
|
| 107 |
+
self.u_update(u, x, z)
|
| 108 |
+
self.x_update(x, z, u, x_target)
|
| 109 |
+
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
def x_update(self, x, z, u, x_t):
|
| 113 |
+
"""Update x in the ADMM iteration."""
|
| 114 |
+
|
| 115 |
+
# x = (wp * I + A^T A + p I)^-1 (wp * x_orig + p (z - u))
|
| 116 |
+
r = self.pos_weight * x_t + self.admm_stepsize * (z - u)
|
| 117 |
+
x[:] = self.system_lu.solve(r)
|
| 118 |
+
|
| 119 |
+
def z_update(self, z, x, z_t, u):
|
| 120 |
+
"""Update z in the ADMM iteration using vectorized operations."""
|
| 121 |
+
# Compute the difference from target for all margin locations at once
|
| 122 |
+
z[:] = x + u - z_t
|
| 123 |
+
|
| 124 |
+
# Check if we need to project back to margin
|
| 125 |
+
z_diff_norms = np.linalg.norm(z, axis=1)
|
| 126 |
+
mask = z_diff_norms > self.margin_vals
|
| 127 |
+
if np.any(mask):
|
| 128 |
+
scale_factors = self.margin_vals[mask] / z_diff_norms[mask]
|
| 129 |
+
z[mask] *= scale_factors[:, np.newaxis]
|
| 130 |
+
|
| 131 |
+
# Add back the target
|
| 132 |
+
z[:] += z_t
|
| 133 |
+
|
| 134 |
+
if self.circle_project:
|
| 135 |
+
z[:] = z / (np.linalg.norm(z, axis=1, keepdims=True) + 1.0e-6)
|
| 136 |
+
|
| 137 |
+
def u_update(self, u, x, z):
|
| 138 |
+
"""Update u in the ADMM iteration using vectorized operations."""
|
| 139 |
+
u[:] += self.alpha_overrelax * (x - z)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def smooth_signal(x, margins, pos_weight=0, alpha_overrelax=1.8, admm_iters=500, circle_project=False):
|
| 143 |
+
"""Multigrid trajectory smoothing with margin constraints.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
x: Input trajectory ``[T, D]`` as a NumPy array.
|
| 147 |
+
margins: Allowed radius around each target frame ``[T]``.
|
| 148 |
+
pos_weight: Weight for staying close to the original signal.
|
| 149 |
+
alpha_overrelax: ADMM over-relaxation coefficient.
|
| 150 |
+
admm_iters: ADMM iterations per multigrid level.
|
| 151 |
+
circle_project: If ``True``, project each vector to the unit sphere.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Smoothed trajectory of shape ``[T, D]``.
|
| 155 |
+
"""
|
| 156 |
+
x_smoothed = x.copy()
|
| 157 |
+
x_smoothed[:] = x.mean(axis=0, keepdims=True)
|
| 158 |
+
|
| 159 |
+
# smooth the signal, multigrid style by starting out coarse,
|
| 160 |
+
# doubling the resolution and repeating until we're at the full
|
| 161 |
+
# resolution, using the previous result as the initial guess.
|
| 162 |
+
levels = int(math.floor(math.log2(len(x))))
|
| 163 |
+
levels = max(levels - 4, 1)
|
| 164 |
+
|
| 165 |
+
stepsize = 2**levels
|
| 166 |
+
while True:
|
| 167 |
+
# smooth signals at this level:
|
| 168 |
+
num_steps = len(x_smoothed[::stepsize])
|
| 169 |
+
smoother = TrajectorySmoother(
|
| 170 |
+
margins=margins[::stepsize],
|
| 171 |
+
pos_weight=pos_weight,
|
| 172 |
+
alpha_overrelax=alpha_overrelax,
|
| 173 |
+
admm_iters=admm_iters,
|
| 174 |
+
circle_project=circle_project,
|
| 175 |
+
)
|
| 176 |
+
x_smoothed[::stepsize] = smoother.smooth(x[::stepsize], x_smoothed[::stepsize])
|
| 177 |
+
|
| 178 |
+
# interpolate to next level:
|
| 179 |
+
next_stepsize = stepsize // 2
|
| 180 |
+
num_interleaved = len(x_smoothed[next_stepsize::stepsize])
|
| 181 |
+
if num_interleaved == num_steps:
|
| 182 |
+
# linearly extrapolate the last value if we have to:
|
| 183 |
+
x_smoothed[next_stepsize::stepsize][-1] = (
|
| 184 |
+
x_smoothed[::stepsize][-1] + (x_smoothed[::stepsize][-1] - x_smoothed[::stepsize][-2]) / 2
|
| 185 |
+
)
|
| 186 |
+
num_interleaved = num_interleaved - 1
|
| 187 |
+
|
| 188 |
+
# linearly interpolate the remaining values:
|
| 189 |
+
x_smoothed[next_stepsize::stepsize][:num_interleaved] = (
|
| 190 |
+
x_smoothed[::stepsize][:-1] + x_smoothed[::stepsize][1:]
|
| 191 |
+
) / 2
|
| 192 |
+
|
| 193 |
+
if stepsize == 1:
|
| 194 |
+
break
|
| 195 |
+
|
| 196 |
+
stepsize //= 2
|
| 197 |
+
|
| 198 |
+
return x_smoothed
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@ensure_batched(hip_translations=3)
|
| 202 |
+
def get_smooth_root_pos(hip_translations):
|
| 203 |
+
"""Smooth root trajectory in the ground plane while preserving height.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
hip_translations: Root translations ``[B, T, 3]``.
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
Smoothed root translations ``[B, T, 3]`` where ``x/z`` are smoothed and
|
| 210 |
+
``y`` remains unchanged.
|
| 211 |
+
"""
|
| 212 |
+
root_translations_xz = hip_translations[..., [0, 2]]
|
| 213 |
+
root_translations_y = hip_translations[..., [1]]
|
| 214 |
+
|
| 215 |
+
batch_size, nframes = root_translations_xz.shape[:2]
|
| 216 |
+
margins = np.full(root_translations_xz.shape[1], 0.06)
|
| 217 |
+
|
| 218 |
+
root_translations_smoothed_xz = []
|
| 219 |
+
for batch in range(batch_size):
|
| 220 |
+
root_translations_smoothed_xz.append(
|
| 221 |
+
smooth_signal(root_translations_xz[batch].detach().cpu().numpy(), margins)[None]
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
root_translations_smoothed_xz = torch.tensor(np.concatenate(root_translations_smoothed_xz))
|
| 225 |
+
|
| 226 |
+
root_translations = torch.cat(
|
| 227 |
+
[
|
| 228 |
+
root_translations_smoothed_xz.to(root_translations_y.device),
|
| 229 |
+
root_translations_y,
|
| 230 |
+
],
|
| 231 |
+
dim=-1,
|
| 232 |
+
)[..., [0, 2, 1]]
|
| 233 |
+
|
| 234 |
+
return root_translations
|
kimodo/motion_rep/stats.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Feature normalization statistics (mean/std) for motion representations."""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
log = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Stats(torch.nn.Module):
|
| 16 |
+
"""Utility module for feature normalization statistics.
|
| 17 |
+
|
| 18 |
+
Normalization follows:
|
| 19 |
+
``(data - mean) / sqrt(std**2 + eps)``
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
folder: Optional[str] = None,
|
| 25 |
+
load: bool = True,
|
| 26 |
+
eps=1e-05,
|
| 27 |
+
):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.folder = folder
|
| 30 |
+
self.eps = eps
|
| 31 |
+
if folder is not None and load:
|
| 32 |
+
self.load()
|
| 33 |
+
|
| 34 |
+
def sliced(self, indices):
|
| 35 |
+
"""Return a new ``Stats`` object containing selected feature indices."""
|
| 36 |
+
new_stats = Stats(folder=self.folder, load=False, eps=self.eps)
|
| 37 |
+
new_stats.register_from_tensors(
|
| 38 |
+
self.mean[..., indices].clone(),
|
| 39 |
+
self.std[..., indices].clone(),
|
| 40 |
+
)
|
| 41 |
+
return new_stats
|
| 42 |
+
|
| 43 |
+
def load(self):
|
| 44 |
+
"""Load ``mean.npy`` and ``std.npy`` from ``self.folder``."""
|
| 45 |
+
mean_path = os.path.join(self.folder, "mean.npy")
|
| 46 |
+
std_path = os.path.join(self.folder, "std.npy")
|
| 47 |
+
if not os.path.exists(mean_path) or not os.path.exists(std_path):
|
| 48 |
+
raise FileNotFoundError(
|
| 49 |
+
f"Missing stats files in '{self.folder}'. Expected:\n"
|
| 50 |
+
f" - {mean_path}\n"
|
| 51 |
+
f" - {std_path}\n\n"
|
| 52 |
+
"Make sure the checkpoint/stats have been downloaded and are mounted into the container.\n"
|
| 53 |
+
"If you're using Docker Compose, run it from the repo root so `./:/workspace` mounts the correct directory."
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
mean = torch.from_numpy(np.load(mean_path))
|
| 57 |
+
std = torch.from_numpy(np.load(std_path))
|
| 58 |
+
self.register_from_tensors(mean, std)
|
| 59 |
+
|
| 60 |
+
def register_from_tensors(self, mean: torch.Tensor, std: torch.Tensor):
|
| 61 |
+
"""Register mean/std tensors as non-persistent buffers."""
|
| 62 |
+
self.register_buffer("mean", mean, persistent=False)
|
| 63 |
+
self.register_buffer("std", std, persistent=False)
|
| 64 |
+
|
| 65 |
+
def normalize(self, data: torch.Tensor) -> torch.Tensor:
|
| 66 |
+
"""Normalize data using the stored statistics."""
|
| 67 |
+
mean = self.mean.to(device=data.device, dtype=data.dtype)
|
| 68 |
+
std = self.std.to(device=data.device, dtype=data.dtype)
|
| 69 |
+
# adjust std with eps
|
| 70 |
+
return (data - mean) / torch.sqrt(std**2 + self.eps)
|
| 71 |
+
|
| 72 |
+
def unnormalize(self, data: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
"""Undo normalization using the stored statistics."""
|
| 74 |
+
mean = self.mean.to(device=data.device, dtype=data.dtype)
|
| 75 |
+
std = self.std.to(device=data.device, dtype=data.dtype)
|
| 76 |
+
# adjust std with eps
|
| 77 |
+
return data * torch.sqrt(std**2 + self.eps) + mean
|
| 78 |
+
|
| 79 |
+
def is_loaded(self):
|
| 80 |
+
"""Return whether statistics are currently available."""
|
| 81 |
+
return hasattr(self, "mean")
|
| 82 |
+
|
| 83 |
+
def get_dim(self):
|
| 84 |
+
"""Return feature dimensionality."""
|
| 85 |
+
return self.mean.shape[0]
|
| 86 |
+
|
| 87 |
+
def save(
|
| 88 |
+
self,
|
| 89 |
+
folder: Optional[str] = None,
|
| 90 |
+
mean: Optional[torch.Tensor] = None,
|
| 91 |
+
std: Optional[torch.Tensor] = None,
|
| 92 |
+
):
|
| 93 |
+
"""Save statistics to ``folder`` as ``mean.npy`` and ``std.npy``."""
|
| 94 |
+
if folder is None:
|
| 95 |
+
folder = self.folder
|
| 96 |
+
if folder is None:
|
| 97 |
+
raise ValueError("No folder to save stats")
|
| 98 |
+
|
| 99 |
+
if mean is None and std is None:
|
| 100 |
+
try:
|
| 101 |
+
mean = self.mean.cpu().numpy()
|
| 102 |
+
std = self.std.cpu().numpy()
|
| 103 |
+
except AttributeError:
|
| 104 |
+
raise ValueError("Stats were not loaded")
|
| 105 |
+
|
| 106 |
+
# don't override stats folder
|
| 107 |
+
os.makedirs(folder, exist_ok=False)
|
| 108 |
+
|
| 109 |
+
np.save(os.path.join(folder, "mean.npy"), mean)
|
| 110 |
+
np.save(os.path.join(folder, "std.npy"), std)
|
| 111 |
+
|
| 112 |
+
def __eq__(self, other):
|
| 113 |
+
return (self.mean.cpu() == other.mean.cpu()).all() and (self.std.cpu() == other.std.cpu()).all()
|
| 114 |
+
|
| 115 |
+
# should define a hash value for pytorch, as we defined __eq__
|
| 116 |
+
def __hash__(self):
|
| 117 |
+
# Convert mean and std to bytes for a consistent hash value
|
| 118 |
+
mean_hash = hash(self.mean.detach().cpu().numpy().tobytes())
|
| 119 |
+
std_hash = hash(self.std.detach().cpu().numpy().tobytes())
|
| 120 |
+
return hash((mean_hash, std_hash))
|
| 121 |
+
|
| 122 |
+
def __repr__(self):
|
| 123 |
+
return f'Stats(folder="{self.folder}")'
|
kimodo/pipeline/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pipeline utilities for prompt/script to Kimodo generation flows."""
|
| 2 |
+
|
| 3 |
+
from .blend_quality import (
|
| 4 |
+
BlendGuardrailConfig,
|
| 5 |
+
TransitionSettings,
|
| 6 |
+
apply_transition_guardrails,
|
| 7 |
+
harmonize_scene_transitions,
|
| 8 |
+
)
|
| 9 |
+
from .script_to_kimodo import (
|
| 10 |
+
CharacterKimodoPlan,
|
| 11 |
+
build_character_plan,
|
| 12 |
+
generator_request_to_plans,
|
| 13 |
+
run_multi_character_generation,
|
| 14 |
+
)
|
| 15 |
+
from .scheduler_runtime import SceneScheduleResult, run_scheduled_scene
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"CharacterKimodoPlan",
|
| 19 |
+
"BlendGuardrailConfig",
|
| 20 |
+
"TransitionSettings",
|
| 21 |
+
"apply_transition_guardrails",
|
| 22 |
+
"harmonize_scene_transitions",
|
| 23 |
+
"build_character_plan",
|
| 24 |
+
"generator_request_to_plans",
|
| 25 |
+
"run_multi_character_generation",
|
| 26 |
+
"SceneScheduleResult",
|
| 27 |
+
"run_scheduled_scene",
|
| 28 |
+
]
|
kimodo/pipeline/blend_quality.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Card 7 blend quality guardrails for transition blending safety and consistency."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass(frozen=True)
|
| 9 |
+
class TransitionSettings:
|
| 10 |
+
"""Transition settings passed to Kimodo generation."""
|
| 11 |
+
|
| 12 |
+
num_transition_frames: int
|
| 13 |
+
share_transition: bool
|
| 14 |
+
percentage_transition_override: float
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass(frozen=True)
|
| 18 |
+
class BlendGuardrailConfig:
|
| 19 |
+
"""Runtime safety bounds for transition blending."""
|
| 20 |
+
|
| 21 |
+
min_transition_frames: int = 1
|
| 22 |
+
max_transition_frames: int = 12
|
| 23 |
+
min_segment_frames_for_share: int = 12
|
| 24 |
+
max_transition_ratio: float = 0.30
|
| 25 |
+
max_shared_window_frames: int = 24
|
| 26 |
+
harmonize_window: int = 2
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _clamp(value: float, low: float, high: float) -> float:
|
| 30 |
+
return max(low, min(high, value))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def apply_transition_guardrails(
|
| 34 |
+
segment_frames: list[int],
|
| 35 |
+
policies: list[str],
|
| 36 |
+
requested: TransitionSettings,
|
| 37 |
+
*,
|
| 38 |
+
config: BlendGuardrailConfig = BlendGuardrailConfig(),
|
| 39 |
+
) -> TransitionSettings:
|
| 40 |
+
"""Clamp transition settings to safe ranges for short/long segments.
|
| 41 |
+
|
| 42 |
+
Guardrails avoid transition windows that dominate short segments and reduce blending artifacts
|
| 43 |
+
for scripted interactions.
|
| 44 |
+
"""
|
| 45 |
+
if len(segment_frames) < 2:
|
| 46 |
+
safe_frames = int(_clamp(requested.num_transition_frames, config.min_transition_frames, config.max_transition_frames))
|
| 47 |
+
return TransitionSettings(
|
| 48 |
+
num_transition_frames=safe_frames,
|
| 49 |
+
share_transition=False,
|
| 50 |
+
percentage_transition_override=0.0,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
min_prev = min(segment_frames[:-1])
|
| 54 |
+
min_next = min(segment_frames[1:])
|
| 55 |
+
# Keep at least one non-transition frame in the shortest pair.
|
| 56 |
+
shortest_pair_budget = max(config.min_transition_frames, min(min_prev, min_next) - 1)
|
| 57 |
+
|
| 58 |
+
safe_frames = int(
|
| 59 |
+
_clamp(
|
| 60 |
+
requested.num_transition_frames,
|
| 61 |
+
config.min_transition_frames,
|
| 62 |
+
min(config.max_transition_frames, shortest_pair_budget),
|
| 63 |
+
)
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
has_cut = "cut" in policies
|
| 67 |
+
can_share = (
|
| 68 |
+
requested.share_transition
|
| 69 |
+
and not has_cut
|
| 70 |
+
and min_prev >= config.min_segment_frames_for_share
|
| 71 |
+
and min_next >= config.min_segment_frames_for_share
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
if not can_share:
|
| 75 |
+
return TransitionSettings(
|
| 76 |
+
num_transition_frames=safe_frames,
|
| 77 |
+
share_transition=False,
|
| 78 |
+
percentage_transition_override=0.0,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
safe_pct = _clamp(requested.percentage_transition_override, 0.0, config.max_transition_ratio)
|
| 82 |
+
|
| 83 |
+
# Cap shared overlap by configured hard ceiling and shortest-pair budget.
|
| 84 |
+
max_pct_from_shared_window = max(0.0, (config.max_shared_window_frames - safe_frames) / max(1, min_prev))
|
| 85 |
+
max_pct_from_shortest_pair = max(0.0, (shortest_pair_budget - safe_frames) / max(1, min_prev))
|
| 86 |
+
safe_pct = min(safe_pct, max_pct_from_shared_window, max_pct_from_shortest_pair)
|
| 87 |
+
|
| 88 |
+
return TransitionSettings(
|
| 89 |
+
num_transition_frames=safe_frames,
|
| 90 |
+
share_transition=True,
|
| 91 |
+
percentage_transition_override=float(safe_pct),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def harmonize_scene_transitions(
|
| 96 |
+
settings_by_character: dict[str, TransitionSettings],
|
| 97 |
+
*,
|
| 98 |
+
config: BlendGuardrailConfig = BlendGuardrailConfig(),
|
| 99 |
+
) -> dict[str, TransitionSettings]:
|
| 100 |
+
"""Nudge transition-frame counts toward a scene median for multi-character consistency."""
|
| 101 |
+
if len(settings_by_character) < 2:
|
| 102 |
+
return settings_by_character
|
| 103 |
+
|
| 104 |
+
frame_values = sorted(setting.num_transition_frames for setting in settings_by_character.values())
|
| 105 |
+
median = frame_values[len(frame_values) // 2]
|
| 106 |
+
low = max(config.min_transition_frames, median - config.harmonize_window)
|
| 107 |
+
high = min(config.max_transition_frames, median + config.harmonize_window)
|
| 108 |
+
|
| 109 |
+
harmonized: dict[str, TransitionSettings] = {}
|
| 110 |
+
for character_id, setting in settings_by_character.items():
|
| 111 |
+
harmonized[character_id] = TransitionSettings(
|
| 112 |
+
num_transition_frames=int(_clamp(setting.num_transition_frames, low, high)),
|
| 113 |
+
share_transition=setting.share_transition,
|
| 114 |
+
percentage_transition_override=setting.percentage_transition_override,
|
| 115 |
+
)
|
| 116 |
+
return harmonized
|
kimodo/pipeline/scheduler_runtime.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Card 8 runtime orchestration: deterministic multi-character scheduling."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Optional
|
| 8 |
+
|
| 9 |
+
from kimodo.pipeline.script_to_kimodo import run_multi_character_generation
|
| 10 |
+
from kimodo.schemas import GeneratorRequest
|
| 11 |
+
from kimodo.scheduler import (
|
| 12 |
+
CharacterState,
|
| 13 |
+
CharacterSegmentState,
|
| 14 |
+
ConflictResolutionPolicy,
|
| 15 |
+
DeterministicLoop,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
LOGGER = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass(frozen=True)
|
| 22 |
+
class SceneScheduleResult:
|
| 23 |
+
"""Structured result for scheduled scene execution."""
|
| 24 |
+
|
| 25 |
+
outputs: dict[str, dict[str, Any]]
|
| 26 |
+
errors: dict[str, str]
|
| 27 |
+
plans: dict[str, Any]
|
| 28 |
+
state_hashes: list[str]
|
| 29 |
+
interactions: list[tuple[int, str, str]]
|
| 30 |
+
completed_segments: dict[str, int]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _activate_next_segment(loop: DeterministicLoop, character_id: str, plan: Any, segment_index: int) -> None:
|
| 34 |
+
"""Set active segment in loop state for one character."""
|
| 35 |
+
slot = loop.characters[character_id]
|
| 36 |
+
slot.segment_state = CharacterSegmentState(
|
| 37 |
+
character_id=character_id,
|
| 38 |
+
segment_index=segment_index,
|
| 39 |
+
frames_elapsed=0,
|
| 40 |
+
total_frames=plan.num_frames[segment_index],
|
| 41 |
+
)
|
| 42 |
+
segment = plan.segment_transition_policies[segment_index]
|
| 43 |
+
# Interaction target is encoded in planner request segments; set later in per-tick update.
|
| 44 |
+
slot.current_state = CharacterState.BUSY if segment != "cut" else CharacterState.TRANSITIONING
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def run_scheduled_scene(
|
| 48 |
+
model: Any,
|
| 49 |
+
request: GeneratorRequest,
|
| 50 |
+
*,
|
| 51 |
+
fps: float,
|
| 52 |
+
seed: int = 42,
|
| 53 |
+
conflict_policy: ConflictResolutionPolicy = ConflictResolutionPolicy.COOLDOWN,
|
| 54 |
+
diffusion_steps: int = 100,
|
| 55 |
+
cfg_weight: Optional[list[float]] = None,
|
| 56 |
+
cfg_type: Optional[str] = None,
|
| 57 |
+
post_processing: bool = True,
|
| 58 |
+
root_margin: float = 0.04,
|
| 59 |
+
constraint_resolver: Optional[Any] = None,
|
| 60 |
+
continue_on_error: bool = False,
|
| 61 |
+
) -> SceneScheduleResult:
|
| 62 |
+
"""Run generation then deterministic timeline scheduling for all characters in a scene."""
|
| 63 |
+
LOGGER.info("card8.run_scheduled_scene.start scene_id=%s chars=%s", request.scene_id, len(request.characters))
|
| 64 |
+
|
| 65 |
+
outputs, errors, plans = run_multi_character_generation(
|
| 66 |
+
model,
|
| 67 |
+
request,
|
| 68 |
+
fps=fps,
|
| 69 |
+
diffusion_steps=diffusion_steps,
|
| 70 |
+
cfg_weight=cfg_weight,
|
| 71 |
+
cfg_type=cfg_type,
|
| 72 |
+
post_processing=post_processing,
|
| 73 |
+
root_margin=root_margin,
|
| 74 |
+
constraint_resolver=constraint_resolver,
|
| 75 |
+
continue_on_error=continue_on_error,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
loop = DeterministicLoop(
|
| 79 |
+
fps=int(fps),
|
| 80 |
+
seed=seed,
|
| 81 |
+
conflict_policy=conflict_policy,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
for priority, character in enumerate(request.characters):
|
| 85 |
+
loop.register_character(character.character_id, character.skeleton_type, priority=priority)
|
| 86 |
+
|
| 87 |
+
segment_indices = {character.character_id: 0 for character in request.characters}
|
| 88 |
+
completed_segments = {character.character_id: 0 for character in request.characters}
|
| 89 |
+
|
| 90 |
+
for character in request.characters:
|
| 91 |
+
plan = plans.get(character.character_id)
|
| 92 |
+
if plan is None:
|
| 93 |
+
continue
|
| 94 |
+
if not plan.num_frames:
|
| 95 |
+
continue
|
| 96 |
+
_activate_next_segment(loop, character.character_id, plan, segment_index=0)
|
| 97 |
+
first_segment = character.segments[0]
|
| 98 |
+
loop.characters[character.character_id].interaction_target = first_segment.interaction_target
|
| 99 |
+
|
| 100 |
+
total_scene_frames = max((plan.total_frames for plan in plans.values()), default=0)
|
| 101 |
+
state_hashes: list[str] = []
|
| 102 |
+
interactions: list[tuple[int, str, str]] = []
|
| 103 |
+
|
| 104 |
+
for _ in range(total_scene_frames):
|
| 105 |
+
tick = loop.advance_tick({})
|
| 106 |
+
state_hashes.append(loop.get_state_hash())
|
| 107 |
+
|
| 108 |
+
for winner, loser in tick.interactions:
|
| 109 |
+
interactions.append((tick.tick_number, winner, loser))
|
| 110 |
+
|
| 111 |
+
for character_id in tick.completed_segments:
|
| 112 |
+
plan = plans.get(character_id)
|
| 113 |
+
if plan is None:
|
| 114 |
+
continue
|
| 115 |
+
completed_segments[character_id] += 1
|
| 116 |
+
next_index = segment_indices[character_id] + 1
|
| 117 |
+
if next_index < len(plan.num_frames):
|
| 118 |
+
segment_indices[character_id] = next_index
|
| 119 |
+
_activate_next_segment(loop, character_id, plan, next_index)
|
| 120 |
+
source_char = next(c for c in request.characters if c.character_id == character_id)
|
| 121 |
+
loop.characters[character_id].interaction_target = source_char.segments[next_index].interaction_target
|
| 122 |
+
else:
|
| 123 |
+
loop.characters[character_id].segment_state = None
|
| 124 |
+
loop.characters[character_id].interaction_target = None
|
| 125 |
+
|
| 126 |
+
LOGGER.info(
|
| 127 |
+
"card8.run_scheduled_scene.exit scene_id=%s hashes=%s interactions=%s",
|
| 128 |
+
request.scene_id,
|
| 129 |
+
len(state_hashes),
|
| 130 |
+
len(interactions),
|
| 131 |
+
)
|
| 132 |
+
return SceneScheduleResult(
|
| 133 |
+
outputs=outputs,
|
| 134 |
+
errors=errors,
|
| 135 |
+
plans=plans,
|
| 136 |
+
state_hashes=state_hashes,
|
| 137 |
+
interactions=interactions,
|
| 138 |
+
completed_segments=completed_segments,
|
| 139 |
+
)
|