Kimodo Bot commited on
Commit
6d5047c
·
1 Parent(s): d6cb863

Add core kimodo package modules required by native demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. kimodo/__init__.py +11 -0
  2. kimodo/assets.py +19 -0
  3. kimodo/constraints.py +625 -0
  4. kimodo/exports/__init__.py +65 -0
  5. kimodo/exports/bvh.py +282 -0
  6. kimodo/exports/motion_convert_lib.py +155 -0
  7. kimodo/exports/motion_formats.py +78 -0
  8. kimodo/exports/motion_io.py +443 -0
  9. kimodo/exports/mujoco.py +588 -0
  10. kimodo/exports/smplx.py +251 -0
  11. kimodo/geometry.py +216 -0
  12. kimodo/meta.py +80 -0
  13. kimodo/metrics/__init__.py +39 -0
  14. kimodo/metrics/base.py +66 -0
  15. kimodo/metrics/constraints.py +87 -0
  16. kimodo/metrics/foot_skate.py +232 -0
  17. kimodo/metrics/tmr.py +530 -0
  18. kimodo/model/__init__.py +31 -0
  19. kimodo/model/backbone.py +312 -0
  20. kimodo/model/cfg.py +133 -0
  21. kimodo/model/common.py +48 -0
  22. kimodo/model/diffusion.py +133 -0
  23. kimodo/model/kimodo_model.py +605 -0
  24. kimodo/model/llm2vec/README.md +1 -0
  25. kimodo/model/llm2vec/__init__.py +11 -0
  26. kimodo/model/llm2vec/llm2vec.py +477 -0
  27. kimodo/model/llm2vec/llm2vec_wrapper.py +73 -0
  28. kimodo/model/llm2vec/models/__init__.py +4 -0
  29. kimodo/model/llm2vec/models/attn_mask_utils.py +181 -0
  30. kimodo/model/llm2vec/models/bidirectional_llama.py +224 -0
  31. kimodo/model/llm2vec/models/utils.py +32 -0
  32. kimodo/model/load_model.py +194 -0
  33. kimodo/model/loading.py +81 -0
  34. kimodo/model/registry.py +473 -0
  35. kimodo/model/text_encoder_api.py +74 -0
  36. kimodo/model/tmr.py +382 -0
  37. kimodo/model/twostage_denoiser.py +153 -0
  38. kimodo/motion_rep/__init__.py +11 -0
  39. kimodo/motion_rep/conditioning.py +28 -0
  40. kimodo/motion_rep/feature_utils.py +212 -0
  41. kimodo/motion_rep/feet.py +60 -0
  42. kimodo/motion_rep/reps/__init__.py +13 -0
  43. kimodo/motion_rep/reps/base.py +300 -0
  44. kimodo/motion_rep/reps/kimodo_motionrep.py +301 -0
  45. kimodo/motion_rep/reps/tmr_motionrep.py +222 -0
  46. kimodo/motion_rep/smooth_root.py +234 -0
  47. kimodo/motion_rep/stats.py +123 -0
  48. kimodo/pipeline/__init__.py +28 -0
  49. kimodo/pipeline/blend_quality.py +116 -0
  50. kimodo/pipeline/scheduler_runtime.py +139 -0
kimodo/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Kimodo: text-driven and constrained motion generation model."""
4
+
5
+ from .model.load_model import AVAILABLE_MODELS, DEFAULT_MODEL, load_model
6
+
7
+ __all__ = [
8
+ "AVAILABLE_MODELS",
9
+ "DEFAULT_MODEL",
10
+ "load_model",
11
+ ]
kimodo/assets.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from pathlib import Path
5
+
6
+ PACKAGE_ROOT = Path(__file__).resolve().parent
7
+ ASSETS_ROOT = PACKAGE_ROOT / "assets"
8
+ DEMO_ASSETS_ROOT = ASSETS_ROOT / "demo"
9
+ DEMO_EXAMPLES_ROOT = DEMO_ASSETS_ROOT / "examples"
10
+ SKELETONS_ROOT = ASSETS_ROOT / "skeletons"
11
+ SOMA_ASSETS_ROOT = ASSETS_ROOT / "SOMA"
12
+
13
+
14
+ def skeleton_asset_path(*parts: str) -> Path:
15
+ return SKELETONS_ROOT.joinpath(*parts)
16
+
17
+
18
+ def demo_asset_path(*parts: str) -> Path:
19
+ return DEMO_ASSETS_ROOT.joinpath(*parts)
kimodo/constraints.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Constraint sets for conditioning motion generation (root 2D, full body, end-effectors)."""
4
+
5
+ from typing import Optional, Union
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+ from kimodo.motion_rep.feature_utils import compute_heading_angle
11
+ from kimodo.skeleton import SkeletonBase, SOMASkeleton30, SOMASkeleton77
12
+ from kimodo.tools import ensure_batched, load_json, save_json
13
+
14
+ from .geometry import axis_angle_to_matrix, matrix_to_axis_angle
15
+
16
+
17
+ def _convert_constraint_local_rots_to_skeleton(local_rot_mats: Tensor, skeleton: SkeletonBase) -> Tensor:
18
+ """Convert loaded local rotation matrices to match the skeleton's joint count.
19
+
20
+ Handles SOMA 30↔77: constraint files may have been saved with 30 or 77 joints while the session
21
+ skeleton (e.g. from the SOMA30 model) uses SOMASkeleton77.
22
+ """
23
+ n_joints = local_rot_mats.shape[-3]
24
+ skeleton_joints = skeleton.nbjoints
25
+ if n_joints == skeleton_joints:
26
+ return local_rot_mats
27
+ if n_joints == 77 and skeleton_joints == 30 and isinstance(skeleton, SOMASkeleton30):
28
+ return skeleton.from_SOMASkeleton77(local_rot_mats)
29
+ if n_joints == 30 and skeleton_joints == 77 and isinstance(skeleton, SOMASkeleton77):
30
+ skel30 = SOMASkeleton30()
31
+ return skel30.to_SOMASkeleton77(local_rot_mats)
32
+ raise ValueError(
33
+ f"Constraint joint count ({n_joints}) does not match skeleton joint count "
34
+ f"({skeleton_joints}). Only SOMA 30↔77 conversion is supported."
35
+ )
36
+
37
+
38
+ def create_pairs(tensor_A: Tensor, tensor_B: Tensor) -> Tensor:
39
+ """Form all (a, b) pairs from two 1D tensors; output shape (len(A)*len(B), 2)."""
40
+ pairs = torch.stack(
41
+ (
42
+ tensor_A[:, None].expand(-1, len(tensor_B)),
43
+ tensor_B.expand(len(tensor_A), -1),
44
+ ),
45
+ dim=-1,
46
+ ).reshape(-1, 2)
47
+ return pairs
48
+
49
+
50
+ def compute_global_heading(global_joints_positions: Tensor, skeleton: SkeletonBase) -> Tensor:
51
+ """Compute global root heading (cos, sin) from global joint positions using skeleton."""
52
+ root_heading_angle = compute_heading_angle(global_joints_positions, skeleton)
53
+ global_root_heading = torch.stack([torch.cos(root_heading_angle), torch.sin(root_heading_angle)], dim=-1)
54
+ return global_root_heading
55
+
56
+
57
+ def _tensor_to(
58
+ t: Tensor,
59
+ device: Optional[Union[str, torch.device]] = None,
60
+ dtype: Optional[torch.dtype] = None,
61
+ ) -> Tensor:
62
+ """Move tensor to device and/or dtype.
63
+
64
+ Returns same tensor if no args.
65
+ """
66
+ if device is not None and dtype is not None:
67
+ return t.to(device=device, dtype=dtype)
68
+ if device is not None:
69
+ return t.to(device=device)
70
+ if dtype is not None:
71
+ return t.to(dtype=dtype)
72
+ return t
73
+
74
+
75
+ class Root2DConstraintSet:
76
+ """Constraint set fixing root (x, z) trajectory and optionally global heading on given
77
+ frames."""
78
+
79
+ name = "root2d"
80
+
81
+ def __init__(
82
+ self,
83
+ skeleton: SkeletonBase,
84
+ frame_indices: Tensor,
85
+ smooth_root_2d: Tensor,
86
+ to_crop: bool = False,
87
+ global_root_heading: Optional[Tensor] = None,
88
+ ) -> None:
89
+ self.skeleton = skeleton
90
+
91
+ # if we pass the full smooth root 3D as input
92
+ if smooth_root_2d.shape[-1] == 3:
93
+ smooth_root_2d = smooth_root_2d[..., [0, 1]]
94
+
95
+ if to_crop:
96
+ smooth_root_2d = smooth_root_2d[frame_indices]
97
+ if global_root_heading is not None:
98
+ global_root_heading = global_root_heading[frame_indices]
99
+ else:
100
+ assert len(smooth_root_2d) == len(
101
+ frame_indices
102
+ ), "The number of smooth root 2d should be match the number of frames"
103
+ if global_root_heading is not None:
104
+ assert len(global_root_heading) == len(
105
+ frame_indices
106
+ ), "The number of global root heading should be match the number of frames"
107
+
108
+ self.smooth_root_2d = smooth_root_2d
109
+ self.global_root_heading = global_root_heading
110
+ self.frame_indices = frame_indices
111
+
112
+ def update_constraints(self, data_dict: dict, index_dict: dict) -> None:
113
+ """Append this constraint's smooth_root_2d (and optional global_root_heading) to data/index
114
+ dicts."""
115
+ data_dict["smooth_root_2d"].append(self.smooth_root_2d)
116
+ index_dict["smooth_root_2d"].append(self.frame_indices)
117
+
118
+ if self.global_root_heading is not None:
119
+ # constraint the global heading
120
+ data_dict["global_root_heading"].append(self.global_root_heading)
121
+ index_dict["global_root_heading"].append(self.frame_indices)
122
+
123
+ def crop_move(self, start: int, end: int) -> "Root2DConstraintSet":
124
+ """Return a new constraint set for the cropped frame range [start, end)."""
125
+ mask = (self.frame_indices >= start) & (self.frame_indices < end)
126
+
127
+ if self.global_root_heading is not None:
128
+ masked_global_root_heading = self.global_root_heading[mask]
129
+ else:
130
+ masked_global_root_heading = None
131
+
132
+ return Root2DConstraintSet(
133
+ self.skeleton,
134
+ self.frame_indices[mask] - start,
135
+ self.smooth_root_2d[mask],
136
+ global_root_heading=masked_global_root_heading,
137
+ )
138
+
139
+ def get_save_info(self) -> dict:
140
+ """Return a dict suitable for JSON serialization (frame_indices, smooth_root_2d, optional
141
+ global_root_heading)."""
142
+ out = {
143
+ "type": self.name,
144
+ "frame_indices": self.frame_indices,
145
+ "smooth_root_2d": self.smooth_root_2d,
146
+ }
147
+ if self.global_root_heading is not None:
148
+ out["global_root_heading"] = self.global_root_heading
149
+ return out
150
+
151
+ def to(
152
+ self,
153
+ device: Optional[Union[str, torch.device]] = None,
154
+ dtype: Optional[torch.dtype] = None,
155
+ ) -> "Root2DConstraintSet":
156
+ self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype)
157
+ self.frame_indices = _tensor_to(self.frame_indices, device, dtype)
158
+ if self.global_root_heading is not None:
159
+ self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype)
160
+ if device is not None and hasattr(self.skeleton, "to"):
161
+ self.skeleton = self.skeleton.to(device)
162
+ return self
163
+
164
+ @classmethod
165
+ def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> "Root2DConstraintSet":
166
+ """Build a Root2DConstraintSet from a dict (e.g. loaded from JSON)."""
167
+ device = skeleton.device if hasattr(skeleton, "device") else "cpu"
168
+
169
+ if "global_root_heading" in dico:
170
+ global_root_heading = torch.tensor(dico["global_root_heading"], device=device)
171
+ else:
172
+ global_root_heading = None
173
+
174
+ return cls(
175
+ skeleton,
176
+ frame_indices=torch.tensor(dico["frame_indices"]),
177
+ smooth_root_2d=torch.tensor(dico["smooth_root_2d"], device=device),
178
+ global_root_heading=global_root_heading,
179
+ )
180
+
181
+
182
+ class FullBodyConstraintSet:
183
+ """Constraint set fixing full-body global positions and rotations on given keyframes."""
184
+
185
+ name = "fullbody"
186
+
187
+ def __init__(
188
+ self,
189
+ skeleton: SkeletonBase,
190
+ frame_indices: Tensor,
191
+ global_joints_positions: Tensor,
192
+ global_joints_rots: Tensor,
193
+ smooth_root_2d: Optional[Tensor] = None,
194
+ to_crop: bool = False,
195
+ ):
196
+ self.skeleton = skeleton
197
+ self.frame_indices = frame_indices
198
+
199
+ # if we pass the full smooth root 3D as input
200
+ if smooth_root_2d is not None and smooth_root_2d.shape[-1] == 3:
201
+ smooth_root_2d = smooth_root_2d[..., [0, 1]]
202
+
203
+ if to_crop:
204
+ global_joints_positions = global_joints_positions[frame_indices]
205
+ global_joints_rots = global_joints_rots[frame_indices]
206
+ if smooth_root_2d is not None:
207
+ smooth_root_2d = smooth_root_2d[frame_indices]
208
+ else:
209
+ assert len(global_joints_positions) == len(
210
+ frame_indices
211
+ ), "The number of global positions should be match the number of frames"
212
+ assert len(global_joints_rots) == len(
213
+ frame_indices
214
+ ), "The number of global joint rotations should be match the number of frames"
215
+
216
+ if smooth_root_2d is not None:
217
+ assert len(smooth_root_2d) == len(
218
+ frame_indices
219
+ ), "The number of smooth root 2d (if specified) should be match the number of frames"
220
+
221
+ if smooth_root_2d is None:
222
+ # substitute the smooth root 2d with the real root
223
+ smooth_root_2d = global_joints_positions[:, skeleton.root_idx, [0, 2]]
224
+
225
+ # root y: from smooth or pelvis is the same
226
+ self.root_y_pos = global_joints_positions[:, skeleton.root_idx, 1]
227
+
228
+ self.global_joints_positions = global_joints_positions
229
+ self.global_joints_rots = global_joints_rots
230
+ self.global_root_heading = compute_global_heading(global_joints_positions, skeleton)
231
+ self.smooth_root_2d = smooth_root_2d
232
+
233
+ def update_constraints(self, data_dict: dict, index_dict: dict) -> None:
234
+ """Append global positions, smooth root 2D, root y, and global heading to data/index
235
+ dicts."""
236
+ nbjoints = self.skeleton.nbjoints
237
+ indices_lst = create_pairs(
238
+ self.frame_indices,
239
+ torch.arange(nbjoints, device=self.frame_indices.device),
240
+ )
241
+ data_dict["global_joints_positions"].append(
242
+ self.global_joints_positions.reshape(-1, 3)
243
+ ) # flatten the global positions
244
+ index_dict["global_joints_positions"].append(indices_lst)
245
+
246
+ # global rotations are not used here
247
+
248
+ # as we use smooth root, also constraint the smooth root to get the same full body
249
+ # maybe keep storing the hips offset, if we smooth it ourselves
250
+ data_dict["smooth_root_2d"].append(self.smooth_root_2d)
251
+ index_dict["smooth_root_2d"].append(self.frame_indices)
252
+
253
+ # constraint the y pos of the root
254
+ data_dict["root_y_pos"].append(self.root_y_pos)
255
+ index_dict["root_y_pos"].append(self.frame_indices)
256
+
257
+ # constraint the global heading
258
+ data_dict["global_root_heading"].append(self.global_root_heading)
259
+ index_dict["global_root_heading"].append(self.frame_indices)
260
+
261
+ def crop_move(self, start: int, end: int) -> "FullBodyConstraintSet":
262
+ """Return a new FullBodyConstraintSet for the cropped frame range [start, end)."""
263
+ mask = (self.frame_indices >= start) & (self.frame_indices < end)
264
+ return FullBodyConstraintSet(
265
+ self.skeleton,
266
+ self.frame_indices[mask] - start,
267
+ self.global_joints_positions[mask],
268
+ self.global_joints_rots[mask],
269
+ self.smooth_root_2d[mask],
270
+ )
271
+
272
+ def get_save_info(self) -> dict:
273
+ """Return a dict for JSON save: type, frame_indices, local_joints_rot, root_positions, smooth_root_2d."""
274
+ local_joints_rot = self.skeleton.global_rots_to_local_rots(self.global_joints_rots)
275
+ if isinstance(self.skeleton, SOMASkeleton30):
276
+ local_joints_rot = self.skeleton.to_SOMASkeleton77(local_joints_rot)
277
+ local_joints_rot = matrix_to_axis_angle(local_joints_rot)
278
+
279
+ root_positions = self.global_joints_positions[:, self.skeleton.root_idx]
280
+ return {
281
+ "type": self.name,
282
+ "frame_indices": self.frame_indices,
283
+ "local_joints_rot": local_joints_rot,
284
+ "root_positions": root_positions,
285
+ "smooth_root_2d": self.smooth_root_2d,
286
+ }
287
+
288
+ def to(
289
+ self,
290
+ device: Optional[Union[str, torch.device]] = None,
291
+ dtype: Optional[torch.dtype] = None,
292
+ ) -> "FullBodyConstraintSet":
293
+ self.frame_indices = _tensor_to(self.frame_indices, device, dtype)
294
+ self.global_joints_positions = _tensor_to(self.global_joints_positions, device, dtype)
295
+ self.global_joints_rots = _tensor_to(self.global_joints_rots, device, dtype)
296
+ self.root_y_pos = _tensor_to(self.root_y_pos, device, dtype)
297
+ self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype)
298
+ self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype)
299
+ if device is not None and hasattr(self.skeleton, "to"):
300
+ self.skeleton = self.skeleton.to(device)
301
+ return self
302
+
303
+ @classmethod
304
+ def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> "FullBodyConstraintSet":
305
+ """Build a FullBodyConstraintSet from a dict (e.g. loaded from JSON)."""
306
+ frame_indices = torch.tensor(dico["frame_indices"])
307
+ device = skeleton.device if hasattr(skeleton, "device") else "cpu"
308
+ local_rot = torch.tensor(dico["local_joints_rot"], device=device)
309
+ local_rot_mats = axis_angle_to_matrix(local_rot)
310
+ local_rot_mats = _convert_constraint_local_rots_to_skeleton(local_rot_mats, skeleton)
311
+ global_joints_rots, global_joints_positions, _ = skeleton.fk(
312
+ local_rot_mats,
313
+ torch.tensor(dico["root_positions"], device=device),
314
+ )
315
+ smooth_root_2d = None
316
+ if "smooth_root_2d" in dico:
317
+ smooth_root_2d = torch.tensor(dico["smooth_root_2d"], device=device)
318
+
319
+ return cls(
320
+ skeleton,
321
+ frame_indices=frame_indices,
322
+ global_joints_positions=global_joints_positions,
323
+ global_joints_rots=global_joints_rots,
324
+ smooth_root_2d=smooth_root_2d,
325
+ )
326
+
327
+
328
+ class EndEffectorConstraintSet:
329
+ """Constraint set fixing selected end-effector positions and rotations on given frames."""
330
+
331
+ name = "end-effector"
332
+
333
+ def __init__(
334
+ self,
335
+ skeleton: SkeletonBase,
336
+ frame_indices: Tensor,
337
+ global_joints_positions: Tensor,
338
+ global_joints_rots: Tensor,
339
+ smooth_root_2d: Optional[Tensor],
340
+ *,
341
+ joint_names: list[str],
342
+ to_crop: bool = False,
343
+ ) -> None:
344
+ self.skeleton = skeleton
345
+ self.frame_indices = frame_indices
346
+ self.joint_names = joint_names
347
+
348
+ # joint_names are constant for all the frames
349
+ rot_joint_names, pos_joint_names = self.skeleton.expand_joint_names(self.joint_names)
350
+ # indexing works for motion_rep with smooth root only (contains pelvis index)
351
+ self.pos_indices = torch.tensor([self.skeleton.bone_index[jname] for jname in pos_joint_names])
352
+ self.rot_indices = torch.tensor([self.skeleton.bone_index[jname] for jname in rot_joint_names])
353
+
354
+ # if we pass the full smooth root 3D as input
355
+ if smooth_root_2d is not None and smooth_root_2d.shape[-1] == 3:
356
+ smooth_root_2d = smooth_root_2d[..., [0, 1]]
357
+
358
+ if to_crop:
359
+ global_joints_positions = global_joints_positions[frame_indices]
360
+ global_joints_rots = global_joints_rots[frame_indices]
361
+ if smooth_root_2d is not None:
362
+ smooth_root_2d = smooth_root_2d[frame_indices]
363
+ else:
364
+ assert len(global_joints_positions) == len(
365
+ frame_indices
366
+ ), "The number of global positions should be match the number of frames"
367
+ assert len(global_joints_rots) == len(
368
+ frame_indices
369
+ ), "The number of global joint rotations should be match the number of frames"
370
+ if smooth_root_2d is not None:
371
+ assert len(smooth_root_2d) == len(
372
+ frame_indices
373
+ ), "The number of smooth root 2d (if specified) should be match the number of frames"
374
+
375
+ if smooth_root_2d is None:
376
+ # substitute the smooth root 2d with the real root
377
+ smooth_root_2d = global_joints_positions[:, skeleton.root_idx, [0, 2]]
378
+
379
+ # root y: from smooth or pelvis is the same
380
+ self.root_y_pos = global_joints_positions[:, skeleton.root_idx, 1]
381
+
382
+ self.global_joints_positions = global_joints_positions
383
+ self.global_root_heading = compute_global_heading(global_joints_positions, skeleton)
384
+ self.global_joints_rots = global_joints_rots
385
+ self.smooth_root_2d = smooth_root_2d
386
+
387
+ def update_constraints(self, data_dict: dict, index_dict: dict) -> None:
388
+ """Append constrained joint positions/rots, smooth root 2D, root y, and heading to
389
+ data/index dicts."""
390
+ crop_frames_indexing = torch.arange(len(self.frame_indices), device=self.frame_indices.device)
391
+
392
+ # constraint positions
393
+ pos_indices_real = create_pairs(
394
+ self.frame_indices,
395
+ self.pos_indices,
396
+ )
397
+ pos_indices_crop = create_pairs(
398
+ crop_frames_indexing,
399
+ self.pos_indices,
400
+ )
401
+ data_dict["global_joints_positions"].append(self.global_joints_positions[tuple(pos_indices_crop.T)])
402
+ index_dict["global_joints_positions"].append(pos_indices_real)
403
+
404
+ # constraint rotations
405
+ rot_indices_real = create_pairs(
406
+ self.frame_indices,
407
+ self.rot_indices,
408
+ )
409
+ rot_indices_crop = create_pairs(
410
+ crop_frames_indexing,
411
+ self.rot_indices,
412
+ )
413
+ data_dict["global_joints_rots"].append(self.global_joints_rots[tuple(rot_indices_crop.T)])
414
+ index_dict["global_joints_rots"].append(rot_indices_real)
415
+
416
+ # as we use smooth root, also constraint the smooth root to get the same full body
417
+ # maybe keep storing the hips offset, if we smooth it ourselves
418
+ data_dict["smooth_root_2d"].append(self.smooth_root_2d)
419
+ index_dict["smooth_root_2d"].append(self.frame_indices)
420
+
421
+ # constraint the y pos of the root
422
+ data_dict["root_y_pos"].append(self.root_y_pos)
423
+ index_dict["root_y_pos"].append(self.frame_indices)
424
+
425
+ # constraint the global heading
426
+ data_dict["global_root_heading"].append(self.global_root_heading)
427
+ index_dict["global_root_heading"].append(self.frame_indices)
428
+
429
+ def crop_move(self, start: int, end: int) -> "EndEffectorConstraintSet":
430
+ """Return a new EndEffectorConstraintSet for the cropped frame range [start, end)."""
431
+ mask = (self.frame_indices >= start) & (self.frame_indices < end)
432
+
433
+ cls = type(self)
434
+ kwargs = {}
435
+ if not hasattr(cls, "joint_names"):
436
+ kwargs["joint_names"] = self.joint_names
437
+
438
+ return cls(
439
+ self.skeleton,
440
+ self.frame_indices[mask] - start,
441
+ self.global_joints_positions[mask],
442
+ self.global_joints_rots[mask],
443
+ self.smooth_root_2d[mask],
444
+ **kwargs,
445
+ )
446
+
447
+ def get_save_info(self) -> dict:
448
+ """Return a dict for JSON save: type, frame_indices, local_joints_rot, root_positions, smooth_root_2d, joint_names."""
449
+ local_joints_rot = self.skeleton.global_rots_to_local_rots(self.global_joints_rots)
450
+ if isinstance(self.skeleton, SOMASkeleton30):
451
+ local_joints_rot = self.skeleton.to_SOMASkeleton77(local_joints_rot)
452
+ local_joints_rot = matrix_to_axis_angle(local_joints_rot)
453
+
454
+ root_positions = self.global_joints_positions[:, self.skeleton.root_idx]
455
+ output = {
456
+ "type": self.name,
457
+ "frame_indices": self.frame_indices,
458
+ "local_joints_rot": local_joints_rot,
459
+ "root_positions": root_positions,
460
+ "smooth_root_2d": self.smooth_root_2d,
461
+ }
462
+ if not hasattr(self.__class__, "joint_names"):
463
+ # save the joint_names for this base class
464
+ # but not for children
465
+ output["joint_names"] = self.joint_names
466
+ return output
467
+
468
+ def to(
469
+ self,
470
+ device: Optional[Union[str, torch.device]] = None,
471
+ dtype: Optional[torch.dtype] = None,
472
+ ) -> "EndEffectorConstraintSet":
473
+ self.frame_indices = _tensor_to(self.frame_indices, device, dtype)
474
+ self.pos_indices = _tensor_to(self.pos_indices, device, dtype)
475
+ self.rot_indices = _tensor_to(self.rot_indices, device, dtype)
476
+ self.root_y_pos = _tensor_to(self.root_y_pos, device, dtype)
477
+ self.global_joints_positions = _tensor_to(self.global_joints_positions, device, dtype)
478
+ self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype)
479
+ self.global_joints_rots = _tensor_to(self.global_joints_rots, device, dtype)
480
+ self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype)
481
+ if device is not None and hasattr(self.skeleton, "to"):
482
+ self.skeleton = self.skeleton.to(device)
483
+ return self
484
+
485
+ @classmethod
486
+ def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> "EndEffectorConstraintSet":
487
+ """Build an EndEffectorConstraintSet from a dict (e.g. loaded from JSON)."""
488
+ frame_indices = torch.tensor(dico["frame_indices"])
489
+ device = skeleton.device if hasattr(skeleton, "device") else "cpu"
490
+ local_rot = torch.tensor(dico["local_joints_rot"], device=device)
491
+ local_rot_mats = axis_angle_to_matrix(local_rot)
492
+ local_rot_mats = _convert_constraint_local_rots_to_skeleton(local_rot_mats, skeleton)
493
+ global_joints_rots, global_joints_positions, _ = skeleton.fk(
494
+ local_rot_mats,
495
+ torch.tensor(dico["root_positions"], device=device),
496
+ )
497
+ smooth_root_2d = None
498
+ if "smooth_root_2d" in dico:
499
+ smooth_root_2d = torch.tensor(dico["smooth_root_2d"], device=device)
500
+
501
+ kwargs = {}
502
+ if not hasattr(cls, "joint_names"):
503
+ kwargs["joint_names"] = dico["joint_names"]
504
+
505
+ return cls(
506
+ skeleton,
507
+ frame_indices=frame_indices,
508
+ global_joints_positions=global_joints_positions,
509
+ global_joints_rots=global_joints_rots,
510
+ smooth_root_2d=smooth_root_2d,
511
+ **kwargs,
512
+ )
513
+
514
+
515
+ class LeftHandConstraintSet(EndEffectorConstraintSet):
516
+ """End-effector constraint for the left hand only."""
517
+
518
+ name = "left-hand"
519
+ joint_names: list[str] = ["LeftHand"]
520
+
521
+ def __init__(self, *args, **kwargs: dict):
522
+ super().__init__(*args, joint_names=self.joint_names, **kwargs)
523
+
524
+
525
+ class RightHandConstraintSet(EndEffectorConstraintSet):
526
+ """End-effector constraint for the right hand only."""
527
+
528
+ name = "right-hand"
529
+ joint_names: list[str] = ["RightHand"]
530
+
531
+ def __init__(self, *args, **kwargs: dict):
532
+ super().__init__(*args, joint_names=self.joint_names, **kwargs)
533
+
534
+
535
+ class LeftFootConstraintSet(EndEffectorConstraintSet):
536
+ """End-effector constraint for the left foot only."""
537
+
538
+ name = "left-foot"
539
+ joint_names: list[str] = ["LeftFoot"]
540
+
541
+ def __init__(self, *args, **kwargs: dict):
542
+ super().__init__(*args, joint_names=self.joint_names, **kwargs)
543
+
544
+
545
+ class RightFootConstraintSet(EndEffectorConstraintSet):
546
+ """End-effector constraint for the right foot only."""
547
+
548
+ name = "right-foot"
549
+ joint_names: list[str] = ["RightFoot"]
550
+
551
+ def __init__(self, *args, **kwargs: dict):
552
+ super().__init__(*args, joint_names=self.joint_names, **kwargs)
553
+
554
+
555
+ TYPE_TO_CLASS = {
556
+ "root2d": Root2DConstraintSet,
557
+ "fullbody": FullBodyConstraintSet,
558
+ "left-hand": LeftHandConstraintSet,
559
+ "right-hand": RightHandConstraintSet,
560
+ "left-foot": LeftFootConstraintSet,
561
+ "right-foot": RightFootConstraintSet,
562
+ "end-effector": EndEffectorConstraintSet,
563
+ }
564
+
565
+
566
+ def load_constraints_lst(
567
+ path_or_data: str | list,
568
+ skeleton: SkeletonBase,
569
+ device: Optional[Union[str, torch.device]] = None,
570
+ dtype: Optional[torch.dtype] = None,
571
+ ):
572
+ """Load a list of constraints from JSON path or list of dicts.
573
+
574
+ Args:
575
+ path_or_data: Path to constraints.json or list of constraint dicts.
576
+ skeleton: Skeleton instance (used for from_dict).
577
+ device: If set, move all constraint tensors and skeleton to this device.
578
+ dtype: If set, cast constraint tensors to this dtype.
579
+ """
580
+ if isinstance(path_or_data, str):
581
+ saved = load_json(path_or_data)
582
+ else:
583
+ saved = path_or_data
584
+
585
+ constraints_lst = []
586
+ for el in saved:
587
+ cls = TYPE_TO_CLASS[el["type"]]
588
+ c = cls.from_dict(skeleton, el)
589
+ if device is not None or dtype is not None:
590
+ c.to(device=device, dtype=dtype)
591
+ constraints_lst.append(c)
592
+ return constraints_lst
593
+
594
+
595
+ def save_constraints_lst(path: str, constraints_lst: list) -> list | None:
596
+ """Save a list of constraint sets to a JSON file.
597
+
598
+ Returns None if list is empty.
599
+ """
600
+ if not constraints_lst:
601
+ print("The constraints lst is empty. Skip saving")
602
+ return
603
+
604
+ to_save = []
605
+
606
+ def tensor_to_list(obj):
607
+ """Recursively convert tensors to lists for JSON serialization."""
608
+ if isinstance(obj, Tensor):
609
+ return obj.cpu().tolist()
610
+ elif isinstance(obj, dict):
611
+ return {k: tensor_to_list(v) for k, v in obj.items()}
612
+ elif isinstance(obj, list):
613
+ return [tensor_to_list(v) for v in obj]
614
+ else:
615
+ return obj
616
+
617
+ for constraint in constraints_lst:
618
+ constraint_info = constraint.get_save_info()
619
+ # Convert all tensors to lists for JSON serialization
620
+ constraint_info = tensor_to_list(constraint_info)
621
+ to_save.append(constraint_info)
622
+
623
+ save_json(path, to_save)
624
+ print(f"Saved constraints to {path}")
625
+ return to_save
kimodo/exports/__init__.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Export utilities: MuJoCo, BVH, SMPLX/AMASS, and motion I/O helpers."""
4
+
5
+ from .bvh import bvh_to_kimodo_motion, motion_to_bvh_bytes, read_bvh_frame_time_seconds, save_motion_bvh
6
+ from .motion_convert_lib import convert_motion_files
7
+ from .motion_formats import (
8
+ infer_npz_kind,
9
+ infer_source_format_from_path,
10
+ infer_target_format_from_path,
11
+ resolve_source_fps,
12
+ )
13
+ from .motion_io import (
14
+ KIMODO_CONVERT_TARGET_FPS,
15
+ amass_npz_to_bytes,
16
+ complete_motion_dict,
17
+ g1_csv_to_bytes,
18
+ kimodo_npz_to_bytes,
19
+ load_amass_npz,
20
+ load_g1_csv,
21
+ load_kimodo_npz,
22
+ load_kimodo_npz_as_torch,
23
+ load_motion_file,
24
+ motion_dict_to_numpy,
25
+ save_kimodo_npz,
26
+ save_kimodo_npz_at_target_fps,
27
+ )
28
+ from .mujoco import MujocoQposConverter, apply_g1_real_robot_projection
29
+ from .smplx import (
30
+ AMASSConverter,
31
+ amass_npz_to_kimodo_motion,
32
+ get_amass_parameters,
33
+ kimodo_y_up_to_amass_coord_rotation_matrix,
34
+ )
35
+
36
+ __all__ = [
37
+ "AMASSConverter",
38
+ "KIMODO_CONVERT_TARGET_FPS",
39
+ "MujocoQposConverter",
40
+ "amass_npz_to_bytes",
41
+ "amass_npz_to_kimodo_motion",
42
+ "apply_g1_real_robot_projection",
43
+ "bvh_to_kimodo_motion",
44
+ "complete_motion_dict",
45
+ "convert_motion_files",
46
+ "g1_csv_to_bytes",
47
+ "get_amass_parameters",
48
+ "infer_npz_kind",
49
+ "infer_source_format_from_path",
50
+ "infer_target_format_from_path",
51
+ "kimodo_npz_to_bytes",
52
+ "kimodo_y_up_to_amass_coord_rotation_matrix",
53
+ "load_amass_npz",
54
+ "load_g1_csv",
55
+ "load_kimodo_npz",
56
+ "load_kimodo_npz_as_torch",
57
+ "load_motion_file",
58
+ "motion_dict_to_numpy",
59
+ "motion_to_bvh_bytes",
60
+ "read_bvh_frame_time_seconds",
61
+ "resolve_source_fps",
62
+ "save_kimodo_npz",
63
+ "save_kimodo_npz_at_target_fps",
64
+ "save_motion_bvh",
65
+ ]
kimodo/exports/bvh.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Export utilities for converting internal motion representations into common file formats.
4
+
5
+ This module is intended to hold lightweight serialization / export helpers that can be reused
6
+ outside of interactive demos.
7
+ """
8
+
9
+ import os
10
+ import tempfile
11
+ from pathlib import Path
12
+ from typing import Tuple, Union
13
+
14
+ import numpy as np
15
+ import torch
16
+
17
+ from kimodo.geometry import matrix_to_quaternion as _matrix_to_quaternion
18
+
19
+
20
+ def _strip_end_site_blocks(bvh_text: str) -> str:
21
+ """Remove all 'End Site { ... }' blocks from BVH text so output matches original format.
22
+
23
+ bvhio adds an End Site for every leaf joint when writing; we do not set EndSite on joints, so we
24
+ post-process the string to remove these blocks for Blender/original compatibility.
25
+ """
26
+ lines = bvh_text.splitlines(keepends=True)
27
+ result = []
28
+ i = 0
29
+ while i < len(lines):
30
+ line = lines[i]
31
+ if "End Site" in line:
32
+ # Skip this line and the following block { ... }; brace-count to find closing }
33
+ i += 1
34
+ if i < len(lines) and "{" in lines[i]:
35
+ i += 1
36
+ depth = 1
37
+ while i < len(lines) and depth > 0:
38
+ if "{" in lines[i]:
39
+ depth += 1
40
+ if "}" in lines[i]:
41
+ depth -= 1
42
+ i += 1
43
+ continue
44
+ result.append(line)
45
+ i += 1
46
+ return "".join(result)
47
+
48
+
49
+ def _coerce_batch(name: str, x: torch.Tensor, *, expected_ndim: int) -> torch.Tensor:
50
+ """Coerce (T, ...) or (1, T, ...) into (T, ...)."""
51
+ if x.ndim == expected_ndim:
52
+ return x
53
+ if x.ndim == expected_ndim + 1:
54
+ if int(x.shape[0]) != 1:
55
+ raise ValueError(
56
+ f"{name} has batch dimension B={int(x.shape[0])}, but BVH export " "only supports a single clip (B==1)."
57
+ )
58
+ return x[0]
59
+ raise ValueError(f"{name} must have shape (T, ...) or (1, T, ...); got {tuple(x.shape)}")
60
+
61
+
62
+ def motion_to_bvh(
63
+ local_rot_mats: torch.Tensor,
64
+ root_positions: torch.Tensor,
65
+ *,
66
+ skeleton,
67
+ fps: float,
68
+ ) -> str:
69
+ """Convert local rotations and root positions to BVH format; return UTF-8 string.
70
+
71
+ Args:
72
+ local_rot_mats: (T, J, 3, 3) or (1, T, J, 3, 3) local rotation matrices.
73
+ root_positions: (T, 3) or (1, T, 3) root joint positions (e.g. from posed joints).
74
+ skeleton: Skeleton with bone_order_names, bvh_neutral_joints, etc.
75
+ fps: Frames per second for the motion.
76
+
77
+ Notes:
78
+ BVH is plain-text. Root is named "Root" with ZYX rotation order; leaf joints
79
+ have no End Site block.
80
+ """
81
+ try:
82
+ import bvhio # type: ignore[import-not-found]
83
+ import glm # type: ignore[import-not-found]
84
+ from SpatialTransform import Pose # type: ignore[import-not-found]
85
+ except Exception as e: # pragma: no cover
86
+ raise ImportError(
87
+ "BVH export requires `bvhio` (and its deps `PyGLM` + `SpatialTransform`). "
88
+ "Install with: `pip install bvhio`."
89
+ ) from e
90
+
91
+ local_rot_mats = local_rot_mats.detach()
92
+ root_positions = root_positions.detach()
93
+ # SOMA: accept either somaskel30 (convert to 77) or somaskel77 (use as-is)
94
+ if skeleton.name == "somaskel30":
95
+ local_rot_mats = skeleton.to_SOMASkeleton77(local_rot_mats)
96
+ skeleton = skeleton.somaskel77
97
+
98
+ local_rot_mats, _ = skeleton.from_standard_tpose(local_rot_mats)
99
+
100
+ neutral = skeleton.bvh_neutral_joints.detach().cpu().numpy()
101
+ joint_names = list(skeleton.bone_order_names)
102
+ parents = skeleton.joint_parents.detach().cpu().numpy().astype(int)
103
+ root_idx = int(skeleton.root_idx)
104
+
105
+ local_rot_mats = _coerce_batch("local_rot_mats", local_rot_mats, expected_ndim=4)
106
+ T, J = local_rot_mats.shape[:2]
107
+ q_wxyz = _matrix_to_quaternion(local_rot_mats).detach().cpu().numpy() # [T, J, 4]
108
+
109
+ root_xyz = _coerce_batch("root_positions", root_positions, expected_ndim=2)
110
+ root_xyz = root_xyz.cpu().numpy() # [T, 3]
111
+
112
+ # Build BVH hierarchy: Root (wrapper at origin) -> Hips (pelvis with offset in meters) -> ...
113
+ # Offsets are in meters to match the original format.
114
+ children: dict[int, list[int]] = {i: [] for i in range(J)}
115
+ for i, p in enumerate(parents):
116
+ if p >= 0:
117
+ children[int(p)].append(int(i))
118
+
119
+ _ROOT_CHANNELS = [
120
+ "Xposition",
121
+ "Yposition",
122
+ "Zposition",
123
+ "Zrotation",
124
+ "Yrotation",
125
+ "Xrotation",
126
+ ]
127
+ _JOINT_CHANNELS = ["Zrotation", "Yrotation", "Xrotation"]
128
+
129
+ # Scale from meters to centimeters (match original BVH scale).
130
+ neutral = neutral * 100
131
+ root_xyz = root_xyz * 100
132
+
133
+ # Hips offset from Root: use skeleton neutral; if root is at origin (zeros), use a
134
+ # nominal pelvis height so the hierarchy is non-degenerate in Blender.
135
+ hips_offset = neutral[root_idx]
136
+ if (hips_offset == 0).all():
137
+ hips_offset = np.array([0.0, 100.0, 0.0], dtype=neutral.dtype) # 1 m in cm
138
+
139
+ def _make_joint(i: int) -> "bvhio.BvhJoint":
140
+ name = joint_names[i]
141
+ j = bvhio.BvhJoint(name, offset=glm.vec3(0, 0, 0))
142
+ if i == root_idx:
143
+ # Hips: offset from Root (origin) in cm
144
+ off = hips_offset
145
+ j.Offset = glm.vec3(float(off[0]), float(off[1]), float(off[2]))
146
+ j.Channels = _ROOT_CHANNELS.copy()
147
+ else:
148
+ p = int(parents[i])
149
+ off = neutral[i] - neutral[p]
150
+ j.Offset = glm.vec3(float(off[0]), float(off[1]), float(off[2]))
151
+ j.Channels = _JOINT_CHANNELS.copy()
152
+
153
+ for c in children[i]:
154
+ j.Children.append(_make_joint(c))
155
+ return j
156
+
157
+ # Wrapper Root at origin; single child is Hips (skeleton root).
158
+ root_wrapper = bvhio.BvhJoint("Root", offset=glm.vec3(0.0, 0.0, 0.0))
159
+ root_wrapper.Channels = _ROOT_CHANNELS.copy()
160
+ root_wrapper.Children.append(_make_joint(root_idx))
161
+ root_joint = root_wrapper
162
+
163
+ # Populate keyframes: Root = identity/zero, Hips = root motion, others = local rotation.
164
+ bvh_layout = root_joint.layout()
165
+ name_to_id = {n: idx for idx, n in enumerate(joint_names)}
166
+ ordered_joint_ids = []
167
+ for bj, _, _ in bvh_layout:
168
+ if bj.Name == "Root":
169
+ ordered_joint_ids.append(None)
170
+ else:
171
+ ordered_joint_ids.append(name_to_id[bj.Name])
172
+
173
+ bvh_joints = [bj for bj, _, _ in bvh_layout]
174
+ for bj in bvh_joints:
175
+ bj.Keyframes = [None] * T # type: ignore[list-item]
176
+
177
+ identity_quat = glm.quat(1.0, 0.0, 0.0, 0.0)
178
+ zero_vec = glm.vec3(0.0, 0.0, 0.0)
179
+ for t in range(T):
180
+ for bj, jid in zip(bvh_joints, ordered_joint_ids):
181
+ if jid is None:
182
+ position = zero_vec
183
+ rotation = identity_quat
184
+ elif jid == root_idx:
185
+ pos = root_xyz[t]
186
+ position = glm.vec3(float(pos[0]), float(pos[1]), float(pos[2]))
187
+ qw, qx, qy, qz = q_wxyz[t, jid]
188
+ rotation = glm.quat(float(qw), float(qx), float(qy), float(qz))
189
+ else:
190
+ position = zero_vec
191
+ qw, qx, qy, qz = q_wxyz[t, jid]
192
+ rotation = glm.quat(float(qw), float(qx), float(qy), float(qz))
193
+ bj.Keyframes[t] = Pose(position, rotation) # type: ignore[index]
194
+
195
+ container = bvhio.BvhContainer(root_joint, frameCount=T, frameTime=1.0 / float(fps))
196
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".bvh", delete=False, encoding="utf-8") as f:
197
+ tmp_path = f.name
198
+ try:
199
+ bvhio.writeBvh(tmp_path, container, percision=6)
200
+ bvh_text = Path(tmp_path).read_text(encoding="utf-8")
201
+ return _strip_end_site_blocks(bvh_text)
202
+ finally:
203
+ try:
204
+ os.remove(tmp_path)
205
+ except Exception:
206
+ pass
207
+
208
+
209
+ def motion_to_bvh_bytes(
210
+ local_rot_mats: torch.Tensor,
211
+ root_positions: torch.Tensor,
212
+ *,
213
+ skeleton,
214
+ fps: float,
215
+ ) -> bytes:
216
+ """Convert local rotations and root positions to BVH bytes (UTF-8).
217
+
218
+ Convenience wrapper around :func:`motion_to_bvh`.
219
+ """
220
+ return motion_to_bvh(local_rot_mats, root_positions, skeleton=skeleton, fps=fps).encode("utf-8")
221
+
222
+
223
+ def save_motion_bvh(
224
+ path: Union[str, Path],
225
+ local_rot_mats: torch.Tensor,
226
+ root_positions: torch.Tensor,
227
+ *,
228
+ skeleton,
229
+ fps: float,
230
+ ) -> None:
231
+ """Write local rotations and root positions to a BVH file at the given path."""
232
+ Path(path).write_text(
233
+ motion_to_bvh(local_rot_mats, root_positions, skeleton=skeleton, fps=fps),
234
+ encoding="utf-8",
235
+ )
236
+
237
+
238
+ def read_bvh_frame_time_seconds(path: Union[str, Path]) -> float:
239
+ """Read ``Frame Time`` from a BVH file (seconds per frame)."""
240
+ with open(path, encoding="utf-8") as f:
241
+ for line in f:
242
+ if "Frame Time:" in line:
243
+ parts = line.split()
244
+ return float(parts[-1])
245
+ raise ValueError(f"Could not find 'Frame Time:' in {path}")
246
+
247
+
248
+ def bvh_to_kimodo_motion(
249
+ path: Union[str, Path],
250
+ skeleton=None,
251
+ ) -> Tuple:
252
+ """Load a Kimodo-style SOMA BVH into a Kimodo motion dict.
253
+
254
+ Expects the same hierarchy as :func:`save_motion_bvh` (``Root`` wrapper + SOMA77 joints).
255
+ The frame rate is always read from the BVH ``Frame Time`` header. Callers
256
+ that need a different playback rate should resample the returned motion dict
257
+ (see :func:`~kimodo.exports.motion_io.resample_motion_dict_to_kimodo_fps`).
258
+
259
+ Returns:
260
+ ``(motion_dict, source_fps)`` where ``source_fps`` is the native BVH
261
+ frame rate read from the file header.
262
+ """
263
+ from kimodo.exports.motion_io import complete_motion_dict
264
+ from kimodo.skeleton.bvh import parse_bvh_motion
265
+ from kimodo.skeleton.registry import build_skeleton
266
+
267
+ if skeleton is None:
268
+ skeleton = build_skeleton(77)
269
+ device = skeleton.neutral_joints.device
270
+
271
+ local_rot_mats, root_trans, bvh_fps = parse_bvh_motion(str(path))
272
+ local_rot_mats = local_rot_mats.to(device=device)
273
+ root_trans = root_trans.to(device=device)
274
+
275
+ if int(local_rot_mats.shape[1]) != int(skeleton.nbjoints):
276
+ raise ValueError(
277
+ f"BVH has {local_rot_mats.shape[1]} joints but skeleton has {skeleton.nbjoints}; "
278
+ "use a Kimodo-exported SOMA BVH or matching skeleton."
279
+ )
280
+ local_rot_mats, _ = skeleton.to_standard_tpose(local_rot_mats)
281
+
282
+ return complete_motion_dict(local_rot_mats, root_trans, skeleton, float(bvh_fps)), bvh_fps
kimodo/exports/motion_convert_lib.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Library API for converting between Kimodo NPZ, AMASS NPZ, SOMA BVH, and G1 MuJoCo CSV."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import warnings
8
+
9
+ import numpy as np
10
+
11
+ from kimodo.exports.bvh import bvh_to_kimodo_motion, save_motion_bvh
12
+ from kimodo.exports.motion_formats import (
13
+ infer_source_format_from_path,
14
+ infer_target_format_from_path,
15
+ resolve_source_fps,
16
+ )
17
+ from kimodo.exports.motion_io import (
18
+ load_amass_npz,
19
+ load_g1_csv,
20
+ load_kimodo_npz_as_torch,
21
+ save_kimodo_npz_at_target_fps,
22
+ )
23
+ from kimodo.exports.mujoco import MujocoQposConverter
24
+ from kimodo.exports.smplx import AMASSConverter
25
+ from kimodo.skeleton.registry import build_skeleton
26
+
27
+
28
+ def convert_motion_files(
29
+ input_path: str,
30
+ output_path: str,
31
+ *,
32
+ from_fmt: str | None = None,
33
+ to_fmt: str | None = None,
34
+ source_fps: float | None = None,
35
+ z_up: bool = True,
36
+ mujoco_rest_zero: bool = False,
37
+ ) -> None:
38
+ """Convert a motion file between Kimodo-supported formats.
39
+
40
+ Supported pairs (hub-and-spoke through Kimodo NPZ):
41
+
42
+ - amass <-> kimodo
43
+ - soma-bvh <-> kimodo
44
+ - g1-csv <-> kimodo
45
+
46
+ Args:
47
+ input_path: Source file (``.npz``, ``.bvh``, or ``.csv``).
48
+ output_path: Destination file.
49
+ from_fmt: Source format; inferred from extension/contents when ``None``.
50
+ to_fmt: Target format; inferred from extension when ``None``.
51
+ source_fps: Source motion frame rate (Hz). If provided, trusted as-is.
52
+ If ``None``, auto-detected from BVH ``Frame Time``, AMASS
53
+ ``mocap_frame_rate``, or default 30.
54
+ z_up: For AMASS conversions, apply the Z-up <-> Kimodo Y-up transform.
55
+ mujoco_rest_zero: For G1 CSV, joint angles relative to MuJoCo rest pose.
56
+ """
57
+ from_fmt = from_fmt or infer_source_format_from_path(input_path)
58
+ to_fmt = to_fmt or infer_target_format_from_path(output_path, from_fmt)
59
+
60
+ _validate_output_extension(to_fmt, output_path)
61
+
62
+ pair = (from_fmt, to_fmt)
63
+
64
+ if pair == ("amass", "kimodo"):
65
+ sk = build_skeleton(22)
66
+ effective_source = source_fps
67
+ if effective_source is None:
68
+ with np.load(input_path, allow_pickle=True) as z:
69
+ effective_source = float(z["mocap_frame_rate"]) if "mocap_frame_rate" in z.files else 30.0
70
+ motion = load_amass_npz(input_path, source_fps=effective_source, z_up=z_up)
71
+ save_kimodo_npz_at_target_fps(motion, sk, effective_source, output_path)
72
+ return
73
+
74
+ if pair == ("kimodo", "amass"):
75
+ data, J = load_kimodo_npz_as_torch(input_path, ensure_complete=False)
76
+ if J != 22:
77
+ raise ValueError(f"Kimodo→AMASS requires 22 joints (SMPL-X); this file has J={J}.")
78
+ sk = build_skeleton(22)
79
+ effective_source = resolve_source_fps(source_fps, "kimodo", input_path, None)
80
+ converter = AMASSConverter(fps=effective_source, skeleton=sk)
81
+ converter.convert_save_npz(data, output_path, z_up=z_up)
82
+ return
83
+
84
+ if pair == ("soma-bvh", "kimodo"):
85
+ sk = build_skeleton(77)
86
+ motion, bvh_fps = bvh_to_kimodo_motion(input_path, skeleton=sk)
87
+ effective_source = source_fps if source_fps is not None else bvh_fps
88
+ save_kimodo_npz_at_target_fps(motion, sk, effective_source, output_path)
89
+ return
90
+
91
+ if pair == ("kimodo", "soma-bvh"):
92
+ data, J = load_kimodo_npz_as_torch(input_path, ensure_complete=False)
93
+ if J == 30:
94
+ warnings.warn(
95
+ f"Input has 30 joints (somaskel30); expanding to somaskel77 for BVH export.",
96
+ UserWarning,
97
+ stacklevel=2,
98
+ )
99
+ sk = build_skeleton(30)
100
+ elif J == 77:
101
+ sk = build_skeleton(77)
102
+ else:
103
+ raise ValueError(f"Kimodo→BVH requires a SOMA skeleton (30 or 77 joints); this file has J={J}.")
104
+ effective_source = resolve_source_fps(source_fps, "kimodo", input_path, None)
105
+ save_motion_bvh(
106
+ output_path,
107
+ data["local_rot_mats"],
108
+ data["root_positions"],
109
+ skeleton=sk,
110
+ fps=effective_source,
111
+ )
112
+ return
113
+
114
+ if pair == ("g1-csv", "kimodo"):
115
+ sk = build_skeleton(34)
116
+ effective_source = resolve_source_fps(source_fps, "g1-csv", input_path, None)
117
+ motion = load_g1_csv(input_path, source_fps=effective_source, mujoco_rest_zero=mujoco_rest_zero)
118
+ save_kimodo_npz_at_target_fps(motion, sk, effective_source, output_path)
119
+ return
120
+
121
+ if pair == ("kimodo", "g1-csv"):
122
+ data, J = load_kimodo_npz_as_torch(input_path, ensure_complete=False)
123
+ if J != 34:
124
+ raise ValueError(f"Kimodo→CSV requires G1 with 34 joints; this file has J={J}.")
125
+ sk = build_skeleton(34)
126
+ effective_source = resolve_source_fps(source_fps, "kimodo", input_path, None)
127
+ converter = MujocoQposConverter(sk)
128
+ qpos = converter.dict_to_qpos(
129
+ {k: v for k, v in data.items() if k in ("local_rot_mats", "root_positions")},
130
+ device=str(sk.neutral_joints.device),
131
+ numpy=True,
132
+ mujoco_rest_zero=mujoco_rest_zero,
133
+ )
134
+ converter.save_csv(qpos, output_path)
135
+ return
136
+
137
+ raise ValueError(
138
+ f"Unsupported conversion {from_fmt!r} → {to_fmt!r}. "
139
+ "Supported: amass↔kimodo (SMPL-X NPZ), soma-bvh↔kimodo, g1-csv↔kimodo."
140
+ )
141
+
142
+
143
+ def _validate_output_extension(to_fmt: str, output_path: str) -> None:
144
+ lower = output_path.lower()
145
+ if to_fmt == "kimodo" and lower.endswith(".npz"):
146
+ return
147
+ if to_fmt == "amass":
148
+ if not lower.endswith(".npz"):
149
+ raise ValueError("AMASS output must use a .npz path.")
150
+ elif to_fmt == "soma-bvh":
151
+ if not lower.endswith(".bvh"):
152
+ raise ValueError("SOMA BVH output must use a .bvh path.")
153
+ elif to_fmt == "g1-csv":
154
+ if not lower.endswith(".csv"):
155
+ raise ValueError("G1 CSV output must use a .csv path.")
kimodo/exports/motion_formats.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Infer motion file formats from paths and NPZ contents."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import os
8
+ from typing import Literal
9
+
10
+ import numpy as np
11
+
12
+ MotionSourceFormat = Literal["amass", "kimodo", "soma-bvh", "g1-csv"]
13
+ MotionTargetFormat = Literal["amass", "kimodo", "soma-bvh", "g1-csv"]
14
+ NpzMotionKind = Literal["amass", "kimodo"]
15
+
16
+
17
+ def infer_npz_kind(path: str) -> NpzMotionKind:
18
+ """Classify a ``.npz`` as AMASS SMPL-X or Kimodo from required array keys."""
19
+ with np.load(path, allow_pickle=False) as z:
20
+ keys = set(z.files)
21
+ if "trans" in keys and "pose_body" in keys and "root_orient" in keys:
22
+ return "amass"
23
+ if "local_rot_mats" in keys or "posed_joints" in keys:
24
+ return "kimodo"
25
+ raise ValueError(
26
+ f"Unrecognized NPZ {path!r}: expected AMASS keys (trans, pose_body, ...) "
27
+ "or Kimodo keys (local_rot_mats, posed_joints, ...)."
28
+ )
29
+
30
+
31
+ def infer_source_format_from_path(path: str) -> MotionSourceFormat:
32
+ """Infer converter input format from file extension and NPZ contents when needed."""
33
+ ext = os.path.splitext(path)[1].lower()
34
+ if ext == ".bvh":
35
+ return "soma-bvh"
36
+ if ext == ".csv":
37
+ return "g1-csv"
38
+ if ext == ".npz":
39
+ return infer_npz_kind(path) # type: ignore[return-value]
40
+ raise ValueError(f"Cannot infer format from extension of {path!r}")
41
+
42
+
43
+ def infer_target_format_from_path(path: str, from_fmt: MotionSourceFormat) -> MotionTargetFormat:
44
+ """Infer converter output format from destination path and source format."""
45
+ ext = os.path.splitext(path)[1].lower()
46
+ if ext == ".bvh":
47
+ return "soma-bvh"
48
+ if ext == ".csv":
49
+ return "g1-csv"
50
+ if ext == ".npz":
51
+ if from_fmt == "amass":
52
+ return "kimodo"
53
+ if from_fmt == "kimodo":
54
+ return "amass"
55
+ if from_fmt in ("g1-csv", "soma-bvh"):
56
+ return "kimodo"
57
+ raise ValueError(
58
+ "Ambiguous .npz output: set --to to 'kimodo' or 'amass' when the input format is not amass/kimodo."
59
+ )
60
+ raise ValueError(f"Cannot infer output format from extension of {path!r}")
61
+
62
+
63
+ def resolve_source_fps(
64
+ fps: float | None,
65
+ from_kind: str,
66
+ input_path: str,
67
+ data: dict | None,
68
+ ) -> float:
69
+ """Resolve source frame rate (Hz) for conversion when ``fps`` is not overridden."""
70
+ if fps is not None:
71
+ return float(fps)
72
+ if data is not None and "mocap_frame_rate" in data:
73
+ return float(np.asarray(data["mocap_frame_rate"]).item())
74
+ if from_kind == "soma-bvh":
75
+ from kimodo.exports.bvh import read_bvh_frame_time_seconds
76
+
77
+ return 1.0 / read_bvh_frame_time_seconds(input_path)
78
+ return 30.0
kimodo/exports/motion_io.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Assemble Kimodo NPZ-compatible motion dicts from local rotations + root trajectory."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import os
8
+ import warnings
9
+ from typing import Any, Dict, Tuple
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ from kimodo.geometry import matrix_to_quaternion, quaternion_to_matrix
15
+ from kimodo.motion_rep.feature_utils import compute_heading_angle, compute_vel_xyz
16
+ from kimodo.motion_rep.feet import foot_detect_from_pos_and_vel
17
+ from kimodo.motion_rep.smooth_root import get_smooth_root_pos
18
+ from kimodo.skeleton import SkeletonBase
19
+ from kimodo.skeleton.registry import build_skeleton
20
+ from kimodo.tools import to_numpy
21
+
22
+ # Default motion rate for Kimodo NPZ produced by format conversion (matches common model FPS).
23
+ KIMODO_CONVERT_TARGET_FPS = 30.0
24
+
25
+
26
+ def _quaternion_slerp(q0: torch.Tensor, q1: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
27
+ """Spherical linear interpolation; ``q0``, ``q1`` (..., 4) wxyz; ``t`` broadcastable to (...,
28
+ 1)."""
29
+ if t.dim() < q0.dim():
30
+ t = t.unsqueeze(-1)
31
+ dot = (q0 * q1).sum(dim=-1, keepdim=True)
32
+ q1 = torch.where(dot < 0, -q1, q1)
33
+ dot = torch.abs(dot).clamp(-1.0, 1.0)
34
+ theta_0 = torch.acos(dot)
35
+ sin_theta = torch.sin(theta_0)
36
+ s0 = torch.sin((1.0 - t) * theta_0) / sin_theta.clamp(min=1e-8)
37
+ s1 = torch.sin(t * theta_0) / sin_theta.clamp(min=1e-8)
38
+ q = s0 * q0 + s1 * q1
39
+ return q / torch.linalg.norm(q, dim=-1, keepdim=True).clamp(min=1e-8)
40
+
41
+
42
+ def resample_motion_dict_to_kimodo_fps(
43
+ motion_dict: Dict[str, torch.Tensor],
44
+ skeleton: SkeletonBase,
45
+ source_fps: float,
46
+ target_fps: float = KIMODO_CONVERT_TARGET_FPS,
47
+ ) -> Tuple[Dict[str, torch.Tensor], bool]:
48
+ """Resample a Kimodo motion dict to ``target_fps``.
49
+
50
+ When the fps ratio is close to an integer (e.g. 120 / 30 = 4), the faster
51
+ stepping method is used (take every *step*-th frame). Otherwise falls back
52
+ to linear interp (root) + quaternion slerp (joints).
53
+
54
+ Re-runs :func:`complete_motion_dict` at the target rate so derived channels stay consistent.
55
+
56
+ Returns:
57
+ The motion dict and ``True`` if time resampling was applied, else ``False`` (already at
58
+ ``target_fps`` with matching frame count; only re-derived via FK).
59
+ """
60
+ local_rot_mats = motion_dict["local_rot_mats"]
61
+ root_positions = motion_dict["root_positions"]
62
+ local_rot_mats, root_positions = _coerce_time_local_root(local_rot_mats, root_positions)
63
+ t_in = int(local_rot_mats.shape[0])
64
+ if t_in < 1:
65
+ raise ValueError("Motion must have at least one frame.")
66
+ if source_fps <= 0:
67
+ raise ValueError(f"source_fps must be positive; got {source_fps}")
68
+
69
+ t_out = max(1, int(round(t_in * target_fps / source_fps)))
70
+ if t_out == t_in and abs(float(source_fps) - float(target_fps)) < 1e-3:
71
+ return complete_motion_dict(local_rot_mats, root_positions, skeleton, float(target_fps)), False
72
+
73
+ ratio = source_fps / target_fps
74
+ step = round(ratio)
75
+ if step >= 2 and abs(ratio - step) < 0.05:
76
+ local_out = local_rot_mats[::step]
77
+ root_out = root_positions[::step]
78
+ else:
79
+ device = local_rot_mats.device
80
+ dtype = local_rot_mats.dtype
81
+ u = torch.linspace(0, t_in - 1, t_out, device=device, dtype=dtype)
82
+ i0 = u.floor().long().clamp(0, t_in - 1)
83
+ i1 = torch.minimum(i0 + 1, torch.tensor(t_in - 1, device=device))
84
+ tau_1d = (u - i0.float()).unsqueeze(-1)
85
+ rp0 = root_positions[i0]
86
+ rp1 = root_positions[i1]
87
+ root_out = (1.0 - tau_1d) * rp0 + tau_1d * rp1
88
+
89
+ quats = matrix_to_quaternion(local_rot_mats)
90
+ q0 = quats[i0]
91
+ q1 = quats[i1]
92
+ tau_q = (u - i0.float()).view(t_out, 1, 1)
93
+ quat_out = _quaternion_slerp(q0, q1, tau_q)
94
+ local_out = quaternion_to_matrix(quat_out)
95
+
96
+ return complete_motion_dict(local_out, root_out, skeleton, float(target_fps)), True
97
+
98
+
99
+ def warn_kimodo_npz_framerate(source_fps: float, t_before: int, t_after: int) -> None:
100
+ """Emit a warning after time resampling for Kimodo NPZ (linear root, quaternion slerp per
101
+ joint)."""
102
+ warnings.warn(
103
+ f"Resampled motion to {KIMODO_CONVERT_TARGET_FPS:.0f} Hz for Kimodo NPZ "
104
+ f"(source ~{source_fps:.4g} Hz, {t_before} input frames → {t_after} output frames). "
105
+ "Pass --source-fps if the detected source rate is wrong.",
106
+ UserWarning,
107
+ stacklevel=3,
108
+ )
109
+
110
+
111
+ def _coerce_time_local_root(
112
+ local_rot_mats: torch.Tensor,
113
+ root_positions: torch.Tensor,
114
+ ) -> tuple[torch.Tensor, torch.Tensor]:
115
+ """Normalize to shapes (T, J, 3, 3) and (T, 3)."""
116
+ if local_rot_mats.dim() == 5:
117
+ if int(local_rot_mats.shape[0]) != 1:
118
+ raise ValueError(f"local_rot_mats batch size must be 1 for single clip; got {local_rot_mats.shape[0]}")
119
+ local_rot_mats = local_rot_mats[0]
120
+ if root_positions.dim() == 3:
121
+ if int(root_positions.shape[0]) != 1:
122
+ raise ValueError(f"root_positions batch size must be 1; got {root_positions.shape[0]}")
123
+ root_positions = root_positions[0]
124
+ if local_rot_mats.dim() != 4:
125
+ raise ValueError(f"local_rot_mats must be (T,J,3,3); got {tuple(local_rot_mats.shape)}")
126
+ if root_positions.dim() != 2 or int(root_positions.shape[-1]) != 3:
127
+ raise ValueError(f"root_positions must be (T,3); got {tuple(root_positions.shape)}")
128
+ if int(local_rot_mats.shape[0]) != int(root_positions.shape[0]):
129
+ raise ValueError("local_rot_mats and root_positions must have the same number of frames")
130
+ return local_rot_mats, root_positions
131
+
132
+
133
+ def complete_motion_dict(
134
+ local_rot_mats: torch.Tensor,
135
+ root_positions: torch.Tensor,
136
+ skeleton: SkeletonBase,
137
+ fps: float,
138
+ ) -> Dict[str, torch.Tensor]:
139
+ """Build the Kimodo motion output dict from local rotations and root positions.
140
+
141
+ Matches keys written by CLI generation (see docs/source/user_guide/output_formats.md).
142
+
143
+ Args:
144
+ local_rot_mats: (T, J, 3, 3) or (1, T, J, 3, 3) local rotation matrices.
145
+ root_positions: (T, 3) or (1, T, 3) root / pelvis world positions (meters).
146
+ skeleton: Skeleton instance (SOMA77, G1, SMPL-X, etc.).
147
+ fps: Sampling rate (Hz).
148
+
149
+ Returns:
150
+ Dict with tensors ``posed_joints``, ``global_rot_mats``, ``local_rot_mats``,
151
+ ``foot_contacts``, ``smooth_root_pos``, ``root_positions``, ``global_root_heading``.
152
+ """
153
+ device = local_rot_mats.device
154
+ dtype = local_rot_mats.dtype
155
+ local_rot_mats, root_positions = _coerce_time_local_root(
156
+ local_rot_mats.to(device=device, dtype=dtype),
157
+ root_positions.to(device=device, dtype=dtype),
158
+ )
159
+
160
+ global_rot_mats, posed_joints, _ = skeleton.fk(local_rot_mats, root_positions)
161
+
162
+ smooth_root_pos = get_smooth_root_pos(root_positions.unsqueeze(0)).squeeze(0)
163
+
164
+ lengths = torch.tensor([posed_joints.shape[0]], device=device)
165
+ velocities = compute_vel_xyz(posed_joints.unsqueeze(0), fps, lengths=lengths).squeeze(0)
166
+
167
+ heading_angle = compute_heading_angle(posed_joints.unsqueeze(0), skeleton).squeeze(0)
168
+ global_root_heading = torch.stack([torch.cos(heading_angle), torch.sin(heading_angle)], dim=-1)
169
+
170
+ foot_contacts = foot_detect_from_pos_and_vel(
171
+ posed_joints.unsqueeze(0),
172
+ velocities.unsqueeze(0),
173
+ skeleton,
174
+ 0.15,
175
+ 0.10,
176
+ ).squeeze(0)
177
+
178
+ return {
179
+ "posed_joints": posed_joints,
180
+ "global_rot_mats": global_rot_mats,
181
+ "local_rot_mats": local_rot_mats,
182
+ "foot_contacts": foot_contacts,
183
+ "smooth_root_pos": smooth_root_pos,
184
+ "root_positions": root_positions,
185
+ "global_root_heading": global_root_heading,
186
+ }
187
+
188
+
189
+ def motion_dict_to_numpy(d: Dict[str, Any]) -> Dict[str, np.ndarray]:
190
+ """Convert motion dict values to numpy arrays for ``np.savez``."""
191
+ out: Dict[str, np.ndarray] = {}
192
+ for k, v in d.items():
193
+ if hasattr(v, "detach"):
194
+ out[k] = to_numpy(v)
195
+ elif isinstance(v, np.ndarray):
196
+ out[k] = v
197
+ else:
198
+ out[k] = np.asarray(v)
199
+ return out
200
+
201
+
202
+ def save_kimodo_npz(path: str, motion_dict: Dict[str, Any]) -> None:
203
+ """Save a Kimodo-compatible motion dict to ``.npz`` (numpy arrays)."""
204
+ np.savez(path, **motion_dict_to_numpy(motion_dict))
205
+
206
+
207
+ def load_kimodo_npz(path: str) -> Dict[str, np.ndarray]:
208
+ """Load arrays from a Kimodo ``.npz`` file."""
209
+ with np.load(path, allow_pickle=False) as data:
210
+ return {k: np.asarray(data[k]) for k in data.files}
211
+
212
+
213
+ def load_g1_csv(
214
+ path: str,
215
+ source_fps: float = KIMODO_CONVERT_TARGET_FPS,
216
+ *,
217
+ mujoco_rest_zero: bool = False,
218
+ ) -> Dict[str, torch.Tensor]:
219
+ """Load a G1 MuJoCo ``qpos`` CSV (``(T, 36)``) into a Kimodo motion dict.
220
+
221
+ Args:
222
+ path: CSV path (comma-separated, no header).
223
+ source_fps: Source frame rate (Hz) of the CSV data.
224
+ mujoco_rest_zero: Must match how the CSV was written (see :class:`MujocoQposConverter`).
225
+ """
226
+ from kimodo.exports.mujoco import MujocoQposConverter
227
+
228
+ qpos = np.loadtxt(path, delimiter=",")
229
+ if qpos.ndim != 2 or qpos.shape[-1] != 36:
230
+ raise ValueError(f"Expected G1 CSV with shape (T, 36); got {qpos.shape}")
231
+ sk = build_skeleton(34)
232
+ converter = MujocoQposConverter(sk)
233
+ return converter.qpos_to_motion_dict(qpos, float(source_fps), mujoco_rest_zero=mujoco_rest_zero)
234
+
235
+
236
+ def load_amass_npz(
237
+ path: str,
238
+ source_fps: float | None = None,
239
+ *,
240
+ z_up: bool = True,
241
+ ) -> Dict[str, torch.Tensor]:
242
+ """Load an AMASS-style SMPL-X ``.npz`` into a Kimodo motion dict (22 joints).
243
+
244
+ Args:
245
+ path: NPZ with ``trans``, ``root_orient``, ``pose_body``, etc.
246
+ source_fps: Source frame rate (Hz); if ``None``, uses ``mocap_frame_rate``
247
+ from the file when present, else 30 Hz.
248
+ z_up: If ``True``, apply AMASS Z-up to Kimodo Y-up transform (same as CLI).
249
+ """
250
+ from kimodo.exports.smplx import amass_npz_to_kimodo_motion
251
+
252
+ sk = build_skeleton(22)
253
+ return amass_npz_to_kimodo_motion(path, sk, source_fps=source_fps, z_up=z_up)
254
+
255
+
256
+ def load_kimodo_npz_as_torch(
257
+ path: str,
258
+ source_fps: float = KIMODO_CONVERT_TARGET_FPS,
259
+ *,
260
+ ensure_complete: bool = True,
261
+ ) -> tuple[Dict[str, torch.Tensor], int]:
262
+ """Load a Kimodo NPZ and return all arrays as torch tensors on the skeleton device.
263
+
264
+ Args:
265
+ path: Kimodo NPZ file path.
266
+ source_fps: Source frame rate (Hz) used for derived channels when
267
+ ``ensure_complete=True``.
268
+ ensure_complete: If ``True`` and the NPZ lacks derived channels
269
+ (``posed_joints``, ``global_rot_mats``, …), run :func:`complete_motion_dict`
270
+ to fill them from ``local_rot_mats`` + ``root_positions``.
271
+ If ``False``, load all arrays verbatim (requires ``local_rot_mats``).
272
+
273
+ Returns:
274
+ ``(tensor_dict, num_joints)``
275
+ """
276
+ raw = load_kimodo_npz(path)
277
+ if "local_rot_mats" in raw:
278
+ j = int(raw["local_rot_mats"].shape[1])
279
+ elif "posed_joints" in raw:
280
+ j = int(raw["posed_joints"].shape[1])
281
+ else:
282
+ raise ValueError("Kimodo NPZ must contain 'local_rot_mats' or 'posed_joints'.")
283
+ sk = build_skeleton(j)
284
+ device = sk.neutral_joints.device
285
+ dtype = torch.float32
286
+
287
+ if not ensure_complete:
288
+ if "local_rot_mats" not in raw:
289
+ raise ValueError("Kimodo NPZ must contain 'local_rot_mats' (and typically 'root_positions').")
290
+ out: Dict[str, torch.Tensor] = {}
291
+ for k, v in raw.items():
292
+ out[k] = torch.from_numpy(np.asarray(v)).to(device=device, dtype=dtype)
293
+ return out, j
294
+
295
+ if "posed_joints" in raw and "global_rot_mats" in raw:
296
+ out = {}
297
+ for k, v in raw.items():
298
+ out[k] = torch.from_numpy(np.asarray(v)).to(device=device, dtype=dtype)
299
+ return out, j
300
+
301
+ if "local_rot_mats" not in raw or "root_positions" not in raw:
302
+ raise ValueError("Kimodo NPZ must contain posed_joints+global_rot_mats, or local_rot_mats+root_positions.")
303
+ local = torch.from_numpy(np.asarray(raw["local_rot_mats"])).to(device=device, dtype=dtype)
304
+ root = torch.from_numpy(np.asarray(raw["root_positions"])).to(device=device, dtype=dtype)
305
+ return complete_motion_dict(local, root, sk, float(source_fps)), j
306
+
307
+
308
+ def save_kimodo_npz_at_target_fps(
309
+ motion: Dict[str, torch.Tensor],
310
+ skeleton: SkeletonBase,
311
+ source_fps: float,
312
+ output_path: str,
313
+ target_fps: float = KIMODO_CONVERT_TARGET_FPS,
314
+ ) -> None:
315
+ """Resample a motion dict to ``target_fps`` when needed, then save Kimodo NPZ."""
316
+ t_before = int(motion["local_rot_mats"].shape[0])
317
+ motion, did_resample = resample_motion_dict_to_kimodo_fps(motion, skeleton, source_fps, target_fps)
318
+ t_after = int(motion["local_rot_mats"].shape[0])
319
+ if did_resample:
320
+ warn_kimodo_npz_framerate(source_fps, t_before, t_after)
321
+ save_kimodo_npz(output_path, motion)
322
+
323
+
324
+ def kimodo_npz_to_bytes(motion_dict: Dict[str, Any]) -> bytes:
325
+ """Serialize a Kimodo motion dict to in-memory NPZ bytes."""
326
+ import io
327
+
328
+ buf = io.BytesIO()
329
+ np.savez(buf, **motion_dict_to_numpy(motion_dict))
330
+ return buf.getvalue()
331
+
332
+
333
+ def g1_csv_to_bytes(motion_dict: Dict[str, Any], skeleton: SkeletonBase, device: Any) -> bytes:
334
+ """Convert a motion dict to G1 MuJoCo CSV bytes via :class:`MujocoQposConverter`."""
335
+ import io
336
+
337
+ from kimodo.exports.mujoco import MujocoQposConverter
338
+
339
+ converter = MujocoQposConverter(skeleton)
340
+ qpos = converter.dict_to_qpos(
341
+ {k: v for k, v in motion_dict.items() if k in ("local_rot_mats", "root_positions")},
342
+ device,
343
+ numpy=True,
344
+ )
345
+ buf = io.StringIO()
346
+ np.savetxt(buf, qpos, delimiter=",")
347
+ return buf.getvalue().encode("utf-8")
348
+
349
+
350
+ def amass_npz_to_bytes(motion_dict: Dict[str, Any], skeleton: SkeletonBase, fps: float) -> bytes:
351
+ """Convert a motion dict to AMASS NPZ bytes via :class:`AMASSConverter`."""
352
+ import io
353
+
354
+ from kimodo.exports.smplx import AMASSConverter
355
+
356
+ converter = AMASSConverter(skeleton=skeleton, fps=fps)
357
+ buf = io.BytesIO()
358
+ converter.convert_save_npz(
359
+ {k: v for k, v in motion_dict.items() if k in ("local_rot_mats", "root_positions")},
360
+ buf,
361
+ )
362
+ return buf.getvalue()
363
+
364
+
365
+ def _read_amass_source_fps(path: str) -> float:
366
+ """Read the source frame rate from an AMASS NPZ, defaulting to 30 Hz."""
367
+ with np.load(path, allow_pickle=True) as z:
368
+ if "mocap_frame_rate" in z.files:
369
+ return float(z["mocap_frame_rate"])
370
+ return 30.0
371
+
372
+
373
+ def load_motion_file(
374
+ path: str,
375
+ source_fps: float | None = None,
376
+ target_fps: float | None = None,
377
+ *,
378
+ z_up: bool = True,
379
+ mujoco_rest_zero: bool = False,
380
+ ) -> tuple[Dict[str, torch.Tensor], int]:
381
+ """Load a motion file and return a Kimodo motion dict plus joint count.
382
+
383
+ Supports SOMA BVH (``.bvh``), G1 MuJoCo CSV (``.csv``), Kimodo NPZ, and AMASS SMPL-X NPZ
384
+ (``.npz``).
385
+
386
+ The motion is loaded at its native (or overridden) source rate, then
387
+ resampled to ``target_fps`` when they differ.
388
+
389
+ Args:
390
+ path: Path to ``.bvh``, ``.csv``, or ``.npz``.
391
+ source_fps: Source frame rate (Hz). If provided, trusted as-is.
392
+ If ``None``, auto-detected per format: BVH ``Frame Time`` header,
393
+ AMASS ``mocap_frame_rate``, or :data:`KIMODO_CONVERT_TARGET_FPS`
394
+ (30 Hz) for CSV / Kimodo NPZ.
395
+ target_fps: Desired output frame rate (Hz). Defaults to
396
+ :data:`KIMODO_CONVERT_TARGET_FPS` (30 Hz). The motion is
397
+ resampled when ``source_fps`` and ``target_fps`` differ.
398
+ z_up: AMASS NPZ only; passed to :func:`load_amass_npz`.
399
+ mujoco_rest_zero: G1 CSV only; passed to :func:`load_g1_csv`.
400
+
401
+ Returns:
402
+ ``(motion_dict, num_joints)`` with the same keys as :func:`complete_motion_dict`.
403
+ """
404
+ from kimodo.exports.motion_formats import infer_npz_kind
405
+
406
+ if target_fps is None:
407
+ target_fps = KIMODO_CONVERT_TARGET_FPS
408
+
409
+ ext = os.path.splitext(path)[1].lower()
410
+ if ext == ".bvh":
411
+ from kimodo.exports.bvh import bvh_to_kimodo_motion
412
+
413
+ motion_dict, bvh_fps = bvh_to_kimodo_motion(path)
414
+ effective_source = source_fps if source_fps is not None else bvh_fps
415
+ num_joints = int(motion_dict["local_rot_mats"].shape[1])
416
+ elif ext == ".csv":
417
+ effective_source = source_fps if source_fps is not None else KIMODO_CONVERT_TARGET_FPS
418
+ motion_dict = load_g1_csv(path, source_fps=effective_source, mujoco_rest_zero=mujoco_rest_zero)
419
+ num_joints = 34
420
+ elif ext == ".npz":
421
+ kind = infer_npz_kind(path)
422
+ if kind == "amass":
423
+ effective_source = source_fps if source_fps is not None else _read_amass_source_fps(path)
424
+ motion_dict = load_amass_npz(path, source_fps=effective_source, z_up=z_up)
425
+ num_joints = 22
426
+ else:
427
+ effective_source = source_fps if source_fps is not None else KIMODO_CONVERT_TARGET_FPS
428
+ motion_dict, num_joints = load_kimodo_npz_as_torch(path, source_fps=effective_source)
429
+ else:
430
+ raise ValueError(f"Unsupported motion file {path!r}; expected .bvh, .csv, or .npz")
431
+
432
+ if abs(effective_source - target_fps) > 0.5:
433
+ sk = build_skeleton(num_joints)
434
+ motion_dict, did_resample = resample_motion_dict_to_kimodo_fps(motion_dict, sk, effective_source, target_fps)
435
+ if did_resample:
436
+ t_out = int(motion_dict["local_rot_mats"].shape[0])
437
+ warnings.warn(
438
+ f"Resampled motion from {effective_source:.4g} Hz to " f"{target_fps:.0f} Hz ({t_out} frames).",
439
+ UserWarning,
440
+ stacklevel=2,
441
+ )
442
+
443
+ return motion_dict, num_joints
kimodo/exports/mujoco.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Convert kimodo motion (y-up, z-forward) to MuJoCo qpos (z-up, x-forward) for G1 skeleton."""
4
+
5
+ import os
6
+ import xml.etree.ElementTree as ET
7
+ from typing import Optional
8
+
9
+ import numpy as np
10
+ import torch
11
+ from scipy.spatial.transform import Rotation
12
+
13
+ from kimodo.assets import skeleton_asset_path
14
+ from kimodo.geometry import (
15
+ axis_angle_to_matrix,
16
+ matrix_to_axis_angle,
17
+ matrix_to_quaternion,
18
+ quaternion_to_matrix,
19
+ )
20
+ from kimodo.skeleton import G1Skeleton34, SkeletonBase, global_rots_to_local_rots
21
+ from kimodo.tools import ensure_batched, to_numpy, to_torch
22
+
23
+ # Cache so that the same (skeleton, xml_path) returns the same converter instance.
24
+ _converter_cache: dict[tuple[int, str], "MujocoQposConverter"] = {}
25
+
26
+
27
+ class MujocoQposConverter:
28
+ """Fast batch converter from our dictionary format to mujoco qpos with precomputed transforms.
29
+
30
+ In mujoco, the coordination is z up and x forward, right handed.
31
+
32
+ Features (30 joints):
33
+ - root (pelvis, 7 = translation + rotation) + 29 dof joints (29)
34
+
35
+ In kimodo, the coordinate system is y up and z forward, right handed.
36
+ Features (34 joints):
37
+ - root (pelvis) + (34 - 1) joints; among these joints, 4 are end-effector joints added by kimodo.
38
+
39
+ Cached by (input_skeleton id, xml_path); repeated calls with the same args return the same instance.
40
+ """
41
+
42
+ def __new__(
43
+ cls,
44
+ input_skeleton: SkeletonBase,
45
+ xml_path: str = str(skeleton_asset_path("g1skel34", "xml", "g1.xml")),
46
+ ):
47
+ key = (id(input_skeleton), xml_path)
48
+ if key not in _converter_cache:
49
+ inst = object.__new__(cls)
50
+ _converter_cache[key] = inst
51
+ return _converter_cache[key]
52
+
53
+ def __init__(
54
+ self,
55
+ input_skeleton: SkeletonBase,
56
+ xml_path: str = str(skeleton_asset_path("g1skel34", "xml", "g1.xml")),
57
+ ):
58
+ """Initialize converter with precomputed transforms.
59
+
60
+ Args:
61
+ xml_path: Path to the mujoco XML file containing joint definitions
62
+ """
63
+ if getattr(self, "_initialized", False):
64
+ return
65
+ self.xml_path = xml_path
66
+ self.skeleton = input_skeleton
67
+ self._prepare_transforms()
68
+ self._subtree_joints = {}
69
+ self._initialized = True
70
+
71
+ def _prepare_transforms(self):
72
+ """Precompute all necessary transforms for efficient batch processing."""
73
+ # Define coordinate transformations between mujoco and kimodo space
74
+ # 1) R_zup_to_yup: rotation around x-axis by -90 degrees
75
+ # 2) x_forward_to_y_forward: rotation around z-axis by -90 degrees
76
+ # Combined transformation matrix: mujoco_to_kimodo = R_zup_to_yup * x_forward_to_y_forward
77
+ self.mujoco_to_kimodo_matrix = torch.tensor(
78
+ [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], dtype=torch.float32
79
+ )
80
+ self.kimodo_to_mujoco_matrix = self.mujoco_to_kimodo_matrix.T # Inverse transformation: kimodo_to_mujoco
81
+
82
+ # Parse XML once and extract joint information
83
+ tree = ET.parse(self.xml_path)
84
+ root = tree.getroot()
85
+
86
+ xml_classes = [x for x in tree.findall(".//default") if "class" in x.attrib]
87
+ joint_axes = dict()
88
+ class_ranges: dict[str, tuple[float, float]] = {}
89
+ for xml_class in xml_classes:
90
+ j = xml_class.findall("joint")
91
+ if j:
92
+ joint_axes[xml_class.get("class")] = j[0].get("axis")
93
+ range_str = j[0].get("range")
94
+ if range_str:
95
+ range_vals = [float(x) for x in range_str.split()]
96
+ if len(range_vals) == 2:
97
+ class_ranges[xml_class.get("class")] = (
98
+ range_vals[0],
99
+ range_vals[1],
100
+ )
101
+
102
+ mujoco_hinge_joints = root.find("worldbody").findall(".//joint") # skip the base joint
103
+ self._mujoco_joint_axis_values_kimodo_space = torch.zeros(
104
+ (len(mujoco_hinge_joints), 3), dtype=torch.float32
105
+ ) # mujoco order but kimodo space
106
+ self._mujoco_joint_axis_values_mujoco_space = torch.zeros(
107
+ (len(mujoco_hinge_joints), 3), dtype=torch.float32
108
+ ) # mujoco order but mujoco space
109
+
110
+ # for the below indices, mujoco_indices_to_kimodo_indices does not include mujoco root (30 - 1 = 29 elements),
111
+ # while kimodo_indices_to_mujoco_indices inclues the kimodo root (32 elements).
112
+ self._mujoco_indices_to_kimodo_indices = torch.zeros((len(mujoco_hinge_joints),), dtype=torch.int32)
113
+ self._kimodo_indices_to_mujoco_indices = (
114
+ torch.ones((self.skeleton.nbjoints,), dtype=torch.int32) * -1
115
+ ) # -1 means not in the csv skeleton
116
+
117
+ self._nb_joints_mujoco = len(mujoco_hinge_joints) + 1
118
+ self._nb_joints_kimodo = self.skeleton.nbjoints
119
+ self._mujoco_joint_including_root_parent_list = torch.full(
120
+ (len(mujoco_hinge_joints) + 1,), -1, dtype=torch.int32
121
+ )
122
+ self._mujoco_joint_including_root_list = ["pelvis_skel"]
123
+
124
+ for joint_id_in_csv, joint in enumerate(mujoco_hinge_joints):
125
+ joint_name_in_skeleton = joint.get("name").replace("_joint", "_skel")
126
+ joint_parent_name_in_skeleton = self.skeleton.bone_parents[joint_name_in_skeleton]
127
+
128
+ self._mujoco_joint_including_root_list.append(joint_name_in_skeleton)
129
+ self._mujoco_joint_including_root_parent_list[joint_id_in_csv + 1] = (
130
+ self._mujoco_joint_including_root_list.index(joint_parent_name_in_skeleton)
131
+ )
132
+
133
+ joint_idx_in_kimodo_skeleton = self.skeleton.bone_order_names.index(joint_name_in_skeleton)
134
+ axis_values = [float(x) for x in (joint.get("axis") or joint_axes[joint.get("class")]).split(" ")]
135
+
136
+ # the mapped axis in kimodo skeleton space is calculated as bones_axis = mujoco_to_kimodo.apply(axis_values)
137
+ # [1, 0, 0] -> [0, 0, 1]; [0, 1, 0] -> [1, 0, 0]; [0, 0, 1] -> [0, 1, 0]
138
+ mujoco_joint_axis_mapping_kimodo_space = [
139
+ torch.tensor([0, 0, 1]),
140
+ torch.tensor([1, 0, 0]),
141
+ torch.tensor([0, 1, 0]),
142
+ ][np.argmax(axis_values)]
143
+
144
+ self._mujoco_joint_axis_values_kimodo_space[joint_id_in_csv] = mujoco_joint_axis_mapping_kimodo_space
145
+ self._mujoco_joint_axis_values_mujoco_space[joint_id_in_csv] = torch.tensor(axis_values)
146
+
147
+ self._mujoco_indices_to_kimodo_indices[joint_id_in_csv] = joint_idx_in_kimodo_skeleton
148
+ self._kimodo_indices_to_mujoco_indices[joint_idx_in_kimodo_skeleton] = (
149
+ joint_id_in_csv + 1
150
+ ) # +1 for the root
151
+ self._kimodo_indices_to_mujoco_indices[0] = 0 # the root joint mapping
152
+
153
+ # Joint limits (min, max) in radians for each mujoco hinge, for clamping
154
+ self._joint_limits_min = torch.full((len(mujoco_hinge_joints),), float("-inf"), dtype=torch.float32)
155
+ self._joint_limits_max = torch.full((len(mujoco_hinge_joints),), float("inf"), dtype=torch.float32)
156
+ for joint_id_in_csv, joint in enumerate(mujoco_hinge_joints):
157
+ range_vals = None
158
+ if joint.get("range"):
159
+ range_vals = [float(x) for x in joint.get("range").split()]
160
+ elif joint.get("class") and joint.get("class") in class_ranges:
161
+ lo, hi = class_ranges[joint.get("class")]
162
+ range_vals = [lo, hi]
163
+ if range_vals is not None and len(range_vals) == 2:
164
+ self._joint_limits_min[joint_id_in_csv] = range_vals[0]
165
+ self._joint_limits_max[joint_id_in_csv] = range_vals[1]
166
+
167
+ # load the offset matrices from the xml
168
+ R_zup_to_yup = Rotation.from_euler("x", -90, degrees=True)
169
+ x_forward_to_y_forward = Rotation.from_euler("z", -90, degrees=True)
170
+ mujoco_to_kimodo = R_zup_to_yup * x_forward_to_y_forward
171
+
172
+ self._rot_offsets_q2t = torch.zeros(len(self._kimodo_indices_to_mujoco_indices), 3, 3, dtype=torch.float32)
173
+ self._rot_offsets_q2t[...] = torch.eye(3)[None]
174
+
175
+ self._rot_offsets_f2q = torch.zeros(len(self._kimodo_indices_to_mujoco_indices), 3, 3, dtype=torch.float32)
176
+ self._rot_offsets_f2q[...] = torch.eye(3)[None]
177
+ parent_map = {child: parent for parent in root.iter() for child in parent}
178
+ for i, joint in enumerate(mujoco_hinge_joints):
179
+ body = parent_map[joint]
180
+ if "quat" in body.attrib:
181
+ rot = Rotation.from_quat(
182
+ [float(x) for x in body.get("quat").strip().split(" ")],
183
+ scalar_first=True,
184
+ )
185
+ idx = self._mujoco_indices_to_kimodo_indices[i]
186
+ self._rot_offsets_q2t[idx] = torch.from_numpy(rot.as_matrix())
187
+ rot = mujoco_to_kimodo * rot * mujoco_to_kimodo.inv()
188
+ self._rot_offsets_f2q[idx] = torch.from_numpy(rot.as_matrix().T)
189
+
190
+ # Hinge axis in f2q space so extraction uses the same frame as joint_rot_f2q.
191
+ # Then extract(offset) gives the angle s.t. axis_angle(angle * axis_f2q) = offset, and
192
+ # reconstruction R_local = offset.T @ axis_angle(angle * axis_f2q) = I when input is identity.
193
+ axis_kimodo = self._mujoco_joint_axis_values_kimodo_space
194
+ self._mujoco_joint_axis_values_f2q_space = torch.zeros_like(axis_kimodo)
195
+ for i in range(len(mujoco_hinge_joints)):
196
+ j = self._mujoco_indices_to_kimodo_indices[i].item()
197
+ axis_f2q = torch.mv(self._rot_offsets_f2q[j], axis_kimodo[i])
198
+ n = axis_f2q.norm()
199
+ if n > 1e-8:
200
+ axis_f2q = axis_f2q / n
201
+ self._mujoco_joint_axis_values_f2q_space[i] = axis_f2q
202
+
203
+ # Rest-pose DOFs: angle we extract when R_local = I (t-pose). MuJoCo limits are
204
+ # relative to joint zero (rest pose), so we must clamp in MuJoCo space: convert
205
+ # joint_dofs to mujoco_angle = joint_dofs - rest_dofs, clamp, then back.
206
+ rest_rot_f2q = self._rot_offsets_f2q[self._mujoco_indices_to_kimodo_indices]
207
+ rest_rot_f2q = rest_rot_f2q.unsqueeze(0).unsqueeze(0)
208
+ self._rest_dofs = self._local_rots_f2q_to_joint_dofs(rest_rot_f2q).squeeze(0).squeeze(0)
209
+ # Axis-angle rest DOFs: angle s.t. axis_angle(angle * axis_f2q) = offset. Used in
210
+ # project_to_real_robot_rotations so extract+reconstruct round-trip and t-pose is preserved.
211
+ rest_rot_f2q_flat = self._rot_offsets_f2q[self._mujoco_indices_to_kimodo_indices]
212
+ full_aa = matrix_to_axis_angle(rest_rot_f2q_flat)
213
+ self._rest_dofs_axis_angle = (full_aa * self._mujoco_joint_axis_values_f2q_space).sum(dim=-1)
214
+
215
+ def dict_to_qpos(
216
+ self,
217
+ output: dict,
218
+ device: Optional[str] = None,
219
+ root_quat_w_first: bool = True,
220
+ numpy: bool = True,
221
+ mujoco_rest_zero: bool = False,
222
+ ):
223
+ """Convert kimodo output dict to mujoco qpos format.
224
+
225
+ Args:
226
+ output: dict with keys "local_rot_mats" and "root_positions".
227
+ device: device to use for the output.
228
+ root_quat_w_first: If True, quaternion in qpos is (w,x,y,z).
229
+ numpy: If True, convert the output to numpy array.
230
+ mujoco_rest_zero: If True, joint angles are written so that kimodo rest (t-pose)
231
+ maps to q=0 in MuJoCo. If False, write raw joint_dofs.
232
+
233
+ Returns:
234
+ qpos: (B, T, 7+J) mujoco qpos format.
235
+ """
236
+ local_rot_mats = to_torch(output["local_rot_mats"], device)
237
+ root_positions = to_torch(output["root_positions"], device)
238
+
239
+ qpos = self.to_qpos(
240
+ local_rot_mats,
241
+ root_positions,
242
+ root_quat_w_first=root_quat_w_first,
243
+ mujoco_rest_zero=mujoco_rest_zero,
244
+ )
245
+ if numpy:
246
+ qpos = to_numpy(qpos)
247
+ return qpos
248
+
249
+ def qpos_to_motion_dict(
250
+ self,
251
+ qpos: torch.Tensor | np.ndarray,
252
+ source_fps: float,
253
+ *,
254
+ root_quat_w_first: bool = True,
255
+ mujoco_rest_zero: bool = False,
256
+ ):
257
+ """Inverse of :meth:`to_qpos` / :meth:`dict_to_qpos` for MuJoCo CSV ``(T, 36)`` rows.
258
+
259
+ Args:
260
+ qpos: Shape ``(T, 36)`` or ``(1, T, 36)`` (root xyz, root quat wxyz, 29 joint angles).
261
+ source_fps: Source frame rate (Hz) of the qpos data.
262
+ root_quat_w_first: Must match how the CSV was written (default ``True``).
263
+ mujoco_rest_zero: Must match :meth:`dict_to_qpos` / :meth:`to_qpos`.
264
+
265
+ Returns:
266
+ Kimodo motion dict (see :func:`kimodo.exports.motion_io.complete_motion_dict`).
267
+ """
268
+ from kimodo.exports.motion_io import complete_motion_dict
269
+
270
+ qpos = to_torch(qpos, None)
271
+ if qpos.dim() == 2:
272
+ qpos = qpos.unsqueeze(0)
273
+ device = qpos.device
274
+ dtype = qpos.dtype
275
+ batch_size, num_frames, ncols = qpos.shape
276
+ if ncols != 36:
277
+ raise ValueError(f"Expected qpos last dim 36; got {ncols}")
278
+
279
+ kimodo_to_mujoco_matrix = self.kimodo_to_mujoco_matrix.to(device=device, dtype=dtype)
280
+ mujoco_to_kimodo_matrix = kimodo_to_mujoco_matrix.T
281
+
282
+ root_mujoco = qpos[..., :3]
283
+ root_positions = torch.matmul(mujoco_to_kimodo_matrix[None, None, ...], root_mujoco[..., None]).squeeze(-1)
284
+
285
+ quat = qpos[..., 3:7]
286
+ if root_quat_w_first:
287
+ root_rot_mujoco = quaternion_to_matrix(quat)
288
+ else:
289
+ quat_wxyz = quat[..., [3, 0, 1, 2]]
290
+ root_rot_mujoco = quaternion_to_matrix(quat_wxyz)
291
+
292
+ O0 = self._rot_offsets_f2q[0].to(device=device, dtype=dtype)
293
+ # root_rot_mujoco is (..., 3, 3) after optional batch unsqueeze (e.g. (1, T, 3, 3)).
294
+ # Use ``...il`` so ``k`` sums with ``kl``; ``...ik`` incorrectly keeps ``k`` in the output.
295
+ R_f2q_root = torch.einsum(
296
+ "ij,...jk,kl->...il",
297
+ mujoco_to_kimodo_matrix,
298
+ root_rot_mujoco,
299
+ kimodo_to_mujoco_matrix,
300
+ )
301
+ R_kimodo_root = torch.einsum("ij,...jk->...ik", O0.T, R_f2q_root)
302
+
303
+ joint_dofs = qpos[..., 7:]
304
+ if mujoco_rest_zero:
305
+ rest_dofs = self._rest_dofs.to(device=device, dtype=dtype)
306
+ angles = joint_dofs + rest_dofs[None, None, :]
307
+ use_relative = True
308
+ else:
309
+ angles = joint_dofs
310
+ use_relative = False
311
+
312
+ nb_joints = self.skeleton.nbjoints
313
+ template = torch.eye(3, device=device, dtype=dtype).expand(batch_size, num_frames, nb_joints, 3, 3).contiguous()
314
+ template[:, :, 0] = R_kimodo_root
315
+
316
+ local_rot_mats = self._joint_dofs_to_local_rot_mats(
317
+ angles,
318
+ template,
319
+ device,
320
+ dtype,
321
+ use_relative=use_relative,
322
+ )
323
+
324
+ if batch_size != 1:
325
+ raise ValueError(f"Only a single clip is supported; got batch_size={batch_size}")
326
+
327
+ return complete_motion_dict(local_rot_mats[0], root_positions[0], self.skeleton, source_fps)
328
+
329
+ def save_csv(self, qpos: torch.Tensor | np.ndarray, csv_path):
330
+ # comment this
331
+ qpos = to_numpy(qpos)
332
+ shape = qpos.shape
333
+ if len(shape) == 2:
334
+ # only one motion: save it
335
+ np.savetxt(csv_path, qpos, delimiter=",")
336
+ if len(shape) == 3:
337
+ # batch of motions
338
+ if shape[0] == 1:
339
+ # if only one motion, just save it
340
+ np.savetxt(csv_path, qpos[0], delimiter=",")
341
+ else:
342
+ csv_path_base, ext = os.path.splitext(csv_path)
343
+ for i in range(shape[0]):
344
+ self.save_csv(qpos[i], csv_path_base + "_" + str(i).zfill(2) + ext)
345
+
346
+ def _local_rots_to_joint_dofs(
347
+ self,
348
+ local_rot_mats: torch.Tensor,
349
+ axis_vals: torch.Tensor,
350
+ ) -> torch.Tensor:
351
+ """Extract per-joint single-DoF angles (radians) via Euler projection (for to_qpos/f2q)."""
352
+ x_joint_dof = torch.atan2(local_rot_mats[..., 2, 1], local_rot_mats[..., 2, 2])
353
+ y_joint_dof = torch.atan2(local_rot_mats[..., 0, 2], local_rot_mats[..., 0, 0])
354
+ z_joint_dof = torch.atan2(local_rot_mats[..., 1, 0], local_rot_mats[..., 1, 1])
355
+ xyz_joint_dofs = torch.stack([x_joint_dof, y_joint_dof, z_joint_dof], dim=-1)
356
+ axis_vals = axis_vals.to(device=local_rot_mats.device, dtype=local_rot_mats.dtype)
357
+ joint_dofs = (xyz_joint_dofs * axis_vals[None, None, :, :]).sum(dim=-1)
358
+ return joint_dofs
359
+
360
+ def _local_rots_to_joint_dofs_axis_angle(
361
+ self,
362
+ local_rot_mats: torch.Tensor,
363
+ axis_vals: torch.Tensor,
364
+ ) -> torch.Tensor:
365
+ """Extract per-joint single-DoF angles (radians) via axis-angle; round-trips with
366
+ axis_angle_to_matrix.
367
+
368
+ Args:
369
+ local_rot_mats: (..., num_hinges, 3, 3) in same frame as axis_vals.
370
+ axis_vals: (num_hinges, 3) unit axis per hinge.
371
+ Returns:
372
+ joint_dofs: (..., num_hinges) signed angle = dot(axis_angle(R), axis).
373
+ """
374
+ axis_vals = axis_vals.to(device=local_rot_mats.device, dtype=local_rot_mats.dtype)
375
+ full_aa = matrix_to_axis_angle(local_rot_mats)
376
+ joint_dofs = (full_aa * axis_vals).sum(dim=-1)
377
+ return joint_dofs
378
+
379
+ def _local_rots_f2q_to_joint_dofs(self, local_rot_mats_f2q: torch.Tensor) -> torch.Tensor:
380
+ """Extract per-joint single-DoF angles from local rotations in f2q space (for to_qpos)."""
381
+ axis_vals = self._mujoco_joint_axis_values_f2q_space
382
+ return self._local_rots_to_joint_dofs(local_rot_mats_f2q, axis_vals)
383
+
384
+ def _clamp_to_limits(self, joint_dofs: torch.Tensor) -> torch.Tensor:
385
+ """Clamp joint angles to XML limits (radians).
386
+
387
+ Angles are in kimodo convention (0 = rest).
388
+ """
389
+ device = joint_dofs.device
390
+ lo = self._joint_limits_min.to(device=device, dtype=joint_dofs.dtype)
391
+ hi = self._joint_limits_max.to(device=device, dtype=joint_dofs.dtype)
392
+ return torch.clamp(joint_dofs, lo[None, None, :], hi[None, None, :])
393
+
394
+ def _clamp_joint_dofs(self, joint_dofs: torch.Tensor, rest_dofs: torch.Tensor) -> torch.Tensor:
395
+ """Clamp joint angles to MuJoCo limits (radians), with rest_dofs conversion."""
396
+ device = joint_dofs.device
397
+ rest_dofs = rest_dofs.to(device=device, dtype=joint_dofs.dtype)
398
+ mujoco_dofs = joint_dofs - rest_dofs[None, None, :]
399
+ lo = self._joint_limits_min.to(device=device, dtype=joint_dofs.dtype)
400
+ hi = self._joint_limits_max.to(device=device, dtype=joint_dofs.dtype)
401
+ mujoco_dofs = torch.clamp(mujoco_dofs, lo[None, None, :], hi[None, None, :])
402
+ return mujoco_dofs + rest_dofs[None, None, :]
403
+
404
+ def _joint_dofs_to_local_rot_mats(
405
+ self,
406
+ joint_dofs: torch.Tensor,
407
+ original_local_rot_mats: torch.Tensor,
408
+ device: torch.device,
409
+ dtype: torch.dtype,
410
+ use_relative: bool = False,
411
+ ) -> torch.Tensor:
412
+ """Reconstruct full local rotation matrices from 1-DoF angles."""
413
+ out = original_local_rot_mats.clone()
414
+ axis_kimodo = self._mujoco_joint_axis_values_kimodo_space.to(device=device, dtype=dtype)
415
+ for i in range(joint_dofs.shape[-1]):
416
+ j = self._mujoco_indices_to_kimodo_indices[i].item()
417
+ angle = joint_dofs[..., i]
418
+ axis = axis_kimodo[i]
419
+ if use_relative:
420
+ axis_angle = angle[..., None] * axis[None, None, :]
421
+ R_local = axis_angle_to_matrix(axis_angle)
422
+ else:
423
+ rot_offsets_f2q = self._rot_offsets_f2q.to(device=device, dtype=dtype)
424
+ axis_in_f2q = torch.mv(rot_offsets_f2q[j], axis)
425
+ axis_angle = angle[..., None] * axis_in_f2q[None, None, :]
426
+ R_f2q = axis_angle_to_matrix(axis_angle)
427
+ R_local = torch.einsum("ij,btjk->btik", rot_offsets_f2q[j].T, R_f2q)
428
+ out[:, :, j, :, :] = R_local
429
+ return out
430
+
431
+ @ensure_batched(local_rot_mats=5, root_positions=3, lengths=1)
432
+ def project_to_real_robot_rotations(
433
+ self,
434
+ local_rot_mats: torch.Tensor,
435
+ root_positions: torch.Tensor,
436
+ clamp_to_limits: bool = True,
437
+ mujoco_rest_zero: bool = False,
438
+ ) -> dict:
439
+ """Project full 3D local rotations to G1 real robot DoF and back to 3D for viz.
440
+
441
+ Joint angles are extracted along each hinge axis, optionally clamped to XML limits, then
442
+ reconstructed to 3D rotations. When mujoco_rest_zero=False (default), raw angles are used
443
+ (baked-with-quat). When True, angles are relative to rest (0 = T-pose in MuJoCo).
444
+ """
445
+ device = local_rot_mats.device
446
+ dtype = local_rot_mats.dtype
447
+
448
+ # Transform to f2q frame and extract 1-DoF angles (axis-angle projection).
449
+ local_rot_f2q = torch.matmul(self._rot_offsets_f2q.to(device=device, dtype=dtype), local_rot_mats)
450
+ hinge_rots = local_rot_f2q[:, :, self._mujoco_indices_to_kimodo_indices, :, :]
451
+ axis_f2q = self._mujoco_joint_axis_values_f2q_space.to(device=device, dtype=dtype)
452
+ joint_dofs = self._local_rots_to_joint_dofs_axis_angle(hinge_rots, axis_f2q)
453
+
454
+ # Optionally express angles relative to rest (MuJoCo q=0 at T-pose).
455
+ if mujoco_rest_zero:
456
+ rest_dofs = self._rest_dofs_axis_angle.to(device=device, dtype=dtype)
457
+ angles = joint_dofs - rest_dofs[None, None, :]
458
+ use_relative = True
459
+ else:
460
+ angles = joint_dofs
461
+ use_relative = False
462
+
463
+ if clamp_to_limits:
464
+ if mujoco_rest_zero:
465
+ angles = self._clamp_to_limits(angles)
466
+ else:
467
+ rest_dofs_aa = self._rest_dofs_axis_angle.to(device=device, dtype=dtype)
468
+ angles = self._clamp_joint_dofs(angles, rest_dofs_aa)
469
+
470
+ # Reconstruct 3D local rotations from 1-DoF angles and run FK.
471
+ local_rot_mats_proj = self._joint_dofs_to_local_rot_mats(
472
+ angles, local_rot_mats, device, dtype, use_relative=use_relative
473
+ )
474
+ global_rot_mats, posed_joints, _ = self.skeleton.fk(local_rot_mats_proj, root_positions)
475
+ return {
476
+ "local_rot_mats": local_rot_mats_proj,
477
+ "global_rot_mats": global_rot_mats,
478
+ "posed_joints": posed_joints,
479
+ "root_positions": root_positions,
480
+ }
481
+
482
+ @ensure_batched(local_rot_mats=5, root_positions=3, lengths=1)
483
+ def to_qpos(
484
+ self,
485
+ local_rot_mats: torch.Tensor,
486
+ root_positions: torch.Tensor,
487
+ root_quat_w_first: bool = True,
488
+ mujoco_rest_zero: bool = False,
489
+ ) -> torch.Tensor:
490
+ """Fast batch conversion from kimodo features to mujoco qpos format.
491
+
492
+ Args:
493
+ local_rot_mats: (B, T, J, 3, 3) local rotation matrices (kimodo convention).
494
+ root_positions: (B, T, 3) root positions.
495
+ root_quat_w_first: If True, quaternion in qpos is (w,x,y,z).
496
+ mujoco_rest_zero: If True, joint angles are written so that kimodo rest (t-pose)
497
+ maps to q=0 in MuJoCo. If False, write raw joint_dofs.
498
+
499
+ Returns:
500
+ torch.Tensor of shape [batch, numFrames, 36] containing mujoco qpos data:
501
+ - root_trans (3) + root_quat (4) + joint_dofs (29) = 36 columns
502
+ """
503
+
504
+ batch_size, num_frames, nb_joints = local_rot_mats.shape[:3]
505
+ device, dtype = local_rot_mats.device, local_rot_mats.dtype
506
+
507
+ local_rot_mats = torch.matmul(self._rot_offsets_f2q.to(device), local_rot_mats)
508
+
509
+ batch_size, num_frames = root_positions.shape[0], root_positions.shape[1]
510
+
511
+ # Move precomputed matrices to the same device/dtype
512
+ kimodo_to_mujoco_matrix = self.kimodo_to_mujoco_matrix.to(device=device, dtype=dtype)
513
+
514
+ # Initialize output tensor: [batch, numFrames, 36]
515
+ qpos = torch.zeros((batch_size, num_frames, 36), dtype=dtype, device=device)
516
+
517
+ # Convert root translation: apply coordinate transformation
518
+ root_positions_mujoco = torch.matmul(kimodo_to_mujoco_matrix[None, None, ...], root_positions[..., None])
519
+ qpos[:, :, :3] = root_positions_mujoco.view(batch_size, num_frames, 3)
520
+
521
+ # Convert root rotation: apply coordinate transformation to rotation matrix
522
+ root_rot = local_rot_mats[:, :, 0, :] # [batch, numFrames, 3, 3]
523
+
524
+ # Apply coordinate transformation: R_mujoco = kimodo_to_mujoco * R_kimodo * kimodo_to_mujoco^T
525
+ mujoco_to_kimodo_matrix = kimodo_to_mujoco_matrix.T
526
+ root_rot_mujoco = torch.matmul(
527
+ torch.matmul(kimodo_to_mujoco_matrix[None, None, ...], root_rot),
528
+ mujoco_to_kimodo_matrix[None, None, ...],
529
+ )
530
+ root_rot_quat = matrix_to_quaternion(root_rot_mujoco) # [w, x, y, z]
531
+ if root_quat_w_first:
532
+ qpos[:, :, 3:7] = root_rot_quat[:, :, [0, 1, 2, 3]] # [w, x, y, z]
533
+ else:
534
+ qpos[:, :, 3:7] = root_rot_quat[:, :, [1, 2, 3, 0]] # [w, x, y, z] -> [x, y, z, w]
535
+
536
+ # Joint DOFs: raw angles or relative to rest (rest = q=0 in MuJoCo).
537
+ joint_rot_f2q = local_rot_mats[:, :, self._mujoco_indices_to_kimodo_indices, :, :]
538
+ joint_dofs = self._local_rots_f2q_to_joint_dofs(joint_rot_f2q)
539
+ if mujoco_rest_zero:
540
+ rest_dofs = self._rest_dofs.to(device=device, dtype=dtype)
541
+ qpos[:, :, 7:] = joint_dofs - rest_dofs[None, None, :]
542
+ else:
543
+ qpos[:, :, 7:] = joint_dofs
544
+ return qpos
545
+
546
+
547
+ def apply_g1_real_robot_projection(
548
+ skeleton: G1Skeleton34,
549
+ joints_pos: torch.Tensor,
550
+ joints_rot: torch.Tensor,
551
+ clamp_to_limits: bool = True,
552
+ ) -> tuple[torch.Tensor, torch.Tensor]:
553
+ """Project G1 motion to real robot DoF (1-DoF per joint) with optional axis limits.
554
+
555
+ Extracts a single angle per hinge along its axis (1-DoF), optionally clamps to
556
+ joint limits from the MuJoCo XML (when clamp_to_limits=True), then reconstructs
557
+ 3D rotations and runs FK. T-pose (identity local rotations) is preserved.
558
+
559
+ Args:
560
+ skeleton: G1 skeleton instance.
561
+ joints_pos: (T, J, 3) or (B, T, J, 3) joint positions in global space.
562
+ joints_rot: (T, J, 3, 3) or (B, T, J, 3, 3) global rotation matrices.
563
+ clamp_to_limits: If True, clamp joint angles to XML axis limits (default True).
564
+
565
+ Returns:
566
+ (posed_joints, global_rot_mats) as tensors, same shape as inputs (batch preserved).
567
+ """
568
+
569
+ local_rot_mats = global_rots_to_local_rots(joints_rot, skeleton)
570
+ root_positions = joints_pos[..., skeleton.root_idx, :]
571
+
572
+ # Converter expects batch dim (B, T, ...); add and remove if single sequence.
573
+ single_sequence = local_rot_mats.dim() == 4
574
+ if single_sequence:
575
+ local_rot_mats = local_rot_mats.unsqueeze(0)
576
+ root_positions = root_positions.unsqueeze(0)
577
+
578
+ converter = MujocoQposConverter(skeleton)
579
+ projected = converter.project_to_real_robot_rotations(
580
+ local_rot_mats, root_positions, clamp_to_limits=clamp_to_limits
581
+ )
582
+
583
+ out_pos = projected["posed_joints"]
584
+ out_rot = projected["global_rot_mats"]
585
+ if single_sequence:
586
+ out_pos = out_pos.squeeze(0)
587
+ out_rot = out_rot.squeeze(0)
588
+ return out_pos, out_rot
kimodo/exports/smplx.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Convert kimodo motion to AMASS/SMPL-X compatible parameters (axis-angle, Y-up or Z-up)."""
4
+
5
+ import os
6
+ from typing import Optional
7
+
8
+ import einops
9
+ import numpy as np
10
+ import torch
11
+
12
+ from kimodo.assets import skeleton_asset_path
13
+ from kimodo.geometry import axis_angle_to_matrix, matrix_to_axis_angle
14
+ from kimodo.tools import ensure_batched, to_numpy, to_torch
15
+
16
+
17
+ def kimodo_y_up_to_amass_coord_rotation_matrix() -> np.ndarray:
18
+ """3x3 rotation mapping Kimodo Y-up (+Z forward) to AMASS Z-up (+Y forward).
19
+
20
+ Used by :func:`get_amass_parameters` and :func:`amass_arrays_to_kimodo_motion` (inverse).
21
+ """
22
+ y_up_to_z_up = np.array(
23
+ [
24
+ [1.0, 0.0, 0.0],
25
+ [0.0, 0.0, -1.0],
26
+ [0.0, 1.0, 0.0],
27
+ ],
28
+ dtype=np.float32,
29
+ )
30
+ rot_z_180 = np.array(
31
+ [
32
+ [-1.0, 0.0, 0.0],
33
+ [0.0, -1.0, 0.0],
34
+ [0.0, 0.0, 1.0],
35
+ ],
36
+ dtype=np.float32,
37
+ )
38
+ return np.matmul(rot_z_180, y_up_to_z_up).astype(np.float32)
39
+
40
+
41
+ @ensure_batched(local_rot_mats=5, root_positions=3, lengths=1)
42
+ def get_amass_parameters(
43
+ local_rot_mats,
44
+ root_positions,
45
+ skeleton,
46
+ z_up=True,
47
+ ):
48
+ """Convert local rot mats and root positions to AMASS-style trans and pose_body; optional z_up
49
+ coordinate transform.
50
+
51
+ Our method generates motions with Y-up and +Z forward; if z_up=True, transform to Z-up and +Y
52
+ forward as in AMASS.
53
+ """
54
+ # Our method generate motions with Y-up and +Z forward
55
+ # if z_up = True, we transform this to: Z-up with +Y forward, as in AMASS
56
+ # Remove the root offset; SMPL-X FK adds pelvis offset back.
57
+ pelvis_offset = skeleton.neutral_joints[skeleton.root_idx].cpu().numpy()
58
+ trans = root_positions - pelvis_offset
59
+
60
+ root_rot_mats = to_numpy(local_rot_mats[:, :, 0])
61
+ local_rot_axis_angle = to_numpy(matrix_to_axis_angle(to_torch(local_rot_mats)))
62
+ pose_body = einops.rearrange(local_rot_axis_angle[:, :, 1:], "b t j d -> b t (j d)")
63
+
64
+ # Optionally convert from Y-up to Z-up coordinates.
65
+ if z_up:
66
+ y_up_to_z_up = kimodo_y_up_to_amass_coord_rotation_matrix()
67
+ root_rot_mats = np.matmul(y_up_to_z_up, root_rot_mats)
68
+ trans = np.matmul(trans + pelvis_offset, y_up_to_z_up.T) - pelvis_offset
69
+
70
+ root_orient = to_numpy(matrix_to_axis_angle(to_torch(root_rot_mats)))
71
+ return trans, root_orient, pose_body
72
+
73
+
74
+ def amass_arrays_to_kimodo_motion(
75
+ trans: np.ndarray,
76
+ root_orient: np.ndarray,
77
+ pose_body: np.ndarray,
78
+ skeleton,
79
+ source_fps: float,
80
+ *,
81
+ z_up: bool = True,
82
+ ):
83
+ """Inverse of :func:`get_amass_parameters` for a single sequence (AMASS → Kimodo motion dict).
84
+
85
+ Args:
86
+ trans: ``(T, 3)`` AMASS root translation (same as ``trans`` in AMASS NPZ).
87
+ root_orient: ``(T, 3)`` axis-angle root orientation in AMASS coordinates (z-up when ``z_up``).
88
+ pose_body: ``(T, 63)`` body pose axis-angle (21 joints × 3).
89
+ skeleton: :class:`~kimodo.skeleton.definitions.SMPLXSkeleton22` instance.
90
+ source_fps: Source frame rate (Hz) of the AMASS recording.
91
+ z_up: If ``True``, invert the same Y-up↔Z-up transform as ``get_amass_parameters(..., z_up=True)``.
92
+
93
+ Returns:
94
+ Motion dict compatible with :func:`kimodo.exports.motion_io.save_kimodo_npz`.
95
+ """
96
+ from kimodo.exports.motion_io import complete_motion_dict
97
+
98
+ trans = np.asarray(trans, dtype=np.float32)
99
+ root_orient = np.asarray(root_orient, dtype=np.float32)
100
+ pose_body = np.asarray(pose_body, dtype=np.float32)
101
+ if trans.ndim != 2 or trans.shape[-1] != 3:
102
+ raise ValueError(f"trans must be (T, 3); got {trans.shape}")
103
+ if root_orient.shape != trans.shape:
104
+ raise ValueError(f"root_orient shape {root_orient.shape} must match trans {trans.shape}")
105
+ t = trans.shape[0]
106
+ if pose_body.shape != (t, 63):
107
+ raise ValueError(f"pose_body must be (T, 63); got {pose_body.shape}")
108
+
109
+ pelvis_offset = skeleton.neutral_joints[skeleton.root_idx].detach().cpu().numpy().astype(np.float32)
110
+ device = skeleton.neutral_joints.device
111
+ dtype = torch.float32
112
+
113
+ Y_np = kimodo_y_up_to_amass_coord_rotation_matrix()
114
+ if z_up:
115
+ y_up_to_z_up = torch.from_numpy(Y_np).to(device=device, dtype=dtype)
116
+ # trans_amass = root_kimodo @ Y.T - pelvis_offset => root_kimodo = (trans_amass + pelvis_offset) @ Y
117
+ root_positions_np = (trans + pelvis_offset) @ Y_np
118
+ else:
119
+ root_positions_np = trans + pelvis_offset
120
+
121
+ root_positions = torch.from_numpy(root_positions_np).to(device=device, dtype=dtype)
122
+
123
+ R_amass_root = axis_angle_to_matrix(torch.from_numpy(root_orient).to(device=device, dtype=dtype))
124
+ if z_up:
125
+ R_kimodo_root = torch.einsum("ij,tjk->tik", y_up_to_z_up.T, R_amass_root)
126
+ else:
127
+ R_kimodo_root = R_amass_root
128
+
129
+ nb = skeleton.nbjoints
130
+ if nb != 22:
131
+ raise ValueError(f"Expected SMPL-X body skeleton with 22 joints; got {nb}")
132
+
133
+ local_rot_mats = torch.zeros((t, nb, 3, 3), device=device, dtype=dtype)
134
+ local_rot_mats[:, 0] = R_kimodo_root
135
+
136
+ pose_aa = torch.from_numpy(pose_body.reshape(t, 21, 3)).to(device=device, dtype=dtype)
137
+ local_rot_mats[:, 1:] = axis_angle_to_matrix(pose_aa.reshape(-1, 3)).reshape(t, 21, 3, 3)
138
+
139
+ return complete_motion_dict(local_rot_mats, root_positions, skeleton, source_fps)
140
+
141
+
142
+ def amass_npz_to_kimodo_motion(npz_path: str, skeleton, source_fps: Optional[float] = None, *, z_up: bool = True):
143
+ """Load an AMASS-style ``.npz`` and return a Kimodo motion dict.
144
+
145
+ Args:
146
+ npz_path: Path to AMASS NPZ (``trans``, ``root_orient``, ``pose_body``, ...).
147
+ skeleton: SMPL-X skeleton instance.
148
+ source_fps: Source frame rate (Hz); if ``None``, uses ``mocap_frame_rate``
149
+ from the file when present, else ``30.0``.
150
+ z_up: Same meaning as :func:`amass_arrays_to_kimodo_motion`.
151
+ """
152
+ with np.load(npz_path, allow_pickle=True) as data:
153
+ trans = np.asarray(data["trans"], dtype=np.float32)
154
+ root_orient = np.asarray(data["root_orient"], dtype=np.float32)
155
+ pose_body = np.asarray(data["pose_body"], dtype=np.float32)
156
+ if source_fps is None:
157
+ source_fps = float(data["mocap_frame_rate"]) if "mocap_frame_rate" in data.files else 30.0
158
+
159
+ return amass_arrays_to_kimodo_motion(trans, root_orient, pose_body, skeleton, source_fps, z_up=z_up)
160
+
161
+
162
+ class AMASSConverter:
163
+ def __init__(
164
+ self,
165
+ fps,
166
+ skeleton,
167
+ beta_path=str(skeleton_asset_path("smplx22", "beta.npy")),
168
+ mean_hands_path=str(skeleton_asset_path("smplx22", "mean_hands.npy")),
169
+ ):
170
+ self.fps = fps
171
+ self.skeleton = skeleton
172
+ # Load betas
173
+ if os.path.exists(beta_path):
174
+ # only use first 16 betas to match AMASS
175
+ betas = np.load(beta_path)[:16]
176
+ else:
177
+ betas = np.zeros(16)
178
+
179
+ # Load mean hands
180
+ if os.path.exists(mean_hands_path):
181
+ mean_hands = np.load(mean_hands_path)
182
+ else:
183
+ mean_hands = np.zeros(90)
184
+
185
+ self.default_frame_params = {
186
+ "pose_jaw": np.zeros(3),
187
+ "pose_eye": np.zeros(6),
188
+ "pose_hand": mean_hands,
189
+ }
190
+ self.output_dict_base = {
191
+ "gender": "neutral",
192
+ "surface_model_type": "smplx",
193
+ "betas": betas,
194
+ "num_betas": len(betas),
195
+ "mocap_frame_rate": float(fps),
196
+ }
197
+
198
+ def convert_save_npz(self, output: dict, npz_path, z_up=True):
199
+ trans, root_orient, pose_body = get_amass_parameters(
200
+ output["local_rot_mats"],
201
+ output["root_positions"],
202
+ self.skeleton,
203
+ z_up=z_up,
204
+ )
205
+ nb_frames = trans.shape[-2]
206
+
207
+ amass_output_base = self.output_dict_base.copy()
208
+ for key, val in self.default_frame_params.items():
209
+ amass_output_base[key] = einops.repeat(val, "d -> t d", t=nb_frames)
210
+
211
+ amass_output_base["mocap_time_length"] = nb_frames / self.fps
212
+ self.save_npz(trans, root_orient, pose_body, amass_output_base, npz_path)
213
+
214
+ def save_npz(self, trans, root_orient, pose_body, base_output, npz_path):
215
+ shape = trans.shape
216
+ if len(shape) == 3 and shape[0] == 1:
217
+ # if only one motion, squeeze the data
218
+ trans = trans[0]
219
+ root_orient = root_orient[0]
220
+ pose_body = pose_body[0]
221
+ shape = trans.shape
222
+ if len(shape) == 2:
223
+ amass_output = {
224
+ "trans": trans,
225
+ "root_orient": root_orient,
226
+ "pose_body": pose_body,
227
+ } | base_output
228
+ np.savez(npz_path, **amass_output)
229
+
230
+ elif len(shape) == 3:
231
+ # real batch of motions
232
+ npz_path_base, ext = os.path.splitext(npz_path)
233
+ for i in range(shape[0]):
234
+ npz_path_i = npz_path_base + "_" + str(i).zfill(2) + ext
235
+ self.save_npz(trans[i], root_orient[i], pose_body[i], base_output, npz_path_i)
236
+
237
+
238
+ # amass_output = {
239
+ # "gender": "neutral",
240
+ # "surface_model_type": "smplx",
241
+ # "mocap_frame_rate": float(fps),
242
+ # "mocap_time_length": len(motion) / float(fps)
243
+ # "trans": trans,
244
+ # "betas": betas,
245
+ # "num_betas": len(betas),
246
+ # "root_orient": np.array([T, 3]), # axis angle
247
+ # "pose_body": np.array([T, 63]), # 63=21*3, axis angle 21 = 22 - root
248
+ # "pose_hand": np.array([T, 90]), # 90=30*3=15*2*3 axis angle (load from mean_hands)
249
+ # "pose_jaw": np.array([T, 3]), # all zeros is fine
250
+ # "pose_eye": np.array([T, 6]), # all zeros is fine`
251
+ # }
kimodo/geometry.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Rotation and representation conversions: axis-angle, quaternion, matrix, 6D continuous."""
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def angle_to_Y_rotation_matrix(angle: torch.Tensor) -> torch.Tensor:
10
+ """Build a rotation matrix around the Y axis from a scalar angle (radians).
11
+
12
+ Shape: angle.shape + (3, 3).
13
+ """
14
+ cos, sin = torch.cos(angle), torch.sin(angle)
15
+ one, zero = torch.ones_like(angle), torch.zeros_like(angle)
16
+ mat = torch.stack((cos, zero, sin, zero, one, zero, -sin, zero, cos), -1)
17
+ mat = mat.reshape(angle.shape + (3, 3))
18
+ return mat
19
+
20
+
21
+ def matrix_to_cont6d(matrix: torch.Tensor) -> torch.Tensor:
22
+ """Convert rotation matrix to 6D continuous representation (first two columns).
23
+
24
+ Shape: (..., 3, 3) -> (..., 6).
25
+ """
26
+ cont_6d = torch.concat([matrix[..., 0], matrix[..., 1]], dim=-1)
27
+ return cont_6d
28
+
29
+
30
+ def cont6d_to_matrix(cont6d: torch.Tensor) -> torch.Tensor:
31
+ """Convert 6D continuous representation to rotation matrix (Gram–Schmidt on two columns).
32
+
33
+ Last dim must be 6.
34
+ """
35
+ assert cont6d.shape[-1] == 6, "The last dimension must be 6"
36
+ x_raw = cont6d[..., 0:3]
37
+ y_raw = cont6d[..., 3:6]
38
+
39
+ x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
40
+ z = torch.cross(x, y_raw, dim=-1)
41
+ z = z / torch.norm(z, dim=-1, keepdim=True)
42
+
43
+ y = torch.cross(z, x, dim=-1)
44
+
45
+ x = x[..., None]
46
+ y = y[..., None]
47
+ z = z[..., None]
48
+
49
+ mat = torch.cat([x, y, z], dim=-1)
50
+ return mat
51
+
52
+
53
+ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
54
+ """Convert axis-angle to rotation matrix.
55
+
56
+ Args:
57
+ axis_angle: (..., 3) axis-angle vectors (angle = norm, axis = normalized)
58
+ Returns:
59
+ rotmat: (..., 3, 3) rotation matrices
60
+ """
61
+ eps = 1e-6
62
+ angle = torch.norm(axis_angle, dim=-1, keepdim=True) # (..., 1)
63
+ axis = axis_angle / (angle + eps)
64
+
65
+ x, y, z = axis.unbind(-1)
66
+
67
+ zero = torch.zeros_like(x)
68
+ K = torch.stack([zero, -z, y, z, zero, -x, -y, x, zero], dim=-1).reshape(*axis.shape[:-1], 3, 3)
69
+
70
+ eye = torch.eye(3, device=axis.device, dtype=axis.dtype)
71
+ eye = eye.expand(*axis.shape[:-1], 3, 3)
72
+
73
+ sin = torch.sin(angle)[..., None]
74
+ cos = torch.cos(angle)[..., None]
75
+
76
+ R = eye + sin * K + (1 - cos) * (K @ K)
77
+ return R
78
+
79
+
80
+ def matrix_to_axis_angle(R: torch.Tensor) -> torch.Tensor:
81
+ """Convert rotation matrix to axis-angle via quaternions (more numerically stable).
82
+
83
+ Args:
84
+ R: (..., 3, 3) rotation matrices
85
+ Returns:
86
+ axis_angle: (..., 3)
87
+ """
88
+ # Go through quaternions for numerical stability
89
+ quat = matrix_to_quaternion(R) # (..., 4) with (w, x, y, z)
90
+ return quaternion_to_axis_angle(quat)
91
+
92
+
93
+ def quaternion_to_axis_angle(quat: torch.Tensor) -> torch.Tensor:
94
+ """Convert quaternion to axis-angle representation.
95
+
96
+ Args:
97
+ quat: (..., 4) quaternions with real part first (w, x, y, z)
98
+ Returns:
99
+ axis_angle: (..., 3)
100
+ """
101
+ eps = 1e-6
102
+
103
+ # Ensure canonical form to avoid sign ambiguity.
104
+ # Primary: prefer w > 0. When w ≈ 0 (angle ≈ π), prefer first nonzero xyz > 0.
105
+ w = quat[..., 0:1]
106
+ xyz = quat[..., 1:]
107
+
108
+ # Find first significant component of xyz for tie-breaking when w ≈ 0
109
+ first_significant = xyz[..., 0:1] # use x component as tie-breaker
110
+
111
+ # Flip if: w < 0, OR (w ≈ 0 AND first xyz component < 0)
112
+ should_flip = (w < -eps) | ((w.abs() <= eps) & (first_significant < 0))
113
+ quat = torch.where(should_flip, -quat, quat)
114
+
115
+ w = quat[..., 0]
116
+ xyz = quat[..., 1:]
117
+
118
+ # sin(angle/2) = ||xyz||
119
+ sin_half_angle = xyz.norm(dim=-1)
120
+
121
+ # angle = 2 * atan2(sin(angle/2), cos(angle/2))
122
+ # This is more stable than 2 * acos(w) near angle=0
123
+ angle = 2.0 * torch.atan2(sin_half_angle, w)
124
+
125
+ # axis = xyz / sin(angle/2), but handle small angles
126
+ # For small angles: axis-angle ≈ 2 * xyz (since sin(x) ≈ x for small x)
127
+ small_angle = sin_half_angle.abs() < eps
128
+
129
+ # Safe division
130
+ scale = torch.where(
131
+ small_angle,
132
+ 2.0 * torch.ones_like(angle), # small angle: axis_angle ≈ 2 * xyz
133
+ angle / sin_half_angle.clamp(min=eps),
134
+ )
135
+
136
+ return xyz * scale.unsqueeze(-1)
137
+
138
+
139
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
140
+ """Returns torch.sqrt(torch.max(0, x)) subgradient is zero where x is 0."""
141
+ return torch.sqrt(x * (x > 0).to(x.dtype))
142
+
143
+
144
+ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
145
+ """Convert rotations given as rotation matrices to quaternions.
146
+
147
+ Args:
148
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
149
+ Returns:
150
+ quaternions with real part first, as tensor of shape (..., 4).
151
+ """
152
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
153
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
154
+
155
+ batch_dim = matrix.shape[:-2]
156
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
157
+
158
+ q_abs = _sqrt_positive_part(
159
+ torch.stack(
160
+ [
161
+ 1.0 + m00 + m11 + m22,
162
+ 1.0 + m00 - m11 - m22,
163
+ 1.0 - m00 + m11 - m22,
164
+ 1.0 - m00 - m11 + m22,
165
+ ],
166
+ dim=-1,
167
+ )
168
+ )
169
+
170
+ quat_by_rijk = torch.stack(
171
+ [
172
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
173
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
174
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
175
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
176
+ ],
177
+ dim=-2,
178
+ )
179
+
180
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
181
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
182
+
183
+ return (
184
+ (F.one_hot(q_abs.argmax(dim=-1), num_classes=4)[..., None] * quat_candidates)
185
+ .sum(dim=-2)
186
+ .reshape(batch_dim + (4,))
187
+ )
188
+
189
+
190
+ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
191
+ """Convert rotations given as quaternions to rotation matrices.
192
+
193
+ Args:
194
+ quaternions: quaternions with real part first,
195
+ as tensor of shape (..., 4).
196
+ Returns:
197
+ Rotation matrices as tensor of shape (..., 3, 3).
198
+ """
199
+ r, i, j, k = torch.unbind(quaternions, -1)
200
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
201
+
202
+ o = torch.stack(
203
+ (
204
+ 1 - two_s * (j * j + k * k),
205
+ two_s * (i * j - k * r),
206
+ two_s * (i * k + j * r),
207
+ two_s * (i * j + k * r),
208
+ 1 - two_s * (i * i + k * k),
209
+ two_s * (j * k - i * r),
210
+ two_s * (i * k - j * r),
211
+ two_s * (j * k + i * r),
212
+ 1 - two_s * (i * i + j * j),
213
+ ),
214
+ -1,
215
+ )
216
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
kimodo/meta.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Parse and normalize prompt text/duration data from meta dicts."""
4
+
5
+ import os
6
+ from typing import Any, Optional
7
+
8
+ from kimodo.tools import load_json
9
+
10
+ from .sanitize import sanitize_text, sanitize_texts
11
+
12
+
13
+ def load_prompts_from_meta(meta_path: str, **kwargs):
14
+ """Load prompts from a meta dict or file. If fps is provided, the durations are converted to
15
+ frames.
16
+
17
+ Args:
18
+ meta_path: Path to the meta file.
19
+ **kwargs: Additional arguments to pass to parse_prompts_from_meta.
20
+
21
+ Returns:
22
+ texts: List of texts.
23
+ durations: List of durations in seconds or frames.
24
+ """
25
+ if not os.path.exists(meta_path):
26
+ raise FileNotFoundError(f"meta.json not found in input folder: {meta_path}")
27
+
28
+ meta = load_json(meta_path)
29
+ return parse_prompts_from_meta(meta, **kwargs)
30
+
31
+
32
+ def parse_prompts_from_meta(
33
+ meta: dict[str, Any],
34
+ fps: Optional[float] = None,
35
+ sanitize: bool = False,
36
+ ) -> tuple[list[str], list[float]]:
37
+ """Parse prompt texts and durations from a meta dict into normalized lists. If fps is provided,
38
+ the durations are converted to frames.
39
+
40
+ Accepts either:
41
+ - Single prompt: "text" (str) and "duration" (float) in seconds.
42
+ - Multiple prompts: "texts" (list of str) and "durations" (list of float) in seconds.
43
+
44
+ Returns:
45
+ (texts, durations): texts as list of str, durations as list of float (seconds or frames).
46
+ Lengths of both lists are equal.
47
+
48
+ Raises:
49
+ ValueError: If meta does not contain a recognized format.
50
+ """
51
+ # Single prompt
52
+ if "text" in meta and "duration" in meta:
53
+ text = meta["text"]
54
+ duration = float(meta["duration"])
55
+ if fps is not None:
56
+ duration = int(duration * fps)
57
+ if isinstance(text, list):
58
+ raise ValueError("meta has 'text' but it is a list; use 'texts' for multiple prompts")
59
+
60
+ if sanitize:
61
+ text = sanitize_text(text)
62
+ return ([text], [duration])
63
+
64
+ # Multiple prompts
65
+ if "texts" in meta and "durations" in meta:
66
+ texts = meta["texts"]
67
+ durations = meta["durations"]
68
+ if not isinstance(texts, list) or not isinstance(durations, list):
69
+ raise ValueError("meta 'texts' and 'durations' must be lists")
70
+ if len(texts) != len(durations):
71
+ raise ValueError(f"meta 'texts' and 'durations' length mismatch: {len(texts)} vs {len(durations)}")
72
+ durations = [float(d) for d in durations]
73
+ if fps is not None:
74
+ durations = [int(d * fps) for d in durations]
75
+
76
+ if sanitize:
77
+ texts = sanitize_texts(texts)
78
+ return texts, durations
79
+
80
+ raise ValueError("meta must contain either 'text' and 'duration', or 'texts' and 'durations'.")
kimodo/metrics/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Evaluation metrics for motion quality (foot skate, contact consistency, constraint following)."""
4
+
5
+ from .base import (
6
+ Metric,
7
+ aggregate_metrics,
8
+ clear_metrics,
9
+ compute_metrics,
10
+ )
11
+ from .constraints import ContraintFollow
12
+ from .foot_skate import (
13
+ FootContactConsistency,
14
+ FootSkateFromContacts,
15
+ FootSkateFromHeight,
16
+ FootSkateRatio,
17
+ )
18
+ from .tmr import (
19
+ TMR_EmbeddingMetric,
20
+ TMR_Metric,
21
+ compute_tmr_per_sample_retrieval,
22
+ compute_tmr_retrieval_metrics,
23
+ )
24
+
25
+ __all__ = [
26
+ "Metric",
27
+ "ContraintFollow",
28
+ "FootContactConsistency",
29
+ "FootSkateFromContacts",
30
+ "FootSkateFromHeight",
31
+ "FootSkateRatio",
32
+ "TMR_EmbeddingMetric",
33
+ "TMR_Metric",
34
+ "aggregate_metrics",
35
+ "clear_metrics",
36
+ "compute_metrics",
37
+ "compute_tmr_per_sample_retrieval",
38
+ "compute_tmr_retrieval_metrics",
39
+ ]
kimodo/metrics/base.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Base metric class and batch/aggregate helpers."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from collections import defaultdict
8
+ from typing import Dict, List
9
+
10
+ import torch
11
+
12
+
13
+ class Metric:
14
+ """Base class for metrics that accumulate results over multiple __call__ and expose
15
+ aggregate()."""
16
+
17
+ def __init__(self, **kwargs):
18
+ self.clear()
19
+
20
+ def __call__(self, *args, **kwargs):
21
+ """Compute metric for current batch, append to saved_metrics, and return the batch
22
+ result."""
23
+ metrics = self._compute(*args, **kwargs)
24
+ for key, val in metrics.items():
25
+ self.saved_metrics[key].append(val.detach().cpu().float())
26
+ return metrics
27
+
28
+ def _compute(self, **kwargs):
29
+ """Subclasses implement this to compute metric dict from batch inputs."""
30
+ raise NotImplementedError()
31
+
32
+ def clear(self):
33
+ """Reset all accumulated metric values."""
34
+ self.saved_metrics = defaultdict(list)
35
+
36
+ def aggregate(self):
37
+ """Return a dict of concatenated/stacked tensors over all accumulated batches."""
38
+ output = {}
39
+ for key, lst in self.saved_metrics.items():
40
+ try:
41
+ output[key] = torch.cat(lst)
42
+ except RuntimeError:
43
+ output[key] = torch.stack(lst)
44
+ return output
45
+
46
+
47
+ def compute_metrics(metrics_list: List[Metric], metrics_in: Dict) -> Dict:
48
+ """Run each metric on metrics_in and return the combined dict of batch results."""
49
+ metrics_out = {}
50
+ for metric in metrics_list:
51
+ metrics_out.update(metric(**metrics_in))
52
+ return metrics_out
53
+
54
+
55
+ def aggregate_metrics(metrics_list: List[Metric]) -> Dict:
56
+ """Return combined aggregated results (concatenated over batches) for all metrics."""
57
+ metrics_out = {}
58
+ for metric in metrics_list:
59
+ metrics_out.update(metric.aggregate())
60
+ return metrics_out
61
+
62
+
63
+ def clear_metrics(metrics_list: List[Metric]) -> None:
64
+ """Clear accumulated values for all metrics in the list."""
65
+ for metric in metrics_list:
66
+ metric.clear()
kimodo/metrics/constraints.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Constraint-following metrics."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from collections import defaultdict
8
+ from typing import Dict, List, Optional
9
+
10
+ import torch
11
+ from torch import Tensor
12
+
13
+ from kimodo.constraints import (
14
+ EndEffectorConstraintSet,
15
+ FullBodyConstraintSet,
16
+ Root2DConstraintSet,
17
+ )
18
+ from kimodo.tools import ensure_batched
19
+
20
+ from .base import Metric
21
+
22
+
23
+ class ContraintFollow(Metric):
24
+ """Constraint-following metric dispatcher for kimodo constraint sets."""
25
+
26
+ def __init__(
27
+ self,
28
+ skeleton,
29
+ root_threshold: float = 0.10,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.skeleton = skeleton
34
+ self.root_threshold = root_threshold
35
+
36
+ @ensure_batched(posed_joints=4, constraints_lst=2, lengths=1)
37
+ def _compute(
38
+ self,
39
+ posed_joints: Tensor,
40
+ constraints_lst: Optional[List],
41
+ lengths: Optional[Tensor] = None,
42
+ **kwargs,
43
+ ) -> Dict:
44
+ if not constraints_lst:
45
+ return {}
46
+
47
+ root_idx = self.skeleton.root_idx
48
+ output = defaultdict(list)
49
+
50
+ for posed_joints_s, constraint_lst_s, lengths_s in zip(posed_joints, constraints_lst, lengths):
51
+ output_seq = defaultdict(list)
52
+ for constraint in constraint_lst_s:
53
+ frame_idx = constraint.frame_indices.to(device=posed_joints_s.device, dtype=torch.long)
54
+ assert frame_idx.max() < lengths_s, "The constraint is defined outsite the lenght of the motion."
55
+ if frame_idx.numel() == 0:
56
+ continue
57
+
58
+ if isinstance(constraint, Root2DConstraintSet):
59
+ pred_root2d = posed_joints_s[frame_idx, root_idx][:, [0, 2]]
60
+ target = constraint.smooth_root_2d.to(posed_joints_s.device)
61
+
62
+ dist = torch.norm(pred_root2d - target, dim=-1)
63
+ output_seq["constraint_root2d_err"].append(dist)
64
+ hit = (dist <= self.root_threshold).float()
65
+ output_seq["constraint_root2d_acc"].append(hit)
66
+
67
+ elif isinstance(constraint, FullBodyConstraintSet):
68
+ pred = posed_joints_s[frame_idx]
69
+ target = constraint.global_joints_positions.to(posed_joints_s.device)
70
+ err = torch.norm(pred - target, dim=-1)
71
+ output_seq["constraint_fullbody_keyframe"].append(err)
72
+
73
+ elif isinstance(constraint, EndEffectorConstraintSet):
74
+ pos_idx = constraint.pos_indices.to(device=posed_joints_s.device, dtype=torch.long)
75
+ pred = posed_joints_s[frame_idx].index_select(1, pos_idx)
76
+ target = constraint.global_joints_positions.to(posed_joints_s.device).index_select(1, pos_idx)
77
+ err = torch.norm(pred - target, dim=-1)
78
+ output_seq["constraint_end_effector"].append(err)
79
+
80
+ # in case we have several same constraints in the list
81
+ for key, val in output_seq.items():
82
+ output[key].append(torch.cat(val).mean())
83
+
84
+ reduced = {}
85
+ for key, vals in output.items():
86
+ reduced[key] = torch.stack(vals, dim=0)
87
+ return reduced
kimodo/metrics/foot_skate.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Foot skate and contact consistency metrics."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import Dict, Optional
8
+
9
+ import torch
10
+ from torch import Tensor
11
+
12
+ from kimodo.motion_rep.feature_utils import compute_vel_xyz
13
+ from kimodo.motion_rep.feet import foot_detect_from_pos_and_vel
14
+ from kimodo.skeleton import SkeletonBase
15
+ from kimodo.tools import ensure_batched
16
+
17
+ from .base import Metric
18
+
19
+
20
+ class FootSkateFromHeight(Metric):
21
+ """When toe joint is near the floor, measures mean velocity of the toes."""
22
+
23
+ def __init__(
24
+ self,
25
+ skeleton: SkeletonBase,
26
+ fps: float,
27
+ height_thresh: float = 0.05,
28
+ **kwargs,
29
+ ):
30
+ super().__init__(**kwargs)
31
+ self.height_thresh = height_thresh
32
+ self.skeleton = skeleton
33
+ self.fps = fps
34
+
35
+ @ensure_batched(posed_joints=4, lengths=1)
36
+ def _compute(
37
+ self,
38
+ posed_joints: Tensor,
39
+ lengths: Optional[Tensor] = None,
40
+ **kwargs,
41
+ ) -> Dict:
42
+ fidx = self.skeleton.foot_joint_idx
43
+ if len(fidx) != 4:
44
+ raise ValueError("FootSkateFromHeight expects four foot joints (heel/toe per foot)")
45
+
46
+ feet_pos = posed_joints[:, :, fidx]
47
+ toe_pos = feet_pos[:, :, [1, 3]]
48
+
49
+ toe_on_floor = (toe_pos[..., 1] < self.height_thresh)[:, :-1] # y-up [B, T, 2] where [left right]
50
+
51
+ dt = 1.0 / self.fps
52
+ toe_vel = torch.norm(toe_pos[:, 1:] - toe_pos[:, :-1], dim=-1) / dt # [B, nframes-1, 2]
53
+
54
+ # compute err
55
+ contact_toe_vel = toe_vel * toe_on_floor # vel when corresponding toe is on ground
56
+
57
+ # account for generated length
58
+ # since they are velocities use length-1 to avoid inaccurate vel going one frame past len
59
+ device = toe_on_floor.device
60
+ len_mask = torch.arange(toe_on_floor.shape[1], device=device)[None, :, None].expand(toe_on_floor.shape) < (
61
+ lengths[:, None, None] - 1
62
+ )
63
+ toe_on_floor = toe_on_floor * len_mask
64
+ contact_toe_vel = contact_toe_vel * len_mask
65
+
66
+ mean_vel = torch.sum(contact_toe_vel, (1, 2)) / (torch.sum(toe_on_floor, (1, 2)) + 1e-6)
67
+ return {"foot_skate_from_height": mean_vel}
68
+
69
+
70
+ class FootSkateFromContacts(Metric):
71
+ """Measures velocity of the toes and ankles when predicted to be in contact."""
72
+
73
+ def __init__(
74
+ self,
75
+ skeleton: SkeletonBase,
76
+ fps: float,
77
+ **kwargs,
78
+ ):
79
+ super().__init__(**kwargs)
80
+ self.skeleton = skeleton
81
+ self.fps = fps
82
+
83
+ @ensure_batched(posed_joints=4, foot_contacts=3, lengths=1)
84
+ def _compute(
85
+ self,
86
+ posed_joints: Tensor,
87
+ foot_contacts: Tensor,
88
+ lengths: Optional[Tensor] = None,
89
+ **kwargs,
90
+ ) -> Dict:
91
+ fidx = self.skeleton.foot_joint_idx
92
+ feet_pos = posed_joints[:, :, fidx]
93
+ dt = 1.0 / self.fps
94
+ foot_vel = torch.norm(feet_pos[:, 1:] - feet_pos[:, :-1], dim=-1) / dt
95
+
96
+ foot_contacts = foot_contacts[:, :-1]
97
+ vel_err = foot_vel * foot_contacts
98
+
99
+ # account for generated length
100
+ # since they are velocities use length-1 to avoid inaccurate vel going one frame past len
101
+ device = foot_contacts.device
102
+ len_mask = torch.arange(foot_contacts.shape[1], device=device)[None, :, None].expand(foot_contacts.shape) < (
103
+ lengths[:, None, None] - 1
104
+ )
105
+ foot_contacts = foot_contacts * len_mask
106
+ vel_err = vel_err * len_mask
107
+
108
+ mean_vel = torch.sum(vel_err, (1, 2)) / (torch.sum(foot_contacts, (1, 2)) + 1e-6) # mean over contacting frames
109
+
110
+ # Compute max velocity error across all feet and frames (per batch)
111
+ max_vel = vel_err.amax(dim=(1, 2)) # [B]
112
+
113
+ return {
114
+ "foot_skate_from_pred_contacts": mean_vel,
115
+ "foot_skate_max_vel": max_vel,
116
+ }
117
+
118
+
119
+ class FootSkateRatio(Metric):
120
+ """Compute fraction of frames where the foot skates when it is on the ground.
121
+
122
+ Inspired by GMD: https://github.com/korrawe/guided-motion-diffusion/blob/main/data_loaders/humanml/utils/metrics.py#L204
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ skeleton: SkeletonBase,
128
+ fps: float,
129
+ height_thresh=0.05,
130
+ vel_thresh=0.2,
131
+ **kwargs,
132
+ ):
133
+ super().__init__(**kwargs)
134
+ self.height_thresh = height_thresh
135
+ self.vel_thresh = vel_thresh
136
+
137
+ self.skeleton = skeleton
138
+ self.fps = fps
139
+
140
+ @ensure_batched(posed_joints=4, foot_contacts=3, lengths=1)
141
+ def _compute(
142
+ self,
143
+ posed_joints: Tensor,
144
+ foot_contacts: Tensor,
145
+ lengths: Optional[Tensor] = None,
146
+ **kwargs,
147
+ ) -> Dict:
148
+ fidx = self.skeleton.foot_joint_idx
149
+ assert len(fidx) == 4, "This metric assumes 4 foot joints: heel, toe, heel, toe"
150
+
151
+ feet_pos = posed_joints[:, :, fidx]
152
+ toe_pos = feet_pos[:, :, [1, 3]]
153
+
154
+ toe_on_floor = toe_pos[..., 1] < self.height_thresh # y-up [B, T, 2] where [left right]
155
+ # current and next frame on floor to consider it in contact
156
+ toe_on_floor = torch.logical_and(toe_on_floor[:, :-1], toe_on_floor[:, 1:]) # [B, T-1, 2]
157
+
158
+ dt = 1.0 / self.fps
159
+ toe_vel = torch.norm(toe_pos[:, 1:] - toe_pos[:, :-1], dim=-1) / dt # [B, nframes-1, 2]
160
+
161
+ # compute err
162
+ contact_toe_vel = toe_vel * toe_on_floor # vel when corresponding toe is on ground
163
+
164
+ # account for generated length
165
+ # since they are velocities use length-1 to avoid inaccurate vel going one frame past len
166
+ device = toe_on_floor.device
167
+ len_mask = torch.arange(toe_on_floor.shape[1], device=device)[None, :, None].expand(toe_on_floor.shape) < (
168
+ lengths[:, None, None] - 1
169
+ )
170
+ toe_on_floor = toe_on_floor * len_mask
171
+ contact_toe_vel = contact_toe_vel * len_mask
172
+
173
+ # skating if velocity during contact > thresh
174
+ toe_skate = contact_toe_vel > self.vel_thresh
175
+ skate_ratio = torch.sum(toe_skate, (1, 2)) / (torch.sum(toe_on_floor, (1, 2)) + 1e-6)
176
+ return {"foot_skate_ratio": skate_ratio}
177
+
178
+
179
+ class FootContactConsistency(Metric):
180
+ """Measures consistency between heuristic detected foot contacts (from height and velocity) and
181
+ predicted foot contacts.
182
+
183
+ i.e. accuracy of how well predicted matches heuristic.
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ skeleton: SkeletonBase,
189
+ fps: float,
190
+ vel_thresh: float = 0.15,
191
+ height_thresh: float = 0.10,
192
+ **kwargs,
193
+ ):
194
+ super().__init__(**kwargs)
195
+ self.vel_thresh = vel_thresh
196
+ self.height_thresh = height_thresh
197
+
198
+ self.skeleton = skeleton
199
+ self.fps = fps
200
+
201
+ @ensure_batched(posed_joints=4, foot_contacts=3, lengths=1)
202
+ def _compute(
203
+ self,
204
+ posed_joints: Tensor,
205
+ foot_contacts: Tensor,
206
+ lengths: Optional[Tensor] = None,
207
+ **kwargs,
208
+ ) -> Dict:
209
+ velocity = compute_vel_xyz(posed_joints, float(self.fps), lengths=lengths)
210
+ heuristic_contacts = foot_detect_from_pos_and_vel(
211
+ posed_joints,
212
+ velocity,
213
+ self.skeleton,
214
+ self.vel_thresh,
215
+ self.height_thresh,
216
+ )
217
+
218
+ # compute accuracy of predicted, treating heuristic as ground truth
219
+ num_contacts = foot_contacts.shape[-1]
220
+ incorrect = torch.logical_xor(heuristic_contacts, foot_contacts)
221
+ # account for generated length
222
+ # since they are velocities, use length-1 to avoid inaccurate vel going one frame past len
223
+ device = foot_contacts.device
224
+ len_mask = torch.arange(foot_contacts.shape[1], device=device)[None, :, None].expand(foot_contacts.shape) < (
225
+ lengths[:, None, None] - 1
226
+ )
227
+ incorrect = incorrect * len_mask
228
+
229
+ incorrect_ratio = torch.sum(incorrect, (1, 2)) / (num_contacts * (lengths - 1))
230
+ accuracy = 1 - incorrect_ratio
231
+
232
+ return {"foot_contact_consistency": accuracy}
kimodo/metrics/tmr.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """TMR evaluation metrics: text-motion retrieval, R-Precision, and related scores."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from collections import defaultdict
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ import numpy as np
11
+ import torch
12
+ from scipy import linalg
13
+ from torch import Tensor
14
+
15
+ from kimodo.model.tmr import TMR
16
+
17
+ from .base import Metric
18
+
19
+
20
+ # Scores are between 0 and 1
21
+ def get_score_matrix_unit(x, y):
22
+ sim_matrix = np.einsum("b i, c i -> b c", x, y)
23
+ scores = sim_matrix / 2 + 0.5
24
+ return scores
25
+
26
+
27
+ def get_scores_unit(x, y):
28
+ similarity = np.einsum("... i, ... i", x, y)
29
+ scores = similarity / 2 + 0.5
30
+ return scores
31
+
32
+
33
+ def compute_tmr_per_sample_retrieval(
34
+ motion_emb: np.ndarray,
35
+ text_emb: np.ndarray,
36
+ sample_ids: List[str],
37
+ texts: List[str],
38
+ top_k: int = 5,
39
+ ) -> List[Dict[str, Any]]:
40
+ """For each sample (text query i), compute t2m rank of motion i and top-k retrieved motions with
41
+ ids and texts.
42
+
43
+ Returns list of dicts: [{"rank": int, "top_k": [{"id": str, "text": str}, ...]}, ...].
44
+ """
45
+ motion_emb = np.asarray(motion_emb).squeeze()
46
+ text_emb = np.asarray(text_emb).squeeze()
47
+ if motion_emb.ndim == 1:
48
+ motion_emb = motion_emb[np.newaxis, :]
49
+ if text_emb.ndim == 1:
50
+ text_emb = text_emb[np.newaxis, :]
51
+ n = motion_emb.shape[0]
52
+ assert text_emb.shape[0] == n and len(sample_ids) == n and len(texts) == n
53
+ scores = get_score_matrix_unit(text_emb, motion_emb)
54
+ out: List[Dict[str, Any]] = []
55
+ for i in range(n):
56
+ row = np.asarray(scores[i])
57
+ order = np.argsort(row)[::-1]
58
+ rank = int(np.where(order == i)[0][0]) + 1
59
+ top_indices = order[:top_k]
60
+ top_k_list = [{"id": sample_ids[j], "text": texts[j]} for j in top_indices]
61
+ out.append({"rank": rank, "top_k": top_k_list})
62
+ return out
63
+
64
+
65
+ class TMR_Metric(Metric):
66
+ def __init__(
67
+ self,
68
+ tmr_model: TMR,
69
+ ranks: List = [1, 2, 3, 5, 10],
70
+ ranks_rounding=2,
71
+ **kwargs,
72
+ ):
73
+ super().__init__(**kwargs)
74
+ self.tmr_model = tmr_model
75
+ self.ranks = ranks
76
+ self.ranks_rounding = ranks_rounding
77
+
78
+ def clear(self):
79
+ self.saved_metrics = defaultdict(list)
80
+ self.saved_text_latents = []
81
+ self.saved_motion_gen_latents = []
82
+ self.saved_motion_gt_latents = []
83
+
84
+ def _compute(
85
+ self,
86
+ motion_rep,
87
+ pred_joints_output: Dict,
88
+ gt_joints_output: Dict,
89
+ text_x_dict: Dict,
90
+ lengths: Tensor,
91
+ **kwargs,
92
+ ) -> Dict:
93
+ pred_posed_joints = pred_joints_output["posed_joints"]
94
+ original_skeleton = motion_rep.skeleton if motion_rep is not None else None
95
+ latents_motion = self.tmr_model.encode_motion(
96
+ pred_posed_joints,
97
+ lengths=lengths,
98
+ original_skeleton=original_skeleton,
99
+ unit_vector=True,
100
+ )
101
+ latents_motion = latents_motion.cpu().numpy()
102
+
103
+ if isinstance(text_x_dict, dict) and "texts" in text_x_dict:
104
+ latents_text = self.tmr_model.encode_raw_text(text_x_dict["texts"], unit_vector=True)
105
+ else:
106
+ latents_text = self.tmr_model.encode_text(text_x_dict, unit_vector=True)
107
+ if latents_text.dim() == 1:
108
+ latents_text = latents_text.unsqueeze(0)
109
+ latents_text = latents_text.cpu().numpy()
110
+
111
+ self.saved_text_latents.append(latents_text)
112
+ self.saved_motion_gen_latents.append(latents_motion)
113
+
114
+ scores_text = get_scores_unit(latents_motion, latents_text)
115
+ output = {"TMR/t2m_sim": scores_text}
116
+
117
+ if gt_joints_output is not None and "posed_joints" in gt_joints_output:
118
+ gt_posed_joints = gt_joints_output["posed_joints"]
119
+ gt_latents_motion = self.tmr_model.encode_motion(
120
+ gt_posed_joints,
121
+ lengths=lengths,
122
+ original_skeleton=original_skeleton,
123
+ unit_vector=True,
124
+ )
125
+ gt_latents_motion = gt_latents_motion.cpu().numpy()
126
+ self.saved_motion_gt_latents.append(gt_latents_motion)
127
+
128
+ gt_scores_text = get_scores_unit(gt_latents_motion, latents_text)
129
+ scores_motion = get_scores_unit(latents_motion, gt_latents_motion)
130
+
131
+ output["TMR/t2m_gt_sim"] = gt_scores_text
132
+ output["TMR/m2m_sim"] = scores_motion
133
+
134
+ # pytorch tensors
135
+ for key, val in output.items():
136
+ output[key] = torch.tensor(val)
137
+ return output
138
+
139
+ def aggregate(self):
140
+ output = {}
141
+ for key, lst in self.saved_metrics.items():
142
+ output[key] = np.concatenate(lst)
143
+
144
+ assert self.saved_text_latents, "Should call the metric at least once."
145
+
146
+ text_latents = np.concatenate(self.saved_text_latents)
147
+ motion_gen_latents = np.concatenate(self.saved_motion_gen_latents)
148
+
149
+ batch_size = len(text_latents)
150
+ assert text_latents.shape == motion_gen_latents.shape
151
+
152
+ scores_t2m = get_score_matrix_unit(text_latents, motion_gen_latents)
153
+ scores_t2t = get_score_matrix_unit(text_latents, text_latents)
154
+
155
+ t2m_metrics = contrastive_metrics(
156
+ scores=scores_t2m,
157
+ scores_t2t=scores_t2t,
158
+ threshold=0.99,
159
+ rounding=2,
160
+ )
161
+
162
+ for key, val in t2m_metrics.items():
163
+ output["TMR/t2m_R/" + key] = val
164
+
165
+ mu_gen, cov_gen = calculate_activation_statistics(motion_gen_latents)
166
+ mu_text, cov_text = calculate_activation_statistics(text_latents)
167
+
168
+ fid_gen_text = calculate_frechet_distance(mu_gen, cov_gen, mu_text, cov_text)
169
+ output["TMR/FID/gen_text"] = fid_gen_text
170
+
171
+ if self.saved_motion_gt_latents:
172
+ motion_gt_latents = np.concatenate(self.saved_motion_gt_latents)
173
+ assert motion_gt_latents.shape == motion_gen_latents.shape
174
+
175
+ scores_m2gm = get_score_matrix_unit(motion_gen_latents, motion_gt_latents)
176
+ scores_t2gm = get_score_matrix_unit(text_latents, motion_gt_latents)
177
+
178
+ m2gm_metrics = contrastive_metrics(
179
+ scores=scores_m2gm,
180
+ scores_t2t=scores_t2t,
181
+ threshold=0.99,
182
+ rounding=2,
183
+ )
184
+ for key, val in m2gm_metrics.items():
185
+ output["TMR/m2m_R/" + key] = val
186
+
187
+ t2gm_metrics = contrastive_metrics(
188
+ scores=scores_t2gm,
189
+ scores_t2t=scores_t2t,
190
+ threshold=0.99,
191
+ rounding=2,
192
+ )
193
+ for key, val in t2gm_metrics.items():
194
+ output["TMR/t2m_gt_R/" + key] = val
195
+
196
+ mu_gt_motion, cov_gt_motion = calculate_activation_statistics(motion_gt_latents)
197
+ fid_gen_motion = calculate_frechet_distance(
198
+ mu_gen,
199
+ cov_gen,
200
+ mu_gt_motion,
201
+ cov_gt_motion,
202
+ )
203
+ output["TMR/FID/gen_gt"] = fid_gen_motion
204
+
205
+ fid_gt_text = calculate_frechet_distance(
206
+ mu_gt_motion,
207
+ cov_gt_motion,
208
+ mu_text,
209
+ cov_text,
210
+ )
211
+ output["TMR/FID/gt_text"] = fid_gt_text
212
+
213
+ for key, val in output.items():
214
+ if isinstance(val, (int, float, np.integer, np.floating)):
215
+ val = torch.tensor([val for _ in range(batch_size)])
216
+
217
+ if isinstance(val, np.ndarray):
218
+ val = torch.from_numpy(val)
219
+
220
+ output[key] = val.cpu().float()
221
+ return output
222
+
223
+
224
+ class TMR_EmbeddingMetric(Metric):
225
+ """TMR metrics from precomputed motion and text embeddings (no model load).
226
+
227
+ Use in the loop: pass motion_emb and text_emb per sample; aggregate() computes retrieval metrics.
228
+ """
229
+
230
+ def __init__(self, ranks_rounding: int = 2, **kwargs):
231
+ super().__init__(**kwargs)
232
+ self.ranks_rounding = ranks_rounding
233
+
234
+ def clear(self):
235
+ self.saved_metrics = defaultdict(list)
236
+ self.saved_text_latents = []
237
+ self.saved_motion_gen_latents = []
238
+ self.saved_motion_gt_latents = []
239
+
240
+ def _compute(
241
+ self,
242
+ motion_emb=None,
243
+ text_emb=None,
244
+ gt_motion_emb=None,
245
+ **kwargs,
246
+ ) -> Dict:
247
+ if motion_emb is None or text_emb is None:
248
+ return {}
249
+ motion_emb = np.asarray(motion_emb)
250
+ text_emb = np.asarray(text_emb)
251
+ if motion_emb.ndim == 1:
252
+ motion_emb = motion_emb[np.newaxis, :]
253
+ if text_emb.ndim == 1:
254
+ text_emb = text_emb[np.newaxis, :]
255
+ self.saved_text_latents.append(text_emb)
256
+ self.saved_motion_gen_latents.append(motion_emb)
257
+ if gt_motion_emb is not None:
258
+ gt_motion_emb = np.asarray(gt_motion_emb)
259
+ if gt_motion_emb.ndim == 1:
260
+ gt_motion_emb = gt_motion_emb[np.newaxis, :]
261
+ self.saved_motion_gt_latents.append(gt_motion_emb)
262
+ scores = get_scores_unit(motion_emb, text_emb)
263
+ return {"TMR/t2m_sim": torch.tensor(scores, dtype=torch.float32)}
264
+
265
+ def aggregate(self):
266
+ output = {}
267
+ for key, lst in self.saved_metrics.items():
268
+ output[key] = np.concatenate(lst)
269
+ if not self.saved_text_latents:
270
+ return output
271
+ text_latents = np.concatenate(self.saved_text_latents)
272
+ motion_gen_latents = np.concatenate(self.saved_motion_gen_latents)
273
+ batch_size = len(text_latents)
274
+ assert text_latents.shape == motion_gen_latents.shape
275
+ scores_t2m = get_score_matrix_unit(text_latents, motion_gen_latents)
276
+ scores_t2t = get_score_matrix_unit(text_latents, text_latents)
277
+ t2m_metrics = contrastive_metrics(
278
+ scores=scores_t2m,
279
+ scores_t2t=scores_t2t,
280
+ threshold=0.99,
281
+ rounding=self.ranks_rounding,
282
+ )
283
+ for key, val in t2m_metrics.items():
284
+ output["TMR/t2m_R/" + key] = val
285
+ mu_gen, cov_gen = calculate_activation_statistics(motion_gen_latents)
286
+ mu_text, cov_text = calculate_activation_statistics(text_latents)
287
+ output["TMR/FID/gen_text"] = calculate_frechet_distance(mu_gen, cov_gen, mu_text, cov_text)
288
+ if self.saved_motion_gt_latents:
289
+ motion_gt_latents = np.concatenate(self.saved_motion_gt_latents)
290
+ assert motion_gt_latents.shape == motion_gen_latents.shape
291
+ scores_m2gm = get_score_matrix_unit(motion_gen_latents, motion_gt_latents)
292
+ scores_t2gm = get_score_matrix_unit(text_latents, motion_gt_latents)
293
+ m2gm_metrics = contrastive_metrics(
294
+ scores=scores_m2gm,
295
+ scores_t2t=scores_t2t,
296
+ threshold=0.99,
297
+ rounding=self.ranks_rounding,
298
+ )
299
+ for key, val in m2gm_metrics.items():
300
+ output["TMR/m2m_R/" + key] = val
301
+ t2gm_metrics = contrastive_metrics(
302
+ scores=scores_t2gm,
303
+ scores_t2t=scores_t2t,
304
+ threshold=0.99,
305
+ rounding=self.ranks_rounding,
306
+ )
307
+ for key, val in t2gm_metrics.items():
308
+ output["TMR/t2m_gt_R/" + key] = val
309
+ mu_gt_motion, cov_gt_motion = calculate_activation_statistics(motion_gt_latents)
310
+ output["TMR/FID/gen_gt"] = calculate_frechet_distance(mu_gen, cov_gen, mu_gt_motion, cov_gt_motion)
311
+ output["TMR/FID/gt_text"] = calculate_frechet_distance(mu_gt_motion, cov_gt_motion, mu_text, cov_text)
312
+ for key, val in output.items():
313
+ if isinstance(val, (int, float, np.integer, np.floating)):
314
+ val = torch.tensor([val for _ in range(batch_size)])
315
+ if isinstance(val, np.ndarray):
316
+ val = torch.from_numpy(val)
317
+ output[key] = val.cpu().float()
318
+ return output
319
+
320
+
321
+ def compute_tmr_retrieval_metrics(
322
+ motion_emb: np.ndarray,
323
+ text_emb: np.ndarray,
324
+ gt_motion_emb: Optional[np.ndarray] = None,
325
+ rounding: int = 2,
326
+ ) -> Dict[str, float]:
327
+ """Compute TMR retrieval metrics from precomputed embeddings."""
328
+ if motion_emb.shape != text_emb.shape:
329
+ raise ValueError(f"Expected same shape for motion/text embeddings, got {motion_emb.shape} vs {text_emb.shape}")
330
+
331
+ scores_t2m = get_score_matrix_unit(text_emb, motion_emb)
332
+ scores_t2t = get_score_matrix_unit(text_emb, text_emb)
333
+
334
+ output: Dict[str, float] = {}
335
+ t2m_metrics = contrastive_metrics(
336
+ scores=scores_t2m,
337
+ scores_t2t=scores_t2t,
338
+ threshold=0.99,
339
+ rounding=rounding,
340
+ )
341
+ for key, val in t2m_metrics.items():
342
+ output[f"TMR/t2m_R/{key}"] = float(val)
343
+
344
+ mu_gen, cov_gen = calculate_activation_statistics(motion_emb)
345
+ mu_text, cov_text = calculate_activation_statistics(text_emb)
346
+ output["TMR/FID/gen_text"] = float(calculate_frechet_distance(mu_gen, cov_gen, mu_text, cov_text))
347
+
348
+ if gt_motion_emb is not None:
349
+ if gt_motion_emb.shape != motion_emb.shape:
350
+ raise ValueError(f"Expected gt motion embeddings shape {motion_emb.shape}, got {gt_motion_emb.shape}")
351
+
352
+ scores_m2gm = get_score_matrix_unit(motion_emb, gt_motion_emb)
353
+ scores_t2gm = get_score_matrix_unit(text_emb, gt_motion_emb)
354
+
355
+ m2gm_metrics = contrastive_metrics(
356
+ scores=scores_m2gm,
357
+ scores_t2t=scores_t2t,
358
+ threshold=0.99,
359
+ rounding=rounding,
360
+ )
361
+ for key, val in m2gm_metrics.items():
362
+ output[f"TMR/m2m_R/{key}"] = float(val)
363
+
364
+ t2gm_metrics = contrastive_metrics(
365
+ scores=scores_t2gm,
366
+ scores_t2t=scores_t2t,
367
+ threshold=0.99,
368
+ rounding=rounding,
369
+ )
370
+ for key, val in t2gm_metrics.items():
371
+ output[f"TMR/t2m_gt_R/{key}"] = float(val)
372
+
373
+ mu_gt_motion, cov_gt_motion = calculate_activation_statistics(gt_motion_emb)
374
+ output["TMR/FID/gen_gt"] = float(calculate_frechet_distance(mu_gen, cov_gen, mu_gt_motion, cov_gt_motion))
375
+ output["TMR/FID/gt_text"] = float(calculate_frechet_distance(mu_gt_motion, cov_gt_motion, mu_text, cov_text))
376
+
377
+ return output
378
+
379
+
380
+ def all_contrastive_metrics(sims, emb=None, threshold=None, rounding=2, return_cols=False):
381
+ text_selfsim = None
382
+ if emb is not None:
383
+ text_selfsim = emb @ emb.T
384
+
385
+ t2m_m, t2m_cols = contrastive_metrics(sims, text_selfsim, threshold, return_cols=True, rounding=rounding)
386
+ m2t_m, m2t_cols = contrastive_metrics(sims.T, text_selfsim, threshold, return_cols=True, rounding=rounding)
387
+
388
+ all_m = {}
389
+ for key in t2m_m:
390
+ all_m[f"t2m/{key}"] = t2m_m[key]
391
+ all_m[f"m2t/{key}"] = m2t_m[key]
392
+
393
+ all_m["t2m/len"] = float(len(sims))
394
+ all_m["m2t/len"] = float(len(sims[0]))
395
+ if return_cols:
396
+ return all_m, t2m_cols, m2t_cols
397
+ return all_m
398
+
399
+
400
+ def contrastive_metrics(
401
+ scores,
402
+ scores_t2t=None,
403
+ threshold=None,
404
+ rounding=2,
405
+ ):
406
+ n, m = scores.shape
407
+ assert n == m
408
+ num_queries = n
409
+
410
+ dists = -scores
411
+ sorted_dists = np.sort(dists, axis=1)
412
+ # GT is in the diagonal
413
+ gt_dists = np.diag(dists)[:, None]
414
+
415
+ if scores_t2t is not None and threshold is not None:
416
+ real_threshold = 2 * threshold - 1
417
+ idx = np.argwhere(scores_t2t > real_threshold)
418
+ partition = np.unique(idx[:, 0], return_index=True)[1]
419
+ # take as GT the minimum score of similar values
420
+ gt_dists = np.minimum.reduceat(dists[tuple(idx.T)], partition)
421
+ gt_dists = gt_dists[:, None]
422
+
423
+ rows, cols = np.where((sorted_dists - gt_dists) == 0) # find column position of GT
424
+
425
+ # if there are ties
426
+ if rows.size > num_queries:
427
+ assert np.unique(rows).size == num_queries, "issue in metric evaluation"
428
+ avg_cols = break_ties_average(sorted_dists, gt_dists)
429
+ cols = avg_cols
430
+
431
+ msg = "expected ranks to match queries ({} vs {}) "
432
+ assert cols.size == num_queries, msg
433
+
434
+ metrics = {}
435
+ vals = [str(x).zfill(2) for x in [1, 2, 3, 5, 10]]
436
+ for val in vals:
437
+ metrics[f"R{val}"] = 100 * float(np.sum(cols < int(val))) / num_queries
438
+
439
+ metrics["MedR"] = float(np.median(cols) + 1)
440
+ metrics["len"] = num_queries
441
+
442
+ if rounding is not None:
443
+ for key in metrics:
444
+ metrics[key] = round(metrics[key], rounding)
445
+ return metrics
446
+
447
+
448
+ def break_ties_average(sorted_dists, gt_dists):
449
+ # fast implementation, based on this code:
450
+ # https://stackoverflow.com/a/49239335
451
+ locs = np.argwhere((sorted_dists - gt_dists) == 0)
452
+
453
+ # Find the split indices
454
+ steps = np.diff(locs[:, 0])
455
+ splits = np.nonzero(steps)[0] + 1
456
+ splits = np.insert(splits, 0, 0)
457
+
458
+ # Compute the result columns
459
+ summed_cols = np.add.reduceat(locs[:, 1], splits)
460
+ counts = np.diff(np.append(splits, locs.shape[0]))
461
+ avg_cols = summed_cols / counts
462
+ return avg_cols
463
+
464
+
465
+ def calculate_activation_statistics(activations):
466
+ """
467
+ Params:
468
+ -- activation: num_samples x dim_feat
469
+ Returns:
470
+ -- mu: dim_feat
471
+ -- sigma: dim_feat x dim_feat
472
+ """
473
+ mu = np.mean(activations, axis=0)
474
+ cov = np.cov(activations, rowvar=False)
475
+ return mu, cov
476
+
477
+
478
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
479
+ """Numpy implementation of the Frechet Distance. The Frechet distance between two multivariate
480
+ Gaussians X_1 ~ N(mu_1, C_1)
481
+
482
+ and X_2 ~ N(mu_2, C_2) is
483
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
484
+ Stable version by Dougal J. Sutherland.
485
+ Params:
486
+ -- mu1 : Numpy array containing the activations of a layer of the
487
+ inception net (like returned by the function 'get_predictions')
488
+ for generated samples.
489
+ -- mu2 : The sample mean over activations, precalculated on an
490
+ representative dataset set.
491
+ -- sigma1: The covariance matrix over activations for generated samples.
492
+ -- sigma2: The covariance matrix over activations, precalculated on an
493
+ representative dataset set.
494
+ Returns:
495
+ -- : The Frechet Distance.
496
+ """
497
+
498
+ mu1 = np.atleast_1d(mu1)
499
+ mu2 = np.atleast_1d(mu2)
500
+
501
+ sigma1 = np.atleast_2d(sigma1)
502
+ sigma2 = np.atleast_2d(sigma2)
503
+
504
+ assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
505
+ assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
506
+
507
+ diff = mu1 - mu2
508
+
509
+ # Product might be almost singular
510
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
511
+ if not np.isfinite(covmean).all():
512
+ msg = ("fid calculation produces singular product; " "adding %s to diagonal of cov estimates") % eps
513
+ print(msg)
514
+ offset = np.eye(sigma1.shape[0]) * eps
515
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
516
+
517
+ # Numerical error might give slight imaginary component
518
+ if np.iscomplexobj(covmean):
519
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
520
+ # try again with diagonal %s
521
+ offset = np.eye(sigma1.shape[0]) * eps
522
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
523
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
524
+ m = np.max(np.abs(covmean.imag))
525
+ raise ValueError("Imaginary component {}".format(m))
526
+ covmean = covmean.real
527
+
528
+ tr_covmean = np.trace(covmean)
529
+
530
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
kimodo/model/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Kimodo model package: main model class, text encoders, and loading utilities."""
4
+
5
+ from .common import resolve_target
6
+ from .kimodo_model import Kimodo
7
+ from .llm2vec import LLM2VecEncoder
8
+ from .load_model import load_model
9
+ from .loading import (
10
+ AVAILABLE_MODELS,
11
+ DEFAULT_MODEL,
12
+ DEFAULT_TEXT_ENCODER_URL,
13
+ MODEL_NAMES,
14
+ load_checkpoint_state_dict,
15
+ )
16
+ from .tmr import TMR
17
+ from .twostage_denoiser import TwostageDenoiser
18
+
19
+ __all__ = [
20
+ "Kimodo",
21
+ "LLM2VecEncoder",
22
+ "TMR",
23
+ "TwostageDenoiser",
24
+ "load_model",
25
+ "load_checkpoint_state_dict",
26
+ "resolve_target",
27
+ "AVAILABLE_MODELS",
28
+ "DEFAULT_MODEL",
29
+ "DEFAULT_TEXT_ENCODER_URL",
30
+ "MODEL_NAMES",
31
+ ]
kimodo/model/backbone.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Transformer backbone: padding, masking, and encoder stack for the denoiser."""
4
+
5
+ import logging
6
+ from typing import Optional, Union
7
+
8
+ import torch
9
+ from omegaconf import ListConfig
10
+ from pydantic.dataclasses import dataclass
11
+ from torch import Tensor, nn
12
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
13
+
14
+ from kimodo.tools import validate
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ def pad_x_and_mask_to_fixed_size(x: Tensor, mask: Tensor, size: int):
20
+ """Pad a feature vector x and the mask to always have the same size.
21
+
22
+ Args:
23
+ x (torch.Tensor): [B, T, D]
24
+ mask (torch.Tensor): [B, T]
25
+ size (int)
26
+ Returns:
27
+ torch.Tensor: [B, size, D]
28
+ torch.Tensor: [B, size]
29
+ """
30
+
31
+ batch_size, cur_max_size, dim = x.shape[0], x.shape[1], x.shape[2]
32
+
33
+ if cur_max_size == size:
34
+ # already padded to this size, probably in the collate function
35
+ return x, mask
36
+
37
+ if cur_max_size > size:
38
+ # This issue should have been handled in the collate function
39
+ # usefull as a check for test time
40
+ log.warn("The size of the tensor is larger than the maximum size. Cropping the input..")
41
+ cur_max_size = size
42
+
43
+ new_x = torch.zeros(
44
+ (batch_size, size, dim),
45
+ dtype=x.dtype,
46
+ device=x.device,
47
+ )
48
+ new_x[:, :cur_max_size] = x
49
+
50
+ # same for the mask
51
+ new_mask = torch.zeros(
52
+ (batch_size, size),
53
+ dtype=mask.dtype,
54
+ device=mask.device,
55
+ )
56
+ new_mask[:, :cur_max_size] = mask
57
+ return new_x, new_mask
58
+
59
+
60
+ @dataclass(frozen=True, config=dict(extra="forbid", arbitrary_types_allowed=True))
61
+ class TransformerEncoderBlockConfig:
62
+ """Configuration for the transformer encoder backbone."""
63
+
64
+ # input features dimension
65
+ input_dim: int
66
+ # output features dimension
67
+ output_dim: int
68
+
69
+ # skeleton object
70
+ skeleton: object
71
+
72
+ # dimension of the text embeddings
73
+ llm_shape: Union[list[int], ListConfig]
74
+
75
+ # mask the text or not
76
+ use_text_mask: bool
77
+
78
+ # latent dimension of the model
79
+ latent_dim: int
80
+ # dimension of the feedforward network in transformer
81
+ ff_size: int
82
+ # num layers in transformer
83
+ num_layers: int
84
+ # num heads in transformer
85
+ num_heads: int
86
+ # activation in transformer
87
+ activation: str
88
+ # dropout rate for the transformer
89
+ dropout: float
90
+ # dropout rate for the positional embeddings
91
+ pe_dropout: float
92
+ # use norm first or not
93
+ norm_first: bool = False
94
+ # artificially extend the number of text tokens
95
+ num_text_tokens_override: Optional[int] = None
96
+
97
+ # Input first heading angle
98
+ input_first_heading_angle: bool = False
99
+
100
+
101
+ class TransformerEncoderBlock(nn.Module):
102
+ @validate(TransformerEncoderBlockConfig, save_args=True, super_init=True)
103
+ def __init__(self, conf):
104
+ self.nbjoints = self.skeleton.nbjoints
105
+ llm_dim = self.llm_shape[-1]
106
+ self.embed_text = nn.Linear(llm_dim, self.latent_dim)
107
+
108
+ self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.pe_dropout)
109
+
110
+ # maximum number of tokens
111
+ self.num_text_tokens = self.llm_shape[0]
112
+ if self.num_text_tokens_override is not None:
113
+ self.num_text_tokens = self.num_text_tokens_override
114
+
115
+ self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)
116
+
117
+ self.input_linear = nn.Linear(self.input_dim, self.latent_dim)
118
+ self.output_linear = nn.Linear(self.latent_dim, self.output_dim)
119
+ self.linear_first_heading_angle = nn.Linear(2, self.latent_dim)
120
+
121
+ trans_enc_layer = TransformerEncoderLayer(
122
+ d_model=self.latent_dim,
123
+ nhead=self.num_heads,
124
+ dim_feedforward=self.ff_size,
125
+ dropout=self.dropout,
126
+ activation=self.activation,
127
+ batch_first=True,
128
+ norm_first=self.norm_first,
129
+ )
130
+ self.seqTransEncoder = TransformerEncoder(
131
+ trans_enc_layer,
132
+ num_layers=self.num_layers,
133
+ enable_nested_tensor=False,
134
+ )
135
+
136
+ def forward(
137
+ self,
138
+ x: Tensor,
139
+ x_pad_mask: torch.Tensor,
140
+ text_feat: torch.Tensor,
141
+ text_feat_pad_mask: torch.Tensor,
142
+ timesteps: Tensor,
143
+ first_heading_angle: Optional[Tensor] = None,
144
+ ) -> Tensor:
145
+ """
146
+ Args:
147
+ x (torch.Tensor): [B, T, dim_motion] current noisy motion
148
+ x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not
149
+ text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts
150
+ text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not
151
+ timesteps (torch.Tensor): [B,] current denoising step
152
+
153
+ Returns:
154
+ torch.Tensor: [B, T, output_dim]
155
+ """
156
+ batch_size = len(x)
157
+ x = self.input_linear(x) # [B, T, D]
158
+
159
+ # Pad the text tokens + mask to always have the same size == self.num_text_tokens
160
+ # done here if it was not done in the collate function
161
+ if self.num_text_tokens is not None:
162
+ text_feat, text_feat_pad_mask = pad_x_and_mask_to_fixed_size(
163
+ text_feat,
164
+ text_feat_pad_mask,
165
+ self.num_text_tokens,
166
+ )
167
+
168
+ # Encode the text features and the time information
169
+ emb_text = self.embed_text(text_feat) # [B, max_text_len, D]
170
+ emb_time = self.embed_timestep(timesteps) # [B, 1, D]
171
+
172
+ # Create mask for the time information
173
+ time_mask = torch.ones((batch_size, 1), dtype=bool, device=x.device)
174
+
175
+ # Create the prefix features (text, time, etc): [B, max_text_len + 1 + etc]
176
+ prefix_feats = torch.cat((emb_text, emb_time), axis=1)
177
+
178
+ # Behavior from old code: not use text mask -> True for all the tokens
179
+ if not self.use_text_mask:
180
+ text_feat_pad_mask = torch.ones(
181
+ (batch_size, emb_text.shape[1]),
182
+ dtype=torch.bool,
183
+ device=x.device,
184
+ )
185
+
186
+ prefix_mask = torch.cat((text_feat_pad_mask, time_mask), axis=1)
187
+
188
+ # add the input first heading angle
189
+ if self.input_first_heading_angle:
190
+ assert first_heading_angle is not None, "The first heading angle is mandatory for this model"
191
+ # cos(angle) / sin(angle)
192
+ first_heading_angle_feats = torch.stack(
193
+ [
194
+ torch.cos(first_heading_angle),
195
+ torch.sin(first_heading_angle),
196
+ ],
197
+ axis=-1,
198
+ )
199
+
200
+ first_heading_angle_feats = self.linear_first_heading_angle(first_heading_angle_feats)
201
+ first_heading_angle_feats = first_heading_angle_feats[:, None] # for cat
202
+ first_heading_angle_mask = torch.ones(
203
+ (batch_size, 1),
204
+ dtype=bool,
205
+ device=x.device,
206
+ )
207
+ prefix_feats = torch.cat((prefix_feats, first_heading_angle_feats), axis=1)
208
+ prefix_mask = torch.cat((prefix_mask, first_heading_angle_mask), axis=1)
209
+
210
+ # compute the number of prefix features
211
+ pose_start_ind = prefix_feats.shape[1]
212
+
213
+ # Concatenate prefix and x: [B, len(prefix) + T, D]
214
+ xseq = torch.cat((prefix_feats, x), axis=1)
215
+
216
+ # Concatenate the masks and negate them: [B, len(prefix) + T]
217
+ src_key_padding_mask = ~torch.cat((prefix_mask, x_pad_mask), axis=1)
218
+
219
+ # Add positional encoding
220
+ xseq = self.sequence_pos_encoder(xseq)
221
+
222
+ # Input to the transformer and keep the motion indexes
223
+ if isinstance(self.seqTransEncoder, nn.TransformerEncoder):
224
+ assert not self.seqTransEncoder.use_nested_tensor, "Flash attention should be disabled due to bug!"
225
+
226
+ output = self.seqTransEncoder(
227
+ xseq,
228
+ src_key_padding_mask=src_key_padding_mask,
229
+ )
230
+ output = output[:, pose_start_ind:] # [B, T, D]
231
+ output = self.output_linear(output) # [B, T, OD]
232
+ return output
233
+
234
+
235
+ class PositionalEncoding(nn.Module):
236
+ """Non-learned positional encoding."""
237
+
238
+ def __init__(
239
+ self,
240
+ d_model: int,
241
+ dropout: Optional[float] = 0.1,
242
+ max_len: Optional[int] = 5000,
243
+ ):
244
+ """
245
+ Args:
246
+ d_model (int): input dim
247
+ dropout (Optional[float] = 0.1): dropout probability on output
248
+ max_len (Optional[int] = 5000): maximum sequence length
249
+ """
250
+ super(PositionalEncoding, self).__init__()
251
+ self.dropout = nn.Dropout(p=dropout)
252
+
253
+ pe = torch.zeros(max_len, d_model)
254
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
255
+
256
+ # Note: have to replace torch.exp() and math.log() with torch.pow()
257
+ # due to MKL exp() and ln() throws floating point exceptions on certain CPUs
258
+ # see corresponding commit and MR
259
+ div_term = torch.pow(10000.0, -torch.arange(0, d_model, 2).float() / d_model)
260
+ # div_term = torch.exp(
261
+ # torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
262
+ # )
263
+
264
+ pe[:, 0::2] = torch.sin(position * div_term)
265
+ pe[:, 1::2] = torch.cos(position * div_term)
266
+ pe = pe.unsqueeze(0) # [1, T, D]
267
+
268
+ self.register_buffer("pe", pe, persistent=False)
269
+
270
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
271
+ """Apply positional encoding to input sequence.
272
+
273
+ Args:
274
+ x (torch.Tensor): [B, T, D] input motion sequence
275
+
276
+ Returns:
277
+ torch.Tensor: [B, T, D] input motion with PE added to it (and optionally dropout)
278
+ """
279
+ x = x + self.pe[:, : x.shape[1], :]
280
+ return self.dropout(x)
281
+
282
+
283
+ class TimestepEmbedder(nn.Module):
284
+ """Encoder for diffusion step."""
285
+
286
+ def __init__(self, latent_dim: int, sequence_pos_encoder: PositionalEncoding):
287
+ """
288
+ Args:
289
+ latent_dim (int): dim to encode to
290
+ sequence_pos_encoder (PositionalEncoding): the PE to use on timesteps
291
+ """
292
+ super().__init__()
293
+ self.latent_dim = latent_dim
294
+ self.sequence_pos_encoder = sequence_pos_encoder
295
+
296
+ time_embed_dim = self.latent_dim
297
+ self.time_embed = nn.Sequential(
298
+ nn.Linear(self.latent_dim, time_embed_dim),
299
+ nn.SiLU(),
300
+ nn.Linear(time_embed_dim, time_embed_dim),
301
+ )
302
+
303
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
304
+ """Embed timesteps by adding PE then going through linear layers.
305
+
306
+ Args:
307
+ timesteps (torch.Tensor): [B]
308
+
309
+ Returns:
310
+ torch.Tensor: [B, 1, D]
311
+ """
312
+ return self.time_embed(self.sequence_pos_encoder.pe.transpose(0, 1)[timesteps])
kimodo/model/cfg.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Classifier-free guidance wrapper for the denoiser at sampling time."""
4
+
5
+ from typing import Dict, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ CFG_TYPES = ["nocfg", "regular", "separated"]
11
+
12
+
13
+ class ClassifierFreeGuidedModel(nn.Module):
14
+ """Wrapper around denoiser to use classifier-free guidance at sampling time."""
15
+
16
+ def __init__(self, model: nn.Module, cfg_type: Optional[str] = "separated"):
17
+ """Wrap the denoiser for classifier-free guidance; cfg_type in CFG_TYPES (e.g. 'regular',
18
+ 'nocfg')."""
19
+ super().__init__()
20
+ self.model = model
21
+ assert cfg_type in CFG_TYPES, f"Invalid cfg_type: {cfg_type}"
22
+ self.cfg_type_default = cfg_type
23
+
24
+ def forward(
25
+ self,
26
+ cfg_weight: Union[float, Tuple[float, float]],
27
+ x: torch.Tensor,
28
+ x_pad_mask: torch.Tensor,
29
+ text_feat: torch.Tensor,
30
+ text_feat_pad_mask: torch.Tensor,
31
+ timesteps: torch.Tensor,
32
+ first_heading_angle: Optional[torch.Tensor] = None,
33
+ motion_mask: Optional[torch.Tensor] = None,
34
+ observed_motion: Optional[torch.Tensor] = None,
35
+ cfg_type: Optional[str] = None,
36
+ ) -> torch.Tensor:
37
+ """
38
+ Args:
39
+ cfg_weight (float): guidance weight float or tuple of floats with (text, constraint) weights if using separated cfg
40
+ x (torch.Tensor): [B, T, dim_motion] current noisy motion
41
+ x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not
42
+ text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts
43
+ text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not
44
+ timesteps (torch.Tensor): [B,] current denoising step
45
+ motion_mask
46
+ observed_motion
47
+ neutral_joints (torch.Tensor): [B, nbjoints] The neutral joints of the motions
48
+
49
+ Returns:
50
+ torch.Tensor: same size as input x
51
+ """
52
+
53
+ if cfg_type is None:
54
+ cfg_type = self.cfg_type_default
55
+
56
+ assert cfg_type in CFG_TYPES, f"Invalid cfg_type: {cfg_type}"
57
+
58
+ # batched conditional and uncond pass together
59
+ if cfg_type == "nocfg":
60
+ return self.model(
61
+ x,
62
+ x_pad_mask,
63
+ text_feat,
64
+ text_feat_pad_mask,
65
+ timesteps,
66
+ first_heading_angle=first_heading_angle,
67
+ motion_mask=motion_mask,
68
+ observed_motion=observed_motion,
69
+ )
70
+ elif cfg_type == "regular":
71
+ assert isinstance(cfg_weight, (float, int)), "cfg_weight must be a single float for regular CFG"
72
+ # out_uncond + w * (out_text_and_constraint - out_uncond)
73
+ text_feat = torch.concatenate([text_feat, 0 * text_feat], dim=0)
74
+ if motion_mask is not None:
75
+ motion_mask = torch.concatenate([motion_mask, 0 * motion_mask], dim=0)
76
+ if observed_motion is not None:
77
+ observed_motion = torch.concatenate([observed_motion, observed_motion], dim=0)
78
+ if first_heading_angle is not None:
79
+ first_heading_angle = torch.concatenate([first_heading_angle, first_heading_angle], dim=0)
80
+
81
+ out_cond_uncond = self.model(
82
+ torch.concatenate([x, x], dim=0),
83
+ torch.concatenate([x_pad_mask, x_pad_mask], dim=0),
84
+ text_feat,
85
+ torch.concatenate([text_feat_pad_mask, False * text_feat_pad_mask], dim=0),
86
+ torch.concatenate([timesteps, timesteps], dim=0),
87
+ first_heading_angle=first_heading_angle,
88
+ motion_mask=motion_mask,
89
+ observed_motion=observed_motion,
90
+ )
91
+
92
+ out, out_uncond = torch.chunk(out_cond_uncond, 2)
93
+ out_new = out_uncond + (cfg_weight * (out - out_uncond))
94
+ elif cfg_type == "separated":
95
+ assert len(cfg_weight) == 2, "cfg_weight must be a tuple of two floats for separated CFG"
96
+ # out_uncond + w_text * (out_text - out_uncond) + w_constraint * (out_constraint - out_uncond)
97
+ text_feat = torch.concatenate([text_feat, 0 * text_feat, 0 * text_feat], dim=0)
98
+ if motion_mask is not None:
99
+ motion_mask = torch.concatenate([0 * motion_mask, motion_mask, 0 * motion_mask], dim=0)
100
+ if observed_motion is not None:
101
+ observed_motion = torch.concatenate([observed_motion, observed_motion, observed_motion], dim=0)
102
+ if first_heading_angle is not None:
103
+ first_heading_angle = torch.concatenate(
104
+ [first_heading_angle, first_heading_angle, first_heading_angle],
105
+ dim=0,
106
+ )
107
+
108
+ out_cond_uncond = self.model(
109
+ torch.concatenate([x, x, x], dim=0),
110
+ torch.concatenate([x_pad_mask, x_pad_mask, x_pad_mask], dim=0),
111
+ text_feat,
112
+ torch.concatenate(
113
+ [
114
+ text_feat_pad_mask,
115
+ False * text_feat_pad_mask,
116
+ False * text_feat_pad_mask,
117
+ ],
118
+ dim=0,
119
+ ),
120
+ torch.concatenate([timesteps, timesteps, timesteps], dim=0),
121
+ first_heading_angle=first_heading_angle,
122
+ motion_mask=motion_mask,
123
+ observed_motion=observed_motion,
124
+ )
125
+
126
+ out_text, out_constraint, out_uncond = torch.chunk(out_cond_uncond, 3)
127
+ out_new = (
128
+ out_uncond + (cfg_weight[0] * (out_text - out_uncond)) + (cfg_weight[1] * (out_constraint - out_uncond))
129
+ )
130
+ else:
131
+ raise ValueError(f"Invalid cfg_type: {cfg_type}")
132
+
133
+ return out_new
kimodo/model/common.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Config hydration: env vars, _target_ resolution, and recursive instantiation."""
4
+
5
+ import importlib
6
+ import os
7
+
8
+
9
+ def get_env_var(name: str, default=None):
10
+ """Read env var by name and by lowercased name; return default if neither set."""
11
+ return os.getenv(name, os.getenv(name.lower(), default))
12
+
13
+
14
+ def resolve_target(target: str):
15
+ """Import module and return the attribute named by a dotted path (e.g. 'pkg.mod.Class')."""
16
+ module_name, attr_name = target.rsplit(".", 1)
17
+ module = importlib.import_module(module_name)
18
+ return getattr(module, attr_name)
19
+
20
+
21
+ def materialize_value(value):
22
+ """Recursively turn dicts with '_target_' into instances; lists/dicts traversed; leaves
23
+ unchanged."""
24
+ if isinstance(value, dict):
25
+ if "_target_" in value:
26
+ return instantiate_from_dict(value)
27
+ return {k: materialize_value(v) for k, v in value.items()}
28
+ if isinstance(value, list):
29
+ return [materialize_value(v) for v in value]
30
+ return value
31
+
32
+
33
+ def instantiate_from_dict(node, overrides=None):
34
+ """Build an instance from a config dict: '_target_' gives the class, other keys are kwargs; overrides merged in."""
35
+ if not isinstance(node, dict) or "_target_" not in node:
36
+ raise ValueError("Config node must be a dict with a '_target_' key.")
37
+
38
+ target = resolve_target(node["_target_"])
39
+ kwargs = {}
40
+ for key, value in node.items():
41
+ if key == "_target_":
42
+ continue
43
+ kwargs[key] = materialize_value(value)
44
+
45
+ if overrides:
46
+ kwargs.update({k: v for k, v in overrides.items() if v is not None})
47
+
48
+ return target(**kwargs)
kimodo/model/diffusion.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Diffusion process and DDIM sampling for motion generation."""
4
+
5
+ import math
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+
12
+ def get_beta_schedule(
13
+ num_diffusion_timesteps: int,
14
+ max_beta: Optional[float] = 0.999,
15
+ ) -> torch.Tensor:
16
+ """Get cosine beta schedule."""
17
+
18
+ def alpha_bar(t):
19
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
20
+
21
+ betas = []
22
+ for i in range(num_diffusion_timesteps):
23
+ t1 = i / num_diffusion_timesteps
24
+ t2 = (i + 1) / num_diffusion_timesteps
25
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
26
+ return torch.tensor(betas, dtype=torch.float)
27
+
28
+
29
+ class Diffusion(torch.nn.Module):
30
+ """Cosine-schedule diffusion process: betas, alphas, and DDIM step mapping."""
31
+
32
+ def __init__(self, num_base_steps: int):
33
+ """Set up cosine beta schedule and precompute diffusion variables for num_base_steps."""
34
+ super().__init__()
35
+ self.num_base_steps = num_base_steps
36
+ betas_base = get_beta_schedule(self.num_base_steps)
37
+ self.register_buffer("betas_base", betas_base, persistent=False)
38
+ alphas_cumprod_base = torch.cumprod(1.0 - self.betas_base, dim=0)
39
+ self.register_buffer("alphas_cumprod_base", alphas_cumprod_base, persistent=False)
40
+ use_timesteps, _ = self.space_timesteps(self.num_base_steps)
41
+ self.calc_diffusion_vars(use_timesteps)
42
+
43
+ def extra_repr(self) -> str:
44
+ return f"num_base_steps={self.num_base_steps}"
45
+
46
+ @property
47
+ def device(self):
48
+ return self.betas_base.device
49
+
50
+ def space_timesteps(self, num_denoising_steps: int) -> Tuple[torch.Tensor, torch.Tensor]:
51
+ """Return (use_timesteps, map_tensor) for a subsampled denoising schedule of
52
+ num_denoising_steps."""
53
+ nsteps_train = self.num_base_steps
54
+ frac_stride = (nsteps_train - 1) / max(1, num_denoising_steps - 1)
55
+ use_timesteps = torch.round(torch.arange(nsteps_train, device=self.device) * frac_stride).to(torch.long)
56
+ use_timesteps = torch.clamp(use_timesteps, max=nsteps_train - 1)
57
+ map_tensor = torch.arange(nsteps_train, device=self.device, dtype=torch.long)[use_timesteps]
58
+ return use_timesteps, map_tensor
59
+
60
+ def calc_diffusion_vars(self, use_timesteps: torch.Tensor) -> None:
61
+ """Update buffers (betas, alphas, alphas_cumprod, etc.) for the given subsampled
62
+ timesteps."""
63
+ alphas_cumprod = self.alphas_cumprod_base[use_timesteps]
64
+ last_alpha_cumprod = torch.cat([torch.tensor([1.0]).to(alphas_cumprod), alphas_cumprod[:-1]])
65
+ betas = 1.0 - alphas_cumprod / last_alpha_cumprod
66
+ self.register_buffer("betas", betas, persistent=False)
67
+
68
+ alphas = 1.0 - self.betas
69
+ self.register_buffer("alphas", alphas, persistent=False)
70
+ alphas_cumprod = torch.cumprod(self.alphas, dim=0)
71
+ alphas_cumprod = torch.clamp(alphas_cumprod, min=1e-9)
72
+ self.register_buffer("alphas_cumprod", alphas_cumprod, persistent=False)
73
+
74
+ alphas_cumprod_prev = torch.cat([torch.tensor([1.0]).to(self.alphas_cumprod), self.alphas_cumprod[:-1]])
75
+ self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev, persistent=False)
76
+
77
+ sqrt_recip_alphas_cumprod = torch.rsqrt(self.alphas_cumprod)
78
+ self.register_buffer("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod, persistent=False)
79
+
80
+ sqrt_recipm1_alphas_cumprod = torch.rsqrt(self.alphas_cumprod / (1.0 - self.alphas_cumprod))
81
+ self.register_buffer("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod, persistent=False)
82
+
83
+ posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
84
+ self.register_buffer("posterior_variance", posterior_variance, persistent=False)
85
+
86
+ sqrt_alphas_cumprod = torch.rsqrt(1.0 / self.alphas_cumprod)
87
+ self.register_buffer("sqrt_alphas_cumprod", sqrt_alphas_cumprod, persistent=False)
88
+
89
+ sqrt_one_minus_alphas_cumprod = torch.rsqrt(1.0 / (1.0 - self.alphas_cumprod))
90
+ self.register_buffer(
91
+ "sqrt_one_minus_alphas_cumprod",
92
+ sqrt_one_minus_alphas_cumprod,
93
+ persistent=False,
94
+ )
95
+
96
+ def q_sample(
97
+ self,
98
+ x_start: torch.Tensor,
99
+ t: torch.Tensor,
100
+ noise: torch.Tensor = None,
101
+ ):
102
+ if noise is None:
103
+ noise = torch.randn_like(x_start)
104
+ assert noise.shape == x_start.shape
105
+
106
+ xt = (
107
+ self.sqrt_alphas_cumprod[t, None, None] * x_start
108
+ + self.sqrt_one_minus_alphas_cumprod[t, None, None] * noise
109
+ )
110
+ return xt
111
+
112
+
113
+ class DDIMSampler(nn.Module):
114
+ """Deterministic DDIM sampler (eta = 0)."""
115
+
116
+ def __init__(self, diffusion: Diffusion):
117
+ super().__init__()
118
+ self.diffusion = diffusion
119
+
120
+ def __call__(
121
+ self,
122
+ use_timesteps: torch.Tensor,
123
+ x_t: torch.Tensor,
124
+ pred_xstart: torch.Tensor,
125
+ t: torch.Tensor,
126
+ ) -> torch.Tensor:
127
+ self.diffusion.calc_diffusion_vars(use_timesteps)
128
+ eps = (
129
+ self.diffusion.sqrt_recip_alphas_cumprod[t, None, None] * x_t - pred_xstart
130
+ ) / self.diffusion.sqrt_recipm1_alphas_cumprod[t, None, None]
131
+ alpha_bar_prev = self.diffusion.alphas_cumprod_prev[t, None, None]
132
+ x = pred_xstart * torch.sqrt(alpha_bar_prev) + torch.sqrt(1 - alpha_bar_prev) * eps
133
+ return x
kimodo/model/kimodo_model.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Kimodo model: denoiser, text encoder, diffusion sampling, and post-processing."""
4
+
5
+ import logging
6
+ from typing import Dict, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ from torch import nn
10
+ from tqdm.auto import tqdm
11
+
12
+ from kimodo.constraints import FullBodyConstraintSet
13
+ from kimodo.motion_rep.feature_utils import compute_heading_angle, length_to_mask
14
+ from kimodo.postprocess import post_process_motion
15
+ from kimodo.sanitize import sanitize_texts
16
+ from kimodo.skeleton import SOMASkeleton30
17
+ from kimodo.tools import to_numpy
18
+
19
+ from .cfg import ClassifierFreeGuidedModel
20
+ from .diffusion import DDIMSampler, Diffusion
21
+
22
+ log = logging.getLogger(__name__)
23
+
24
+
25
+ class Kimodo(nn.Module):
26
+ """Helper class for test time."""
27
+
28
+ def __init__(
29
+ self,
30
+ denoiser: nn.Module,
31
+ text_encoder: nn.Module,
32
+ num_base_steps: int,
33
+ device: Optional[Union[str, torch.device]] = None,
34
+ cfg_type: Optional[str] = "separated",
35
+ ):
36
+ super().__init__()
37
+
38
+ self.denoiser = denoiser.eval()
39
+
40
+ if cfg_type is None:
41
+ cfg_type = "nocfg"
42
+
43
+ # Add Classifier-free guidance to the model if needed
44
+ self.denoiser = ClassifierFreeGuidedModel(self.denoiser, cfg_type=cfg_type)
45
+
46
+ self.motion_rep = denoiser.motion_rep
47
+ self.skeleton = self.motion_rep.skeleton
48
+
49
+ self.fps = denoiser.motion_rep.fps
50
+
51
+ self.diffusion = Diffusion(num_base_steps=num_base_steps)
52
+ self.sampler = DDIMSampler(self.diffusion)
53
+ self.text_encoder = text_encoder
54
+
55
+ self.device = device
56
+ # for classifier-free guidance
57
+
58
+ self.to(device)
59
+
60
+ @property
61
+ def output_skeleton(self):
62
+ """Skeleton used for model output (somaskel77 for SOMA, else unchanged)."""
63
+ if isinstance(self.skeleton, SOMASkeleton30):
64
+ return self.skeleton.somaskel77
65
+ return self.skeleton
66
+
67
+ def train(self, mode: bool):
68
+ self.denoiser.train(mode)
69
+ return self
70
+
71
+ def eval(self):
72
+ self.denoiser.eval()
73
+ return self
74
+
75
+ def denoising_step(
76
+ self,
77
+ motion: torch.Tensor,
78
+ pad_mask: torch.Tensor,
79
+ text_feat: torch.Tensor,
80
+ text_pad_mask: torch.Tensor,
81
+ t: torch.Tensor,
82
+ first_heading_angle: Optional[torch.Tensor],
83
+ motion_mask: torch.Tensor,
84
+ observed_motion: torch.Tensor,
85
+ num_denoising_steps: torch.Tensor,
86
+ cfg_weight: Union[float, Tuple[float, float]],
87
+ guide_masks: Optional[Dict] = None,
88
+ cfg_type: Optional[str] = None,
89
+ ) -> torch.Tensor:
90
+ """Single denoising step.
91
+
92
+ Returns:
93
+ torch.Tensor: [B, T, D] noisy motion input to t-1
94
+ """
95
+ # subsample timesteps
96
+ # NOTE: do this at every step due to ONNX export, i.e. num_samp_stepsmay change dynamically when
97
+ # running onnx version so need to account for that.
98
+ num_denoising_steps = num_denoising_steps[0]
99
+ use_timesteps, map_tensor = self.diffusion.space_timesteps(num_denoising_steps)
100
+ self.diffusion.calc_diffusion_vars(use_timesteps)
101
+
102
+ # first compute initial clean prediction from denoiser
103
+ t_map = map_tensor[t]
104
+
105
+ with torch.inference_mode():
106
+ pred_clean = self.denoiser(
107
+ cfg_weight,
108
+ motion,
109
+ pad_mask,
110
+ text_feat,
111
+ text_pad_mask,
112
+ t_map,
113
+ first_heading_angle,
114
+ motion_mask,
115
+ observed_motion,
116
+ cfg_type=cfg_type,
117
+ )
118
+
119
+ # sampler computes next step noisy motion
120
+ x_tm1 = self.sampler(use_timesteps, motion, pred_clean, t)
121
+ return x_tm1
122
+
123
+ def _multiprompt(
124
+ self,
125
+ prompts: list[str],
126
+ num_frames: int | list[int],
127
+ num_denoising_steps: int,
128
+ constraint_lst: Optional[list] = [],
129
+ cfg_weight: Optional[float] = [2.0, 2.0],
130
+ num_samples: Optional[int] = None,
131
+ cfg_type: Optional[str] = None,
132
+ return_numpy: bool = False,
133
+ first_heading_angle: Optional[torch.Tensor] = None,
134
+ # for transitioning
135
+ num_transition_frames: int = 5,
136
+ share_transition: bool = True,
137
+ percentage_transition_override=0.10,
138
+ # for postprocess
139
+ post_processing: bool = False,
140
+ root_margin: float = 0.04,
141
+ # progress bar
142
+ progress_bar=tqdm,
143
+ ) -> torch.Tensor:
144
+ device = self.device
145
+
146
+ bs = num_samples
147
+ texts = sanitize_texts(prompts)
148
+
149
+ if isinstance(num_frames, int):
150
+ # same duration for all the segments
151
+ num_frames = [num_frames for _ in range(num_samples)]
152
+
153
+ tosqueeze = False
154
+ if num_samples is None:
155
+ num_samples = 1
156
+ tosqueeze = True
157
+
158
+ if constraint_lst is None:
159
+ constraint_lst = []
160
+
161
+ # Generate one chunck at a time
162
+ current_frame = 0
163
+ generated_motions = []
164
+
165
+ for idx, (text, num_frame) in enumerate(zip(texts, num_frames)):
166
+ texts_bs = [text for _ in range(num_samples)]
167
+
168
+ lengths = torch.tensor(
169
+ [num_frame for _ in range(num_samples)],
170
+ device=device,
171
+ )
172
+
173
+ is_first_motion = not generated_motions
174
+
175
+ observed_motion, motion_mask = None, None
176
+
177
+ # filter the constraint_lst to only keep the relevent ones
178
+ constraint_lst_base = [
179
+ constraint.crop_move(current_frame, current_frame + num_frame) for constraint in constraint_lst
180
+ ] # this move temporally but not spatially
181
+
182
+ observed_motion, motion_mask = self.motion_rep.create_conditions_from_constraints_batched(
183
+ constraint_lst_base,
184
+ lengths,
185
+ to_normalize=False, # don't normalize yet, it needs to be moved around
186
+ device=device,
187
+ )
188
+
189
+ if not is_first_motion:
190
+ prev_num_frame = num_frames[idx - 1]
191
+ if share_transition:
192
+ # starting the transitioning earlier, to "share" the transition between A and B
193
+ # in any case, we still use "num_transition_frames" for conditioning
194
+ # we don't condition until the end of A
195
+ # we compute the number of frames of transition as a percentage of the last motion
196
+ nb_transition_frames = num_transition_frames + int(prev_num_frame * percentage_transition_override)
197
+ else:
198
+ nb_transition_frames = num_transition_frames
199
+
200
+ latest_motions = generated_motions.pop()
201
+ # remove the transition part of A (will be put back afterward)
202
+ generated_motions.append(latest_motions[:, :-nb_transition_frames])
203
+ latest_frames = latest_motions[:, -nb_transition_frames:]
204
+ # latest_frames[..., 2] += 0.5
205
+
206
+ last_output = self.motion_rep.inverse(
207
+ latest_frames,
208
+ is_normalized=False,
209
+ return_numpy=False,
210
+ )
211
+ smooth_root_2d = last_output["smooth_root_pos"][..., [0, 2]]
212
+
213
+ # add constraints at the begining to allow natural transitions
214
+ constraint_lst_transition = []
215
+ for batch_id in range(bs):
216
+ new_constraint = FullBodyConstraintSet(
217
+ self.skeleton,
218
+ torch.arange(num_transition_frames),
219
+ last_output["posed_joints"][batch_id, :num_transition_frames],
220
+ last_output["local_rot_mats"][batch_id, :num_transition_frames],
221
+ smooth_root_2d[batch_id, :num_transition_frames],
222
+ )
223
+
224
+ # new lists
225
+ constraint_lst_transition.append([new_constraint])
226
+
227
+ transition_lengths = torch.tensor(
228
+ [nb_transition_frames for _ in range(num_samples)],
229
+ device=device,
230
+ )
231
+
232
+ observed_motion_transition, motion_mask_transition = (
233
+ self.motion_rep.create_conditions_from_constraints_batched(
234
+ constraint_lst_transition,
235
+ transition_lengths,
236
+ to_normalize=False, # don't normalize yet
237
+ device=device,
238
+ )
239
+ )
240
+
241
+ # concatenate the obversed motion / motion mask
242
+ observed_motion = torch.cat([observed_motion_transition, observed_motion], axis=1)
243
+ motion_mask = torch.cat([motion_mask_transition, motion_mask], axis=1)
244
+
245
+ # we need to move each observed motion in the batch to the new starting points
246
+ last_smooth_root_2d = smooth_root_2d[:, 0]
247
+ observed_motion = self.motion_rep.translate_2d(
248
+ observed_motion, -last_smooth_root_2d
249
+ ) # equivalent to: self.motion_rep.translate_2d_to_zero(observed_motion)
250
+
251
+ # remove dummy values after moving
252
+ observed_motion = observed_motion * motion_mask
253
+
254
+ lengths = lengths + transition_lengths
255
+ first_heading_angle = compute_heading_angle(last_output["posed_joints"], self.skeleton)[:, 0]
256
+ else:
257
+ if first_heading_angle is None:
258
+ # Start at 0 angle, but this will change afterward
259
+ first_heading_angle = torch.tensor([0.0] * bs, device=device)
260
+ else:
261
+ first_heading_angle = torch.as_tensor(first_heading_angle, device=device)
262
+ if first_heading_angle.numel() == 1:
263
+ first_heading_angle = first_heading_angle.repeat(bs)
264
+
265
+ observed_motion = self.motion_rep.normalize(observed_motion)
266
+
267
+ max_frames = max(lengths)
268
+ motion_pad_mask = length_to_mask(lengths)
269
+
270
+ motion = self._generate(
271
+ texts_bs,
272
+ max_frames,
273
+ num_denoising_steps=num_denoising_steps,
274
+ pad_mask=motion_pad_mask,
275
+ first_heading_angle=first_heading_angle,
276
+ motion_mask=motion_mask,
277
+ observed_motion=observed_motion,
278
+ cfg_weight=cfg_weight,
279
+ cfg_type=cfg_type,
280
+ )
281
+
282
+ motion = self.motion_rep.unnormalize(motion)
283
+
284
+ if not is_first_motion:
285
+ motion_with_transition = self.motion_rep.translate_2d(
286
+ motion,
287
+ last_smooth_root_2d,
288
+ )
289
+
290
+ motion = motion_with_transition[:, num_transition_frames:]
291
+ transition_frames = motion_with_transition[:, :num_transition_frames]
292
+ # for sharing = True, the new motion contains the very last of A
293
+
294
+ # linearly combine the previously generated transitions with the newly generated ones
295
+ # so that we linearly go from previous gen to new gen
296
+ alpha = torch.linspace(1, 0, num_transition_frames, device=device)[:, None]
297
+ new_transition_frames = (
298
+ latest_frames[:, :num_transition_frames] * alpha + (1 - alpha) * transition_frames
299
+ )
300
+
301
+ # add new transitions frames for A (merging with B predition of the history)
302
+ # for share_transition == True, this remove (do not add back) a small part of the end of A
303
+ # the small last part of A has been re-generated by B
304
+ generated_motions.append(new_transition_frames)
305
+
306
+ # motion[..., 2] += 0.5
307
+
308
+ generated_motions.append(motion)
309
+ current_frame += num_frame
310
+
311
+ generated_motions = torch.cat(generated_motions, axis=1) # temporal axis (b, t, d)
312
+
313
+ if tosqueeze:
314
+ generated_motions = generated_motions[0]
315
+
316
+ output = self.motion_rep.inverse(
317
+ generated_motions,
318
+ is_normalized=False,
319
+ return_numpy=False,
320
+ )
321
+
322
+ # Apply post-processing if requested
323
+ if post_processing:
324
+ corrected = post_process_motion(
325
+ output["local_rot_mats"],
326
+ output["root_positions"],
327
+ output["foot_contacts"],
328
+ self.skeleton,
329
+ constraint_lst,
330
+ root_margin=root_margin,
331
+ )
332
+ output.update(corrected)
333
+
334
+ # Convert SOMA output to somaskel77 for external API
335
+ if isinstance(self.skeleton, SOMASkeleton30):
336
+ output = self.skeleton.output_to_SOMASkeleton77(output)
337
+
338
+ # Convert to numpy if requested
339
+ if return_numpy:
340
+ output = to_numpy(output)
341
+ return output
342
+
343
+ def __call__(
344
+ self,
345
+ prompts: str | list[str],
346
+ num_frames: int | list[int],
347
+ num_denoising_steps: int,
348
+ multi_prompt: bool = False,
349
+ constraint_lst: Optional[list] = [],
350
+ cfg_weight: Optional[float] = [2.0, 2.0],
351
+ num_samples: Optional[int] = None,
352
+ cfg_type: Optional[str] = None,
353
+ return_numpy: bool = False,
354
+ first_heading_angle: Optional[torch.Tensor] = None,
355
+ # for transitioning
356
+ num_transition_frames: int = 5,
357
+ share_transition: bool = True,
358
+ percentage_transition_override=0.10,
359
+ # for postprocess
360
+ post_processing: bool = False,
361
+ root_margin: float = 0.04,
362
+ # progress bar
363
+ progress_bar=tqdm,
364
+ ) -> dict:
365
+ """Generate motion from text prompts and optional kinematic constraints.
366
+
367
+ When a single prompt/num_frames pair is given, one motion is generated.
368
+ Passing lists of prompts and/or num_frames produces a batch of
369
+ independent motions. With ``multi_prompt=True``, the prompts are
370
+ treated as sequential segments that are generated and stitched together
371
+ with smooth transitions.
372
+
373
+ Args:
374
+ prompts: One or more text descriptions of the desired motion.
375
+ A single string generates one sample; a list generates a batch
376
+ (or sequential segments when ``multi_prompt=True``).
377
+ num_frames: Duration of the generated motion in frames. Can be a
378
+ single int applied to every prompt or a per-prompt list.
379
+ num_denoising_steps: Number of DDIM denoising steps. More steps
380
+ generally improve quality at the cost of speed.
381
+ multi_prompt: If ``True``, treat ``prompts`` as an ordered sequence
382
+ of segments and concatenate them with transitions.
383
+ constraint_lst: Per-sample list of kinematic constraints (e.g.
384
+ keyframe poses, end-effector targets, 2-D paths). Pass an
385
+ empty list for unconstrained generation.
386
+ cfg_weight: Classifier-free guidance scale(s). A two-element list
387
+ ``[text_cfg, constraint_cfg]`` controls text and constraint
388
+ guidance independently.
389
+ num_samples: Number of samples to generate.
390
+ cfg_type: Override the default CFG strategy set at init
391
+ (e.g. ``"separated"``).
392
+ return_numpy: If ``True``, convert all output tensors to numpy
393
+ arrays.
394
+ first_heading_angle: Initial body heading in radians. Shape
395
+ ``(B,)`` or scalar. Defaults to ``0`` (facing +Z).
396
+ num_transition_frames: Number of overlapping frames used to blend
397
+ consecutive segments in multi-prompt mode.
398
+ share_transition: If ``True``, transition frames are shared between
399
+ adjacent segments rather than appended.
400
+ percentage_transition_override: Fraction of each segment's length
401
+ that may be overridden by the transition blend.
402
+ post_processing: If ``True``, apply post-processing
403
+ (foot-skate cleanup and constraint enforcement).
404
+ root_margin: Horizontal margin (in meters) used by the post-processor
405
+ to determine when to correct root motion. When root deviates more than
406
+ margin from the constraint, the post-processor will correct it.
407
+ progress_bar: Callable wrapping an iterable to display progress
408
+ (default: ``tqdm``). Pass a no-op to silence output.
409
+
410
+ Returns:
411
+ dict: A dictionary of motion tensors (or numpy arrays if
412
+ ``return_numpy=True``) with the following keys:
413
+
414
+ - ``local_rot_mats`` – Local joint rotations as rotation matrices.
415
+ - ``global_rot_mats`` – Global joint rotations as rotation matrices.
416
+ - ``posed_joints`` – Joint positions in world space.
417
+ - ``root_positions`` – Root joint positions.
418
+ - ``smooth_root_pos`` – Smoothed root trajectory.
419
+ - ``foot_contacts`` – Boolean foot-contact labels [left heel, left toe, right heel, right toe].
420
+ - ``global_root_heading`` – Root heading angle over time.
421
+ """
422
+ device = self.device
423
+
424
+ if multi_prompt:
425
+ # multi prompt generation
426
+ return self._multiprompt(
427
+ prompts,
428
+ num_frames,
429
+ num_denoising_steps,
430
+ constraint_lst,
431
+ cfg_weight,
432
+ num_samples,
433
+ cfg_type,
434
+ return_numpy,
435
+ first_heading_angle,
436
+ num_transition_frames,
437
+ share_transition,
438
+ percentage_transition_override,
439
+ post_processing,
440
+ root_margin,
441
+ progress_bar,
442
+ )
443
+
444
+ # Input checking
445
+ tosqueeze = False
446
+ if isinstance(prompts, list) and isinstance(num_frames, list):
447
+ assert len(prompts) == len(num_frames), "The number of prompts should match the number of num_frames."
448
+ num_samples = len(prompts)
449
+ elif isinstance(prompts, list):
450
+ num_samples = len(prompts)
451
+ num_frames = [num_frames for _ in range(num_samples)]
452
+ elif isinstance(num_frames, list):
453
+ num_samples = len(num_frames)
454
+ prompts = [prompts for _ in range(num_samples)]
455
+ else:
456
+ if num_samples is None:
457
+ tosqueeze = True
458
+ num_samples = 1
459
+ prompts = [prompts for _ in range(num_samples)]
460
+ num_frames = [num_frames for _ in range(num_samples)]
461
+
462
+ bs = num_samples
463
+ texts = sanitize_texts(prompts)
464
+
465
+ lengths = torch.tensor(
466
+ num_frames,
467
+ device=device,
468
+ )
469
+ max_frames = max(lengths)
470
+ motion_pad_mask = length_to_mask(lengths)
471
+
472
+ if first_heading_angle is None:
473
+ # Start at 0 angle
474
+ first_heading_angle = torch.tensor([0.0] * bs, device=device)
475
+ else:
476
+ first_heading_angle = torch.as_tensor(first_heading_angle, device=device)
477
+ if first_heading_angle.numel() == 1:
478
+ first_heading_angle = first_heading_angle.repeat(bs)
479
+
480
+ observed_motion, motion_mask = None, None
481
+ if constraint_lst:
482
+ observed_motion, motion_mask = self.motion_rep.create_conditions_from_constraints_batched(
483
+ constraint_lst,
484
+ lengths,
485
+ to_normalize=True,
486
+ device=device,
487
+ )
488
+
489
+ motion = self._generate(
490
+ texts,
491
+ max_frames,
492
+ num_denoising_steps=num_denoising_steps,
493
+ pad_mask=motion_pad_mask,
494
+ first_heading_angle=first_heading_angle,
495
+ motion_mask=motion_mask,
496
+ observed_motion=observed_motion,
497
+ cfg_weight=cfg_weight,
498
+ cfg_type=cfg_type,
499
+ progress_bar=progress_bar,
500
+ )
501
+
502
+ if tosqueeze:
503
+ motion = motion[0]
504
+
505
+ output = self.motion_rep.inverse(
506
+ motion,
507
+ is_normalized=True,
508
+ return_numpy=False, # Keep as tensor for potential post-processing
509
+ )
510
+
511
+ # Apply post-processing if requested
512
+ if post_processing:
513
+ corrected = post_process_motion(
514
+ output["local_rot_mats"],
515
+ output["root_positions"],
516
+ output["foot_contacts"],
517
+ self.skeleton,
518
+ constraint_lst,
519
+ root_margin=root_margin,
520
+ )
521
+ # key frame outputs / foot contacts are not changed
522
+ output.update(corrected)
523
+
524
+ # Convert SOMA output to somaskel77 for external API
525
+ if isinstance(self.skeleton, SOMASkeleton30):
526
+ output = self.skeleton.output_to_SOMASkeleton77(output)
527
+
528
+ # Convert to numpy if requested
529
+ if return_numpy:
530
+ output = to_numpy(output)
531
+ return output
532
+
533
+ def _generate(
534
+ self,
535
+ texts: List[str],
536
+ max_frames: int,
537
+ num_denoising_steps: int,
538
+ pad_mask: torch.Tensor,
539
+ first_heading_angle: Optional[torch.Tensor],
540
+ motion_mask: torch.Tensor,
541
+ observed_motion: torch.Tensor,
542
+ cfg_weight: Optional[float] = 2.0,
543
+ text_feat: Optional[torch.Tensor] = None,
544
+ text_pad_mask: Optional[torch.Tensor] = None,
545
+ guide_masks: Optional[Dict] = None,
546
+ cfg_type: Optional[str] = None,
547
+ progress_bar=tqdm,
548
+ ) -> torch.Tensor:
549
+ """Sample full denoising loop.
550
+
551
+ Args:
552
+ texts (List[str]): batch of text prompts to use for sampling (if text_feat is not passed in)
553
+ """
554
+
555
+ device = self.device
556
+ if text_feat is None:
557
+ assert text_pad_mask is None
558
+ log.info("Encoding text...")
559
+ text_feat, text_length = self.text_encoder(texts)
560
+ text_feat = text_feat.to(device)
561
+
562
+ # handle empty string (set to zero)
563
+ empty_text_mask = [len(text.strip()) == 0 for text in texts]
564
+ text_feat[empty_text_mask] = 0
565
+
566
+ # Create the pad mask for the text
567
+ batch_size, maxlen = text_feat.shape[:2]
568
+ tensor_text_length = torch.tensor(text_length, device=device)
569
+ tensor_text_length[empty_text_mask] = 0
570
+ text_pad_mask = torch.arange(maxlen, device=device).expand(batch_size, maxlen) < tensor_text_length[:, None]
571
+
572
+ if motion_mask is not None:
573
+ if motion_mask.dtype == torch.bool:
574
+ motion_mask = 1 * motion_mask
575
+
576
+ batch_size = text_feat.shape[0]
577
+
578
+ # sample loop
579
+ indices = list(range(num_denoising_steps))[::-1]
580
+ shape = (batch_size, max_frames, self.motion_rep.motion_rep_dim)
581
+ cur_mot = torch.randn(shape, device=self.device)
582
+ num_denoising_steps = torch.tensor(
583
+ [num_denoising_steps], device=self.device
584
+ ) # this and t need to be tensor for onnx export
585
+ # init diffusion with correct num steps before looping
586
+ use_timesteps = self.diffusion.space_timesteps(num_denoising_steps[0])[0]
587
+ self.diffusion.calc_diffusion_vars(use_timesteps)
588
+ for i in progress_bar(indices):
589
+ t = torch.tensor([i] * cur_mot.size(0), device=self.device)
590
+ with torch.inference_mode():
591
+ cur_mot = self.denoising_step(
592
+ cur_mot,
593
+ pad_mask,
594
+ text_feat,
595
+ text_pad_mask,
596
+ t,
597
+ first_heading_angle,
598
+ motion_mask,
599
+ observed_motion,
600
+ num_denoising_steps,
601
+ cfg_weight,
602
+ guide_masks=guide_masks,
603
+ cfg_type=cfg_type,
604
+ )
605
+ return cur_mot
kimodo/model/llm2vec/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ This is a patched version of the original [LLM2Vec](https://github.com/McGill-NLP/llm2vec) codebase so that `McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised` works with `transformers==5.0.0rc3`.
kimodo/model/llm2vec/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """LLM2Vec text encoder and wrapper for Kimodo."""
4
+
5
+ from .llm2vec import LLM2Vec
6
+ from .llm2vec_wrapper import LLM2VecEncoder
7
+
8
+ __all__ = [
9
+ "LLM2Vec",
10
+ "LLM2VecEncoder",
11
+ ]
kimodo/model/llm2vec/llm2vec.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # Permission is hereby granted, free of charge, to any person obtaining a
5
+ # copy of this software and associated documentation files (the "Software"),
6
+ # to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
+ # and/or sell copies of the Software, and to permit persons to whom the
9
+ # Software is furnished to do so, subject to the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be included in
12
+ # all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
+ # DEALINGS IN THE SOFTWARE.
21
+
22
+
23
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
24
+ # SPDX-License-Identifier: Apache-2.0
25
+ #
26
+ # Licensed under the Apache License, Version 2.0 (the "License");
27
+ # you may not use this file except in compliance with the License.
28
+ # You may obtain a copy of the License at
29
+ #
30
+ # http://www.apache.org/licenses/LICENSE-2.0
31
+ #
32
+ # Unless required by applicable law or agreed to in writing, software
33
+ # distributed under the License is distributed on an "AS IS" BASIS,
34
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35
+ # See the License for the specific language governing permissions and
36
+ # limitations under the License.
37
+
38
+ import json
39
+ import logging
40
+ import os
41
+ from functools import partial
42
+ from typing import Dict, List, Optional, Union
43
+
44
+ import numpy as np
45
+ import torch
46
+ import torch.multiprocessing as mp
47
+ from peft import PeftModel
48
+ from torch import Tensor, device, nn
49
+ from tqdm.autonotebook import tqdm, trange
50
+ from transformers import (
51
+ AutoConfig,
52
+ AutoModel,
53
+ AutoTokenizer,
54
+ GemmaConfig,
55
+ LlamaConfig,
56
+ MistralConfig,
57
+ PretrainedConfig,
58
+ Qwen2Config,
59
+ )
60
+
61
+ logger = logging.getLogger(__name__)
62
+
63
+
64
+ def batch_to_device(batch, target_device: device):
65
+ """Send a pytorch batch to a device (CPU/GPU)"""
66
+ for key in batch:
67
+ if isinstance(batch[key], Tensor):
68
+ batch[key] = batch[key].to(target_device)
69
+ return batch
70
+
71
+
72
+ class LLM2Vec(nn.Module):
73
+ def __init__(
74
+ self,
75
+ model: AutoModel,
76
+ tokenizer: AutoTokenizer,
77
+ pooling_mode: str = "mean",
78
+ max_length: int = 512,
79
+ doc_max_length: int = 400,
80
+ skip_instruction: bool = True,
81
+ ):
82
+ super().__init__()
83
+ self.model = model
84
+ self.tokenizer = tokenizer
85
+ self.pooling_mode = pooling_mode
86
+ self.skip_instruction = skip_instruction
87
+ self.max_length = max_length
88
+ self.doc_max_length = doc_max_length
89
+ self.config = model.config
90
+
91
+ @classmethod
92
+ def _get_model_class(cls, config_class_name, enable_bidirectional):
93
+ if not enable_bidirectional:
94
+ return AutoModel
95
+ if config_class_name == "MistralConfig":
96
+ from .models.bidirectional_mistral import MistralBiModel
97
+
98
+ return MistralBiModel
99
+ elif config_class_name == "LlamaConfig":
100
+ from .models.bidirectional_llama import LlamaBiModel
101
+
102
+ return LlamaBiModel
103
+ elif config_class_name == "GemmaConfig":
104
+ from .models.bidirectional_gemma import GemmaBiModel
105
+
106
+ return GemmaBiModel
107
+ elif config_class_name == "Qwen2Config":
108
+ from .models.bidirectional_qwen2 import Qwen2BiModel
109
+
110
+ return Qwen2BiModel
111
+ else:
112
+ raise ValueError(f"{config_class_name} is not supported yet with bidirectional models.")
113
+
114
+ @classmethod
115
+ def from_pretrained(
116
+ cls,
117
+ base_model_name_or_path,
118
+ peft_model_name_or_path=None,
119
+ merge_peft=False,
120
+ enable_bidirectional=True,
121
+ **kwargs,
122
+ ):
123
+ # pop out encoder args
124
+ keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"]
125
+ encoder_args = {key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None}
126
+
127
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
128
+ tokenizer.pad_token = tokenizer.eos_token
129
+ tokenizer.padding_side = "left"
130
+
131
+ config = AutoConfig.from_pretrained(base_model_name_or_path)
132
+ config_class_name = config.__class__.__name__
133
+
134
+ model_class = cls._get_model_class(config_class_name, enable_bidirectional=enable_bidirectional)
135
+
136
+ model = model_class.from_pretrained(base_model_name_or_path, **kwargs)
137
+
138
+ if os.path.isdir(base_model_name_or_path) and os.path.exists(f"{base_model_name_or_path}/config.json"):
139
+ with open(f"{base_model_name_or_path}/config.json", "r") as fIn:
140
+ config_dict = json.load(fIn)
141
+ config = PretrainedConfig.from_dict(config_dict)
142
+ model.config._name_or_path = config._name_or_path
143
+
144
+ # For special case where config.json and adapter weights are in the same directory
145
+ if hasattr(model, "peft_config"):
146
+ model = PeftModel.from_pretrained(
147
+ model,
148
+ base_model_name_or_path,
149
+ )
150
+ model = model.merge_and_unload()
151
+
152
+ if peft_model_name_or_path is not None:
153
+ model = PeftModel.from_pretrained(
154
+ model,
155
+ peft_model_name_or_path,
156
+ )
157
+ if merge_peft:
158
+ model = model.merge_and_unload()
159
+
160
+ config = {}
161
+ config_addr = peft_model_name_or_path if peft_model_name_or_path is not None else base_model_name_or_path
162
+ if os.path.exists(f"{config_addr}/llm2vec_config.json"):
163
+ with open(f"{config_addr}/llm2vec_config.json", "r") as fIn:
164
+ llm2vec_config = json.load(fIn)
165
+ config.update(llm2vec_config)
166
+
167
+ for key, value in encoder_args.items():
168
+ config[key] = value
169
+
170
+ return cls(model=model, tokenizer=tokenizer, **config)
171
+
172
+ def prepare_for_tokenization(self, text):
173
+ if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct":
174
+ text = "<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>"
175
+ return text
176
+ if self.model.config._name_or_path in [
177
+ "mistralai/Mistral-7B-Instruct-v0.2",
178
+ "meta-llama/Llama-2-7b-chat-hf",
179
+ ]:
180
+ text = "[INST] " + text.strip() + " [/INST]"
181
+ if self.model.config._name_or_path in [
182
+ "google/gemma-2-9b-it",
183
+ ]:
184
+ text = "<bos><start_of_turn>user\n" + text.strip() + "<end_of_turn>"
185
+ if self.model.config._name_or_path in [
186
+ "Qwen/Qwen2-1.5B-Instruct",
187
+ "Qwen/Qwen2-7B-Instruct",
188
+ ]:
189
+ text = "<|im_start|>user\n" + text.strip() + "<|im_end|>"
190
+ if self.pooling_mode == "eos_token":
191
+ if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B":
192
+ text = text.strip() + "<|end_of_text|>"
193
+ elif isinstance(self.model.config, LlamaConfig) or isinstance(self.model.config, MistralConfig):
194
+ text = text.strip() + " </s>"
195
+ elif isinstance(self.model.config, GemmaConfig):
196
+ text = text.strip() + "<eos>"
197
+ elif isinstance(self.model.config, Qwen2Config):
198
+ text = text.strip() + "<|endoftext|>"
199
+ return text
200
+
201
+ def tokenize(self, texts):
202
+ texts_2 = []
203
+ original_texts = []
204
+ for text in texts:
205
+ t = text.split("!@#$%^&*()")
206
+ texts_2.append(t[1] if len(t) > 1 else "")
207
+ original_texts.append("".join(t))
208
+
209
+ original = self.tokenizer(
210
+ original_texts,
211
+ return_tensors="pt",
212
+ padding=True,
213
+ truncation=True,
214
+ max_length=self.max_length,
215
+ )
216
+ embed_mask = None
217
+ for t_i, t in enumerate(texts_2):
218
+ ids = self.tokenizer(
219
+ [t],
220
+ return_tensors="pt",
221
+ padding=True,
222
+ truncation=True,
223
+ max_length=self.max_length,
224
+ add_special_tokens=False,
225
+ )
226
+ if embed_mask is None:
227
+ e_m = torch.zeros_like(original["attention_mask"][t_i])
228
+ if len(ids["input_ids"][0]) > 0:
229
+ e_m[-len(ids["input_ids"][0]) :] = torch.ones(len(ids["input_ids"][0]))
230
+ embed_mask = e_m.unsqueeze(0)
231
+ else:
232
+ e_m = torch.zeros_like(original["attention_mask"][t_i])
233
+ if len(ids["input_ids"][0]) > 0:
234
+ e_m[-len(ids["input_ids"][0]) :] = torch.ones(len(ids["input_ids"][0]))
235
+ embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0)
236
+
237
+ original["embed_mask"] = embed_mask
238
+ return original
239
+
240
+ def _skip_instruction(self, sentence_feature):
241
+ assert sentence_feature["attention_mask"].shape == sentence_feature["embed_mask"].shape
242
+ sentence_feature["attention_mask"] = sentence_feature["embed_mask"]
243
+
244
+ def forward(self, sentence_feature: Dict[str, Tensor]):
245
+ embed_mask = None
246
+ if "embed_mask" in sentence_feature:
247
+ embed_mask = sentence_feature.pop("embed_mask")
248
+ reps = self.model(**sentence_feature)
249
+ sentence_feature["embed_mask"] = embed_mask
250
+
251
+ return self.get_pooling(sentence_feature, reps.last_hidden_state)
252
+
253
+ def get_pooling(self, features, last_hidden_states): # All models padded from left
254
+ assert self.tokenizer.padding_side == "left", "Pooling modes are implemented for padding from left."
255
+ if self.skip_instruction:
256
+ self._skip_instruction(features)
257
+ seq_lengths = features["attention_mask"].sum(dim=-1)
258
+ if self.pooling_mode == "mean":
259
+ return torch.stack(
260
+ [last_hidden_states[i, -length:, :].mean(dim=0) for i, length in enumerate(seq_lengths)],
261
+ dim=0,
262
+ )
263
+ elif self.pooling_mode == "weighted_mean":
264
+ bs, l, _ = last_hidden_states.shape
265
+ complete_weights = torch.zeros(bs, l, device=last_hidden_states.device)
266
+ for i, seq_l in enumerate(seq_lengths):
267
+ if seq_l > 0:
268
+ complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1
269
+ complete_weights[i] /= torch.clamp(complete_weights[i].sum(), min=1e-9)
270
+ return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1)
271
+ elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token":
272
+ return last_hidden_states[:, -1]
273
+ elif self.pooling_mode == "bos_token":
274
+ return last_hidden_states[features["input_ids"] == self.tokenizer.bos_token_id]
275
+ else:
276
+ raise ValueError(f"{self.pooling_mode} is not implemented yet.")
277
+
278
+ def _convert_to_str(self, instruction, text):
279
+ tokenized_q = self.tokenizer(
280
+ text,
281
+ return_tensors="pt",
282
+ padding=True,
283
+ truncation=True,
284
+ max_length=self.max_length,
285
+ add_special_tokens=False,
286
+ )
287
+ tokenized_q_length = len(tokenized_q["input_ids"][0])
288
+
289
+ while tokenized_q_length > self.doc_max_length:
290
+ reduction_ratio = self.doc_max_length / tokenized_q_length
291
+ reduced_length = int(len(text.split()) * reduction_ratio)
292
+ text = " ".join(text.split()[:reduced_length])
293
+ tokenized_q = self.tokenizer(
294
+ text,
295
+ return_tensors="pt",
296
+ padding=True,
297
+ truncation=True,
298
+ max_length=self.max_length,
299
+ add_special_tokens=False,
300
+ )
301
+ tokenized_q_length = len(tokenized_q["input_ids"][0])
302
+
303
+ return f"{instruction.strip()} !@#$%^&*(){text}" if instruction else f"!@#$%^&*(){text}"
304
+
305
+ def encode(
306
+ self,
307
+ sentences: Union[str, List[str]],
308
+ batch_size: int = 32,
309
+ show_progress_bar: bool = True,
310
+ convert_to_numpy: bool = False,
311
+ convert_to_tensor: bool = False,
312
+ device: Optional[str] = None,
313
+ ):
314
+ """
315
+ Encode a list of sentences to their respective embeddings. The sentences can be a list of strings or a string.
316
+ Args:
317
+ sentences: sentence or sentences to encode.
318
+ batch_size: batch size for turning sentence tokens into embeddings.
319
+ show_progress_bar: whether to show progress bars during encoding steps.
320
+ convert_to_numpy: If true, return numpy arrays instead of torch tensors.
321
+ convert_to_tensor: If true, return torch tensors (default).
322
+ device: torch backend device identifier (e.g., 'cuda', 'cpu','mps' etc.). If not specified,
323
+ the default is to use cuda when available, otherwise cpu. Note that only the choice of 'cuda' supports
324
+ multiprocessing as currently implemented.
325
+
326
+ Returns: embeddings of the sentences. Embeddings are detached and always on the CPU (see _encode implementation).
327
+
328
+ """
329
+ if isinstance(sentences[0], str) and isinstance(sentences[-1], int):
330
+ sentences = [sentences]
331
+ # required for MEDI version of MTEB
332
+ if isinstance(sentences[0], str):
333
+ sentences = [[""] + [sentence] for sentence in sentences]
334
+
335
+ if device is None:
336
+ device = "cuda" if torch.cuda.is_available() else "cpu"
337
+
338
+ concatenated_input_texts = []
339
+ for sentence in sentences:
340
+ assert isinstance(sentence[0], str)
341
+ assert isinstance(sentence[1], str)
342
+ concatenated_input_texts.append(self._convert_to_str(sentence[0], sentence[1]))
343
+ sentences = concatenated_input_texts
344
+
345
+ self.eval()
346
+
347
+ if convert_to_tensor:
348
+ convert_to_numpy = False
349
+
350
+ length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
351
+ sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
352
+ all_embeddings = []
353
+
354
+ if torch.cuda.device_count() <= 1:
355
+ # This branch also support mps devices
356
+ self.to(device)
357
+ for start_index in trange(
358
+ 0,
359
+ len(sentences),
360
+ batch_size,
361
+ desc="Batches",
362
+ disable=not show_progress_bar,
363
+ ):
364
+ sentences_batch = sentences_sorted[start_index : start_index + batch_size]
365
+ embeddings = self._encode(sentences_batch, device=device, convert_to_numpy=convert_to_numpy)
366
+ all_embeddings.append(embeddings)
367
+ else:
368
+ num_proc = torch.cuda.device_count()
369
+ cuda_compatible_multiprocess = mp.get_context("spawn")
370
+ with cuda_compatible_multiprocess.Pool(num_proc) as p:
371
+ sentences_batches = [
372
+ sentences_sorted[start_index : start_index + batch_size]
373
+ for start_index in range(0, len(sentences), batch_size)
374
+ ]
375
+
376
+ progress_bar = tqdm(
377
+ total=len(sentences_batches),
378
+ desc="Batches",
379
+ disable=not show_progress_bar,
380
+ )
381
+ results = []
382
+
383
+ def update(*args):
384
+ progress_bar.update()
385
+
386
+ for batch in sentences_batches:
387
+ results.append(
388
+ p.apply_async(
389
+ self._encode,
390
+ args=(batch, None, convert_to_numpy, True),
391
+ callback=update,
392
+ )
393
+ )
394
+
395
+ all_embeddings = [result.get() for result in results]
396
+ progress_bar.close()
397
+
398
+ all_embeddings = torch.cat(all_embeddings, dim=0)
399
+ all_embeddings = all_embeddings[np.argsort(length_sorted_idx)]
400
+ all_embeddings = all_embeddings.to(torch.float32)
401
+ if convert_to_numpy:
402
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
403
+ return all_embeddings
404
+
405
+ def save(self, output_path, merge_before_save=False, save_config=True):
406
+ if merge_before_save and isinstance(self.model, PeftModel):
407
+ self.model = self.model.merge_and_unload()
408
+ # Fixes the issue of saving - https://huggingface.co/McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse/discussions/1
409
+ if hasattr(self.model, "_hf_peft_config_loaded"):
410
+ self.model._hf_peft_config_loaded = False
411
+
412
+ self.model.save_pretrained(output_path)
413
+ self.tokenizer.save_pretrained(output_path)
414
+
415
+ llm2vec_config = {
416
+ "pooling_mode": self.pooling_mode,
417
+ "max_length": self.max_length,
418
+ "doc_max_length": self.doc_max_length,
419
+ "skip_instruction": self.skip_instruction,
420
+ }
421
+
422
+ if save_config:
423
+ os.makedirs(output_path, exist_ok=True)
424
+ with open(f"{output_path}/llm2vec_config.json", "w") as fOut:
425
+ json.dump(llm2vec_config, fOut, indent=4)
426
+
427
+ def _encode(
428
+ self,
429
+ sentences_batch,
430
+ device: Optional[str] = None,
431
+ convert_to_numpy: bool = False,
432
+ multiprocessing=False,
433
+ ):
434
+ if multiprocessing:
435
+ # multiprocessing only supports CUDA devices at this time, so we ignore the value of device
436
+ # and use cuda:rank for the device
437
+ rank = mp.current_process()._identity[0]
438
+ if device is None and torch.cuda.is_available():
439
+ device = f"cuda:{rank % torch.cuda.device_count()}"
440
+
441
+ self.to(device)
442
+ features = self.tokenize([self.prepare_for_tokenization(sentence) for sentence in sentences_batch])
443
+ features = batch_to_device(features, device)
444
+
445
+ with torch.no_grad():
446
+ embeddings = self.forward(features)
447
+ embeddings = embeddings.detach()
448
+ embeddings = embeddings.cpu()
449
+
450
+ return embeddings
451
+
452
+ def _text_length(self, text: Union[List[int], List[List[int]]]):
453
+ """Help function to get the length for the input text.
454
+
455
+ Text can be either a string (which means a single text) a list of ints (which means a single
456
+ tokenized text), or a tuple of list of ints (representing several text inputs to the model).
457
+ """
458
+ if (
459
+ isinstance(text, str) or (isinstance(text, list) and isinstance(text[0], int)) or len(text) == 0
460
+ ): # Single text, list of ints, or empty
461
+ return len(text)
462
+ if isinstance(text, dict): # {key: value} case
463
+ return len(next(iter(text.values())))
464
+ elif not hasattr(text, "__len__"): # Object has no len() method
465
+ return 1
466
+ else:
467
+ return sum([len(t) for t in text])
468
+
469
+ def resize_token_embeddings(
470
+ self,
471
+ new_num_tokens: Optional[int] = None,
472
+ pad_to_multiple_of: Optional[int] = None,
473
+ ) -> nn.Embedding:
474
+ return self.model.resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of)
475
+
476
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
477
+ self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
kimodo/model/llm2vec/llm2vec_wrapper.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """LLM2Vec encoder wrapper for Kimodo text conditioning."""
4
+
5
+ import os
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from .llm2vec import LLM2Vec
11
+
12
+
13
+ class LLM2VecEncoder:
14
+ """LLM2Vec text embeddings."""
15
+
16
+ def __init__(
17
+ self,
18
+ base_model_name_or_path: str,
19
+ peft_model_name_or_path: str,
20
+ dtype: str,
21
+ llm_dim: int,
22
+ ) -> None:
23
+ torch_dtype = getattr(torch, dtype)
24
+ self.llm_dim = llm_dim
25
+
26
+ cache_dir = os.environ.get("HUGGINGFACE_CACHE_DIR")
27
+
28
+ if "TEXT_ENCODERS_DIR" in os.environ:
29
+ base_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], base_model_name_or_path)
30
+ peft_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], peft_model_name_or_path)
31
+
32
+ self.model = LLM2Vec.from_pretrained(
33
+ base_model_name_or_path=base_model_name_or_path,
34
+ peft_model_name_or_path=peft_model_name_or_path,
35
+ torch_dtype=torch_dtype,
36
+ cache_dir=cache_dir,
37
+ )
38
+ self.model.eval()
39
+ for p in self.model.parameters():
40
+ p.requires_grad = False
41
+
42
+ def to(self, device: torch.device):
43
+ self.model = self.model.to(device)
44
+ return self
45
+
46
+ def eval(self):
47
+ self.model.eval()
48
+ return self
49
+
50
+ def get_device(self):
51
+ return self.model.model.device
52
+
53
+ def __call__(self, text: list[str] | str):
54
+ is_string = False
55
+ if isinstance(text, str):
56
+ text = [text]
57
+ is_string = True
58
+
59
+ with torch.no_grad():
60
+ encoded_text = self.model.encode(text, batch_size=len(text), show_progress_bar=False)
61
+
62
+ assert len(encoded_text.shape)
63
+ assert self.llm_dim == encoded_text.shape[-1]
64
+
65
+ encoded_text = encoded_text[:, None]
66
+ lengths = np.ones(len(encoded_text), dtype=int).tolist()
67
+
68
+ if is_string:
69
+ encoded_text = encoded_text[0]
70
+ lengths = lengths[0]
71
+
72
+ encoded_text = torch.tensor(encoded_text).to(self.get_device())
73
+ return encoded_text, lengths
kimodo/model/llm2vec/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # from .bidirectional_gemma import GemmaBiForMNTP, GemmaBiModel
2
+ # from .bidirectional_llama import LlamaBiForMNTP, LlamaBiModel
3
+ # from .bidirectional_mistral import MistralBiForMNTP, MistralBiModel
4
+ # from .bidirectional_qwen2 import Qwen2BiForMNTP, Qwen2BiModel
kimodo/model/llm2vec/models/attn_mask_utils.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # Permission is hereby granted, free of charge, to any person obtaining a
5
+ # copy of this software and associated documentation files (the "Software"),
6
+ # to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
+ # and/or sell copies of the Software, and to permit persons to whom the
9
+ # Software is furnished to do so, subject to the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be included in
12
+ # all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
+ # DEALINGS IN THE SOFTWARE.
21
+
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
26
+
27
+
28
+ def _prepare_4d_causal_attention_mask(
29
+ attention_mask: Optional[torch.Tensor],
30
+ input_shape: Union[torch.Size, Tuple, List],
31
+ inputs_embeds: torch.Tensor,
32
+ past_key_values_length: int,
33
+ sliding_window: Optional[int] = None,
34
+ ):
35
+ """Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D
36
+ mask of shape `(batch_size, key_value_length)`
37
+
38
+ Args:
39
+ attention_mask (`torch.Tensor` or `None`):
40
+ A 2D attention mask of shape `(batch_size, key_value_length)`
41
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
42
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
43
+ inputs_embeds (`torch.Tensor`):
44
+ The embedded inputs as a torch Tensor.
45
+ past_key_values_length (`int`):
46
+ The length of the key value cache.
47
+ sliding_window (`int`, *optional*):
48
+ If the model uses windowed attention, a sliding window should be passed.
49
+ """
50
+ attn_mask_converter = AttentionMaskConverter(
51
+ is_causal=False, sliding_window=sliding_window
52
+ ) # is_causal=True in original implementation
53
+
54
+ key_value_length = input_shape[-1] + past_key_values_length
55
+
56
+ # 4d mask is passed through the layers
57
+ if attention_mask is not None and len(attention_mask.shape) == 2:
58
+ attention_mask = attn_mask_converter.to_4d(
59
+ attention_mask,
60
+ input_shape[-1],
61
+ key_value_length=key_value_length,
62
+ dtype=inputs_embeds.dtype,
63
+ )
64
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
65
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
66
+ if tuple(attention_mask.shape) != expected_shape:
67
+ raise ValueError(
68
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
69
+ )
70
+ else:
71
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
72
+ inverted_mask = 1.0 - attention_mask
73
+ attention_mask = inverted_mask.masked_fill(
74
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
75
+ )
76
+ else:
77
+ attention_mask = attn_mask_converter.to_causal_4d(
78
+ input_shape[0],
79
+ input_shape[-1],
80
+ key_value_length,
81
+ dtype=inputs_embeds.dtype,
82
+ device=inputs_embeds.device,
83
+ )
84
+
85
+ return attention_mask
86
+
87
+
88
+ # Adapted from _prepare_4d_causal_attention_mask
89
+ def _prepare_4d_causal_attention_mask_for_sdpa(
90
+ attention_mask: Optional[torch.Tensor],
91
+ input_shape: Union[torch.Size, Tuple, List],
92
+ inputs_embeds: torch.Tensor,
93
+ past_key_values_length: int,
94
+ sliding_window: Optional[int] = None,
95
+ ):
96
+ """Prepares the correct `attn_mask` argument to be used by
97
+ `torch.nn.functional.scaled_dot_product_attention`.
98
+
99
+ In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
100
+ `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
101
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
102
+ """
103
+ attn_mask_converter = AttentionMaskConverter(
104
+ is_causal=False, sliding_window=sliding_window
105
+ ) # is_causal=True in original implementation
106
+
107
+ key_value_length = input_shape[-1] + past_key_values_length
108
+ batch_size, query_length = input_shape
109
+
110
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
111
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
112
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
113
+ is_tracing = (
114
+ torch.jit.is_tracing()
115
+ or isinstance(inputs_embeds, torch.fx.Proxy)
116
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
117
+ )
118
+
119
+ if attention_mask is not None:
120
+ # 4d mask is passed through
121
+ if len(attention_mask.shape) == 4:
122
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
123
+ if tuple(attention_mask.shape) != expected_shape:
124
+ raise ValueError(
125
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
126
+ )
127
+ else:
128
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
129
+ inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
130
+ attention_mask = inverted_mask.masked_fill(
131
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
132
+ )
133
+ return attention_mask
134
+
135
+ elif not is_tracing and torch.all(attention_mask == 1):
136
+ if query_length == 1:
137
+ # For query_length == 1, causal attention and bi-directional attention are the same.
138
+ attention_mask = None
139
+ elif key_value_length == query_length:
140
+ attention_mask = None
141
+ else:
142
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
143
+ # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
144
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
145
+ pass
146
+ elif query_length > 1 and key_value_length != query_length:
147
+ # See the comment above (https://github.com/pytorch/pytorch/issues/108108).
148
+ # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
149
+ attention_mask = True
150
+ elif is_tracing:
151
+ raise ValueError(
152
+ 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
153
+ )
154
+
155
+ if attention_mask is None:
156
+ expanded_4d_mask = None
157
+ elif attention_mask is True:
158
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
159
+ input_shape[0],
160
+ input_shape[-1],
161
+ key_value_length,
162
+ dtype=inputs_embeds.dtype,
163
+ device=inputs_embeds.device,
164
+ )
165
+ else:
166
+ expanded_4d_mask = attn_mask_converter.to_4d(
167
+ attention_mask,
168
+ input_shape[-1],
169
+ dtype=inputs_embeds.dtype,
170
+ key_value_length=key_value_length,
171
+ )
172
+
173
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
174
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
175
+ # Details: https://github.com/pytorch/pytorch/issues/110213
176
+ if not is_tracing and expanded_4d_mask.device.type == "cuda":
177
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
178
+ expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
179
+ )
180
+
181
+ return expanded_4d_mask
kimodo/model/llm2vec/models/bidirectional_llama.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # Permission is hereby granted, free of charge, to any person obtaining a
5
+ # copy of this software and associated documentation files (the "Software"),
6
+ # to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
+ # and/or sell copies of the Software, and to permit persons to whom the
9
+ # Software is furnished to do so, subject to the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be included in
12
+ # all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
+ # DEALINGS IN THE SOFTWARE.
21
+
22
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
23
+ # SPDX-License-Identifier: Apache-2.0
24
+ #
25
+ # Licensed under the Apache License, Version 2.0 (the "License");
26
+ # you may not use this file except in compliance with the License.
27
+ # You may obtain a copy of the License at
28
+ #
29
+ # http://www.apache.org/licenses/LICENSE-2.0
30
+ #
31
+ # Unless required by applicable law or agreed to in writing, software
32
+ # distributed under the License is distributed on an "AS IS" BASIS,
33
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34
+ # See the License for the specific language governing permissions and
35
+ # limitations under the License.
36
+
37
+ import torch
38
+ from peft import PeftModel
39
+ from torch import nn
40
+ from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel, LlamaPreTrainedModel
41
+ from transformers.cache_utils import Cache, StaticCache
42
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
43
+ from transformers.models.llama.modeling_llama import (
44
+ LlamaAttention,
45
+ LlamaDecoderLayer,
46
+ # LlamaFlashAttention2,
47
+ LlamaMLP,
48
+ LlamaRMSNorm,
49
+ LlamaRotaryEmbedding,
50
+ # LlamaSdpaAttention,
51
+ )
52
+ from transformers.utils import logging
53
+
54
+ from .utils import is_transformers_attn_greater_or_equal_4_43_1
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+
59
+ class ModifiedLlamaAttention(LlamaAttention):
60
+ def __init__(self, *args, **kwargs):
61
+ super().__init__(*args, **kwargs)
62
+ self.is_causal = False
63
+
64
+
65
+ # class ModifiedLlamaFlashAttention2(LlamaFlashAttention2):
66
+ # def __init__(self, *args, **kwargs):
67
+ # super().__init__(*args, **kwargs)
68
+ # self.is_causal = False
69
+
70
+
71
+ # class ModifiedLlamaSdpaAttention(LlamaSdpaAttention):
72
+ # def __init__(self, *args, **kwargs):
73
+ # super().__init__(*args, **kwargs)
74
+ # self.is_causal = False
75
+
76
+
77
+ # LLAMA_ATTENTION_CLASSES = {
78
+ # "eager": ModifiedLlamaAttention,
79
+ # "flash_attention_2": ModifiedLlamaFlashAttention2,
80
+ # "sdpa": ModifiedLlamaSdpaAttention,
81
+ # }
82
+
83
+
84
+ class ModifiedLlamaDecoderLayer(LlamaDecoderLayer):
85
+ def __init__(self, config: LlamaConfig, layer_idx: int):
86
+ nn.Module.__init__(self)
87
+ self.hidden_size = config.hidden_size
88
+
89
+ self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx)
90
+ # self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
91
+ # config=config, layer_idx=layer_idx
92
+ # )
93
+
94
+ self.mlp = LlamaMLP(config)
95
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
96
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
97
+
98
+
99
+ class LlamaBiModel(LlamaModel):
100
+ _no_split_modules = ["ModifiedLlamaDecoderLayer"]
101
+
102
+ def __init__(self, config: LlamaConfig):
103
+ if not is_transformers_attn_greater_or_equal_4_43_1():
104
+ raise ValueError(
105
+ "The current implementation of LlamaEncoderModel follows modeling_llama.py of transformers version >= 4.43.1"
106
+ )
107
+ LlamaPreTrainedModel.__init__(self, config)
108
+ self.padding_idx = config.pad_token_id
109
+ self.vocab_size = config.vocab_size
110
+
111
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
112
+ self.layers = nn.ModuleList(
113
+ [ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
114
+ )
115
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
116
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
117
+ self.gradient_checkpointing = False
118
+
119
+ # Initialize weights and apply final processing
120
+ self.post_init()
121
+
122
+ def _update_causal_mask(
123
+ self,
124
+ attention_mask,
125
+ input_tensor,
126
+ cache_position,
127
+ past_key_values: Cache,
128
+ output_attentions: bool,
129
+ ):
130
+ if self.config._attn_implementation == "flash_attention_2":
131
+ if attention_mask is not None and 0.0 in attention_mask:
132
+ return attention_mask
133
+ return None
134
+
135
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
136
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
137
+ # to infer the attention mask.
138
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
139
+ using_static_cache = isinstance(past_key_values, StaticCache)
140
+
141
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
142
+ # if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
143
+ # if AttentionMaskConverter._ignore_causal_mask_sdpa(
144
+ # attention_mask,
145
+ # inputs_embeds=input_tensor,
146
+ # past_key_values_length=past_seen_tokens,
147
+ # is_training=self.training,
148
+ # ):
149
+ # return None
150
+
151
+ dtype, device = input_tensor.dtype, input_tensor.device
152
+ min_dtype = torch.finfo(dtype).min
153
+ sequence_length = input_tensor.shape[1]
154
+ if using_static_cache:
155
+ target_length = past_key_values.get_max_length()
156
+ else:
157
+ target_length = (
158
+ attention_mask.shape[-1]
159
+ if isinstance(attention_mask, torch.Tensor)
160
+ else past_seen_tokens + sequence_length + 1
161
+ )
162
+
163
+ causal_mask = torch.zeros(
164
+ (sequence_length, target_length), dtype=dtype, device=device
165
+ ) # in original implementation - torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
166
+ # Commenting out next 2 lines to disable causal masking
167
+ # if sequence_length != 1:
168
+ # causal_mask = torch.triu(causal_mask, diagonal=1)
169
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
170
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
171
+ if attention_mask is not None:
172
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
173
+ if attention_mask.dim() == 2:
174
+ mask_length = attention_mask.shape[-1]
175
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
176
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
177
+ elif attention_mask.dim() == 4:
178
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
179
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
180
+ if attention_mask.shape[-2] < cache_position[0] + sequence_length:
181
+ offset = cache_position[0]
182
+ else:
183
+ offset = 0
184
+ mask_shape = attention_mask.shape
185
+ mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
186
+ causal_mask[
187
+ : mask_shape[0],
188
+ : mask_shape[1],
189
+ offset : mask_shape[2] + offset,
190
+ : mask_shape[3],
191
+ ] = mask_slice
192
+
193
+ if (
194
+ self.config._attn_implementation == "sdpa"
195
+ and attention_mask is not None
196
+ and attention_mask.device.type == "cuda"
197
+ and not output_attentions
198
+ ):
199
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
200
+
201
+ return causal_mask
202
+
203
+
204
+ class LlamaBiForMNTP(LlamaForCausalLM):
205
+ def __init__(self, config):
206
+ LlamaPreTrainedModel.__init__(self, config)
207
+ self.model = LlamaBiModel(config)
208
+ self.vocab_size = config.vocab_size
209
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
210
+
211
+ # Initialize weights and apply final processing
212
+ self.post_init()
213
+
214
+ # getter for PEFT model
215
+ def get_model_for_peft(self):
216
+ return self.model
217
+
218
+ # setter for PEFT model
219
+ def set_model_for_peft(self, model: PeftModel):
220
+ self.model = model
221
+
222
+ # save the PEFT model
223
+ def save_peft_model(self, path):
224
+ self.model.save_pretrained(path)
kimodo/model/llm2vec/models/utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # Permission is hereby granted, free of charge, to any person obtaining a
5
+ # copy of this software and associated documentation files (the "Software"),
6
+ # to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
+ # and/or sell copies of the Software, and to permit persons to whom the
9
+ # Software is furnished to do so, subject to the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be included in
12
+ # all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
+ # DEALINGS IN THE SOFTWARE.
21
+
22
+ import importlib.metadata
23
+
24
+ from packaging import version
25
+ from transformers.utils.import_utils import _is_package_available
26
+
27
+
28
+ def is_transformers_attn_greater_or_equal_4_43_1():
29
+ if not _is_package_available("transformers"):
30
+ return False
31
+
32
+ return version.parse(importlib.metadata.version("transformers")) >= version.parse("4.43.1")
kimodo/model/load_model.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Load Kimodo diffusion models from local checkpoints or Hugging Face."""
4
+
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ from huggingface_hub import snapshot_download
9
+ from omegaconf import OmegaConf
10
+
11
+ from .loading import (
12
+ AVAILABLE_MODELS,
13
+ DEFAULT_MODEL,
14
+ DEFAULT_TEXT_ENCODER_URL,
15
+ MODEL_NAMES,
16
+ TMR_MODELS,
17
+ get_env_var,
18
+ instantiate_from_dict,
19
+ )
20
+ from .registry import get_model_info, resolve_model_name
21
+
22
+ DEFAULT_TEXT_ENCODER = "llm2vec"
23
+ TEXT_ENCODER_PRESETS = {
24
+ "llm2vec": {
25
+ "target": "kimodo.model.LLM2VecEncoder",
26
+ "kwargs": {
27
+ "base_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
28
+ "peft_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised",
29
+ "dtype": "bfloat16",
30
+ "llm_dim": 4096,
31
+ },
32
+ }
33
+ }
34
+
35
+
36
+ def _resolve_hf_model_path(modelname: str) -> Path:
37
+ """Resolve model name to a local path, using Hugging Face cache or CHECKPOINT_DIR."""
38
+ try:
39
+ repo_id = MODEL_NAMES[modelname]
40
+ except KeyError:
41
+ raise ValueError(f"Model '{modelname}' not found. Available models: {MODEL_NAMES.keys()}")
42
+
43
+ local_cache = get_env_var("LOCAL_CACHE", "False").lower() == "true"
44
+ if not local_cache:
45
+ snapshot_dir = snapshot_download(repo_id=repo_id) # will check online no matter what
46
+ return Path(snapshot_dir)
47
+
48
+ try:
49
+ snapshot_dir = snapshot_download(repo_id=repo_id, local_files_only=True) # will check local cache only
50
+ return Path(snapshot_dir)
51
+ except Exception:
52
+ # if local cache is not found, download from online
53
+ try:
54
+ snapshot_dir = snapshot_download(repo_id=repo_id)
55
+ return Path(snapshot_dir)
56
+ except Exception:
57
+ raise RuntimeError(f"Could not resolve model '{modelname}' from Hugging Face (repo: {repo_id}). ") from None
58
+
59
+
60
+ def _build_api_text_encoder_conf(text_encoder_url: str) -> dict:
61
+ return {
62
+ "_target_": "kimodo.model.text_encoder_api.TextEncoderAPI",
63
+ "url": text_encoder_url,
64
+ }
65
+
66
+
67
+ def _build_local_text_encoder_conf() -> dict:
68
+ text_encoder_name = get_env_var("TEXT_ENCODER", DEFAULT_TEXT_ENCODER)
69
+ if text_encoder_name not in TEXT_ENCODER_PRESETS:
70
+ available = ", ".join(sorted(TEXT_ENCODER_PRESETS))
71
+ raise ValueError(f"Unknown TEXT_ENCODER='{text_encoder_name}'. Available: {available}")
72
+
73
+ preset = TEXT_ENCODER_PRESETS[text_encoder_name]
74
+ return {
75
+ "_target_": preset["target"],
76
+ **preset["kwargs"],
77
+ }
78
+
79
+
80
+ def _select_text_encoder_conf(text_encoder_url: str) -> dict:
81
+ # TEXT_ENCODER_MODE options:
82
+ # - "api": force TextEncoderAPI
83
+ # - "local": force local LLM2VecEncoder
84
+ # - "auto": try API first, fallback to local if unreachable
85
+ mode = get_env_var("TEXT_ENCODER_MODE", "auto").lower()
86
+ if mode == "local":
87
+ return _build_local_text_encoder_conf()
88
+ if mode == "api":
89
+ return _build_api_text_encoder_conf(text_encoder_url)
90
+
91
+ api_conf = _build_api_text_encoder_conf(text_encoder_url)
92
+ try:
93
+ text_encoder = instantiate_from_dict(api_conf)
94
+ # Probe availability early so inference doesn't fail later.
95
+ text_encoder(["healthcheck"])
96
+ return api_conf
97
+ except Exception as error:
98
+ print(
99
+ "Text encoder service is unreachable, falling back to local LLM2Vec "
100
+ f"encoder. ({type(error).__name__}: {error})"
101
+ )
102
+ return _build_local_text_encoder_conf()
103
+
104
+
105
+ def load_model(
106
+ modelname=None,
107
+ device=None,
108
+ eval_mode: bool = True,
109
+ default_family: Optional[str] = "Kimodo",
110
+ return_resolved_name: bool = False,
111
+ ):
112
+ """Load a kimodo model by name (e.g. 'g1', 'soma').
113
+
114
+ Resolution of partial/full names (e.g. Kimodo-SOMA-RP-v1, SOMA) is done
115
+ inside this function using default_family when the name is not a known
116
+ short key.
117
+
118
+ Args:
119
+ modelname: Model identifier; uses DEFAULT_MODEL if None. Can be a short key,
120
+ a full name (e.g. Kimodo-SOMA-RP-v1), or a partial name; unknown names
121
+ are resolved via resolve_model_name using default_family.
122
+ device: Target device for the model (e.g. 'cuda', 'cpu').
123
+ eval_mode: If True, set model to eval mode.
124
+ default_family: Used when modelname is not in AVAILABLE_MODELS to resolve
125
+ partial names ("Kimodo" for demo/generation, "TMR" for embed script).
126
+ Default "Kimodo".
127
+ return_resolved_name: If True, return (model, resolved_short_key). If False,
128
+ return only the model.
129
+
130
+ Returns:
131
+ Loaded model in eval mode, or (model, resolved short key) if
132
+ return_resolved_name is True.
133
+
134
+ Raises:
135
+ ValueError: If modelname is not in AVAILABLE_MODELS and cannot be resolved.
136
+ FileNotFoundError: If config.yaml is missing in the checkpoint folder.
137
+ """
138
+ if modelname is None:
139
+ modelname = DEFAULT_MODEL
140
+ if modelname not in AVAILABLE_MODELS:
141
+ if default_family is not None:
142
+ modelname = resolve_model_name(modelname, default_family)
143
+ else:
144
+ raise ValueError(
145
+ f"""The model is not recognized.
146
+ Please choose between: {AVAILABLE_MODELS}"""
147
+ )
148
+
149
+ resolved_modelname = modelname
150
+
151
+ # In case, we specify a custom checkpoint directory
152
+ configured_checkpoint_dir = get_env_var("CHECKPOINT_DIR")
153
+ if configured_checkpoint_dir:
154
+ print(f"CHECKPOINT_DIR is set to {configured_checkpoint_dir}, checking the local cache...")
155
+ # Checkpoint folders are named by display name (e.g. Kimodo-SOMA-RP-v1)
156
+ info = get_model_info(modelname)
157
+ checkpoint_folder_name = info.display_name if info is not None else modelname
158
+ model_path = Path(configured_checkpoint_dir) / checkpoint_folder_name
159
+ if not model_path.exists() and modelname != checkpoint_folder_name:
160
+ # Fallback: try short_key for backward compatibility
161
+ model_path = Path(configured_checkpoint_dir) / modelname
162
+ if not model_path.exists():
163
+ print(f"Model folder not found at '{model_path}', downloading it from Hugging Face...")
164
+ model_path = _resolve_hf_model_path(modelname)
165
+ else:
166
+ # Otherwise, we load the model from the local cache or download it from Hugging Face.
167
+ model_path = _resolve_hf_model_path(modelname)
168
+
169
+ model_config_path = model_path / "config.yaml"
170
+ if not model_config_path.exists():
171
+ raise FileNotFoundError(f"The model checkpoint folder exists but config.yaml is missing: {model_config_path}")
172
+
173
+ model_conf = OmegaConf.load(model_config_path)
174
+
175
+ if modelname in TMR_MODELS:
176
+ # Same process at the moment for TMR and Kimodo
177
+ pass
178
+
179
+ text_encoder_url = get_env_var("TEXT_ENCODER_URL", DEFAULT_TEXT_ENCODER_URL)
180
+ runtime_conf = OmegaConf.create(
181
+ {
182
+ "checkpoint_dir": str(model_path),
183
+ "text_encoder": _select_text_encoder_conf(text_encoder_url),
184
+ }
185
+ )
186
+ model_cfg = OmegaConf.to_container(OmegaConf.merge(model_conf, runtime_conf), resolve=True)
187
+ model_cfg.pop("checkpoint_dir", None)
188
+
189
+ model = instantiate_from_dict(model_cfg, overrides={"device": device})
190
+ if eval_mode:
191
+ model = model.eval()
192
+ if return_resolved_name:
193
+ return model, resolved_modelname
194
+ return model
kimodo/model/loading.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Model loading utilities: checkpoints, registry, env, and Hydra-based instantiation."""
4
+
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Union
8
+
9
+ import torch
10
+ from hydra.utils import instantiate
11
+ from omegaconf import OmegaConf
12
+ from safetensors.torch import load_file as load_safetensors
13
+
14
+ from .registry import (
15
+ AVAILABLE_MODELS,
16
+ DEFAULT_MODEL,
17
+ DEFAULT_TEXT_ENCODER_URL,
18
+ KIMODO_MODELS,
19
+ MODEL_NAMES,
20
+ TMR_MODELS,
21
+ )
22
+
23
+
24
+ def get_env_var(name: str, default: Optional[str] = None) -> Optional[str]:
25
+ """Return environment variable value, or default if unset/empty."""
26
+ return os.environ.get(name) or default
27
+
28
+
29
+ def instantiate_from_dict(
30
+ cfg: Dict[str, Any],
31
+ overrides: Optional[Dict[str, Any]] = None,
32
+ ):
33
+ """Instantiate an object from a config dict (e.g. from OmegaConf.to_container).
34
+
35
+ The dict must contain _target_ with a fully qualified class path. Nested configs are
36
+ instantiated recursively.
37
+ """
38
+ if overrides:
39
+ cfg = {**cfg, **overrides}
40
+ conf = OmegaConf.create(cfg)
41
+ return instantiate(conf)
42
+
43
+
44
+ def load_checkpoint_state_dict(ckpt_path: Union[str, Path]) -> dict:
45
+ """Load a state dict from a checkpoint file.
46
+
47
+ If the checkpoint is a dict with a 'state_dict' key (e.g. PyTorch Lightning),
48
+ that is returned; otherwise the whole checkpoint is treated as the state dict.
49
+
50
+ Args:
51
+ ckpt_path: Path to the checkpoint file.
52
+
53
+ Returns:
54
+ state_dict suitable for model.load_state_dict().
55
+ """
56
+ ckpt_path = str(ckpt_path)
57
+
58
+ if ckpt_path.endswith(".safetensors"):
59
+ state_dict = load_safetensors(ckpt_path)
60
+ else:
61
+ checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
62
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
63
+ state_dict = checkpoint["state_dict"]
64
+ elif isinstance(checkpoint, dict):
65
+ state_dict = checkpoint
66
+ else:
67
+ raise ValueError(f"Unsupported checkpoint format: {ckpt_path}")
68
+ return {key: val.detach().cpu() for key, val in state_dict.items()}
69
+
70
+
71
+ __all__ = [
72
+ "get_env_var",
73
+ "instantiate_from_dict",
74
+ "KIMODO_MODELS",
75
+ "TMR_MODELS",
76
+ "AVAILABLE_MODELS",
77
+ "MODEL_NAMES",
78
+ "DEFAULT_MODEL",
79
+ "DEFAULT_TEXT_ENCODER_URL",
80
+ "load_checkpoint_state_dict",
81
+ ]
kimodo/model/registry.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Registry of model names and Hugging Face repo IDs for Kimodo and TMR.
4
+
5
+ Canonical source of truth is the list of repo IDs. Short keys (e.g. soma-rp) and metadata (dataset,
6
+ skeleton, version, display name) are derived by parsing.
7
+ """
8
+
9
+ import re
10
+ from dataclasses import dataclass
11
+ from typing import Optional
12
+
13
+ # Canonical list: repo IDs in the same syntax as Hugging Face (org/Model-Name-v1).
14
+ # Parser expects: org/Family-SKELETON-DATASET-version (e.g. Kimodo-SOMA-RP-v1).
15
+ KIMODO_REPO_IDS = [
16
+ "nvidia/Kimodo-SOMA-RP-v1",
17
+ "nvidia/Kimodo-SMPLX-RP-v1",
18
+ "nvidia/Kimodo-G1-RP-v1",
19
+ "nvidia/Kimodo-SOMA-SEED-v1",
20
+ "nvidia/Kimodo-G1-SEED-v1",
21
+ ]
22
+ TMR_REPO_IDS = [
23
+ "nvidia/TMR-SOMA-RP-v1",
24
+ ]
25
+
26
+ # Repo ID without org, for display (e.g. Kimodo-SOMA-RP-v1).
27
+ _REPO_NAME_PATTERN = re.compile(r"^(Kimodo|TMR)-([A-Za-z0-9]+)-(RP|SEED)-v(\d+)$")
28
+
29
+
30
+ @dataclass
31
+ class ModelInfo:
32
+ """Structured metadata for one model, derived from its repo ID."""
33
+
34
+ repo_id: str
35
+ short_key: str
36
+ family: str
37
+ skeleton: str
38
+ dataset: str
39
+ version: str
40
+ display_name: str
41
+
42
+ @property
43
+ def dataset_ui_label(self) -> str:
44
+ return "Rigplay" if self.dataset == "RP" else "SEED"
45
+
46
+
47
+ def _parse_repo_id(repo_id: str) -> Optional[ModelInfo]:
48
+ """Parse a repo ID into ModelInfo.
49
+
50
+ Returns None if format is unrecognized.
51
+ """
52
+ # repo_id is "org/Model-Name-v1"
53
+ if "/" in repo_id:
54
+ _, name = repo_id.split("/", 1)
55
+ else:
56
+ name = repo_id
57
+ m = _REPO_NAME_PATTERN.match(name)
58
+ if not m:
59
+ return None
60
+ family, skeleton, dataset, ver = m.groups()
61
+ # Normalize skeleton for display (as is for now)
62
+ skeleton_display = skeleton
63
+ # Include family so Kimodo-SOMA-RP and TMR-SOMA-RP have distinct keys.
64
+ short_key = f"{family.lower()}-{skeleton.lower()}-{dataset.lower()}"
65
+ return ModelInfo(
66
+ repo_id=repo_id,
67
+ short_key=short_key,
68
+ family=family,
69
+ skeleton=skeleton_display,
70
+ dataset=dataset,
71
+ version=f"v{ver}",
72
+ display_name=name,
73
+ )
74
+
75
+
76
+ def _build_registry() -> tuple[list[ModelInfo], dict[str, str], list[str]]:
77
+ """Build model infos, short_key -> repo_id map, and list of short keys.
78
+
79
+ When multiple versions exist for the same (family, skeleton, dataset), the base short_key (e.g.
80
+ kimodo-soma-rp) maps to the latest version's repo_id so that HF resolution finds the newest
81
+ model.
82
+ """
83
+
84
+ def _version_key(info: ModelInfo) -> int:
85
+ v = info.version
86
+ if v.startswith("v") and v[1:].isdigit():
87
+ return int(v[1:])
88
+ return 0
89
+
90
+ all_repos = KIMODO_REPO_IDS + TMR_REPO_IDS
91
+ infos: list[ModelInfo] = []
92
+ for repo_id in all_repos:
93
+ info = _parse_repo_id(repo_id)
94
+ if info is None:
95
+ raise ValueError(f"Registry repo ID does not match expected pattern: {repo_id}")
96
+ infos.append(info)
97
+
98
+ # Map each base short_key to the latest version's repo_id (by version number)
99
+ model_names: dict[str, str] = {}
100
+ seen_short_keys: set[str] = set()
101
+ for info in infos:
102
+ if info.short_key in seen_short_keys:
103
+ continue
104
+ seen_short_keys.add(info.short_key)
105
+ candidates = [
106
+ i for i in infos if i.family == info.family and i.skeleton == info.skeleton and i.dataset == info.dataset
107
+ ]
108
+ if candidates:
109
+ latest = max(candidates, key=_version_key)
110
+ model_names[info.short_key] = latest.repo_id
111
+
112
+ return infos, model_names, list(model_names.keys())
113
+
114
+
115
+ MODEL_INFOS, MODEL_NAMES, _SHORT_KEYS = _build_registry()
116
+ AVAILABLE_MODELS = _SHORT_KEYS
117
+
118
+ # Short-key lists for Kimodo vs TMR (load_model uses TMR_MODELS to branch).
119
+ KIMODO_MODELS = [info.short_key for info in MODEL_INFOS if info.family == "Kimodo"]
120
+ TMR_MODELS = [info.short_key for info in MODEL_INFOS if info.family == "TMR"]
121
+
122
+ # Backward compatibility: FRIENDLY_NAMES for any code that still expects it.
123
+ FRIENDLY_NAMES = {info.short_key: info.display_name for info in MODEL_INFOS}
124
+
125
+ DEFAULT_MODEL = "kimodo-soma-rp"
126
+ DEFAULT_TEXT_ENCODER_URL = "http://127.0.0.1:9550/"
127
+
128
+ # Friendly names for skeleton dropdown (key -> label).
129
+ SKELETON_DISPLAY_NAMES = {
130
+ "SOMA": "SOMA Human Body",
131
+ "SMPLX": "SMPLX Human Body",
132
+ "G1": "Unitree G1 Humanoid Robot",
133
+ }
134
+
135
+ # Order for skeleton dropdown: SOMA, SMPLX, G1.
136
+ SKELETON_ORDER = ("SOMA", "SMPLX", "G1")
137
+
138
+
139
+ def get_skeleton_display_name(skeleton_key: str) -> str:
140
+ """Return the UI label for a skeleton key (e.g. SOMA -> SOMA Human Body)."""
141
+ return SKELETON_DISPLAY_NAMES.get(skeleton_key, skeleton_key)
142
+
143
+
144
+ def get_skeleton_key_from_display_name(display_name: str) -> Optional[str]:
145
+ """Return the skeleton key for a UI label, or None."""
146
+ for key, label in SKELETON_DISPLAY_NAMES.items():
147
+ if label == display_name:
148
+ return key
149
+ return None
150
+
151
+
152
+ def get_skeleton_display_names_for_dataset(dataset_ui_label: str, family: Optional[str] = None) -> list[str]:
153
+ """Return skeleton UI labels for the given dataset.
154
+
155
+ If family is set (e.g. "Kimodo"), only skeletons with a model of that family are included.
156
+ """
157
+ keys = get_skeletons_for_dataset(dataset_ui_label, family=family)
158
+ return [get_skeleton_display_name(k) for k in keys]
159
+
160
+
161
+ def get_short_key(repo_id: str) -> Optional[str]:
162
+ """Return the short key for a repo ID, or None if not in registry."""
163
+ for info in MODEL_INFOS:
164
+ if info.repo_id == repo_id:
165
+ return info.short_key
166
+ return None
167
+
168
+
169
+ def get_model_info(short_key: str) -> Optional[ModelInfo]:
170
+ """Return ModelInfo for a short key, or None if not found.
171
+
172
+ When multiple versions share the same short_key, returns the one used for loading (the latest
173
+ version), so CHECKPOINT_DIR and HF use the same version.
174
+ """
175
+ repo_id = MODEL_NAMES.get(short_key)
176
+ if repo_id is None:
177
+ return None
178
+ for info in MODEL_INFOS:
179
+ if info.repo_id == repo_id:
180
+ return info
181
+ return None
182
+
183
+
184
+ def get_short_key_from_display_name(display_name: str) -> Optional[str]:
185
+ """Return short_key for a display name (e.g. Kimodo-SOMA-RP-v1), or None."""
186
+ for info in MODEL_INFOS:
187
+ if info.display_name == display_name:
188
+ return info.short_key
189
+ return None
190
+
191
+
192
+ def get_models_for_demo() -> list[ModelInfo]:
193
+ """Return all model infos in registry order (for demo model list)."""
194
+ return list(MODEL_INFOS)
195
+
196
+
197
+ def get_datasets(family: Optional[str] = None) -> list[str]:
198
+ """Return unique dataset UI labels (Rigplay, SEED) present in registry.
199
+
200
+ If family is set (e.g. "Kimodo"), only datasets that have a model of that family are included.
201
+ """
202
+ infos = MODEL_INFOS
203
+ if family is not None:
204
+ infos = [i for i in infos if i.family == family]
205
+ labels = set()
206
+ for info in infos:
207
+ labels.add(info.dataset_ui_label)
208
+ return sorted(labels)
209
+
210
+
211
+ def get_skeletons_for_dataset(dataset_ui_label: str, family: Optional[str] = None) -> list[str]:
212
+ """Return skeleton names that have a model for the given dataset.
213
+
214
+ Order: SOMA, SMPLX, G1 (only those present for the dataset).
215
+ If family is set (e.g. "Kimodo"), only skeletons with a model of that
216
+ family are included.
217
+ """
218
+ dataset = "RP" if dataset_ui_label == "Rigplay" else "SEED"
219
+ infos = MODEL_INFOS
220
+ if family is not None:
221
+ infos = [i for i in infos if i.family == family]
222
+ skeletons = set()
223
+ for info in infos:
224
+ if info.dataset == dataset:
225
+ skeletons.add(info.skeleton)
226
+ return [s for s in SKELETON_ORDER if s in skeletons]
227
+
228
+
229
+ def get_versions_for_dataset_skeleton(dataset_ui_label: str, skeleton: str) -> list[str]:
230
+ """Return version strings (e.g. v1) for the given dataset/skeleton.
231
+
232
+ Sorted by version number so the last element is the highest (e.g. v1, v2).
233
+ """
234
+ dataset = "RP" if dataset_ui_label == "Rigplay" else "SEED"
235
+ versions = []
236
+ for info in MODEL_INFOS:
237
+ if info.dataset == dataset and info.skeleton == skeleton:
238
+ versions.append(info.version)
239
+
240
+ # Sort by numeric part so v2 comes after v1.
241
+ def version_key(v: str) -> int:
242
+ if v.startswith("v") and v[1:].isdigit():
243
+ return int(v[1:])
244
+ return 0
245
+
246
+ return sorted(set(versions), key=version_key)
247
+
248
+
249
+ def get_models_for_dataset_skeleton(
250
+ dataset_ui_label: str, skeleton: str, family: Optional[str] = None
251
+ ) -> list[ModelInfo]:
252
+ """Return model infos for the given dataset/skeleton, sorted by version (max first).
253
+
254
+ Used to build the Version dropdown (options = full display names, one per model). If family is
255
+ set (e.g. "Kimodo"), only models of that family are returned.
256
+ """
257
+ dataset = "RP" if dataset_ui_label == "Rigplay" else "SEED"
258
+ infos = [info for info in MODEL_INFOS if info.dataset == dataset and info.skeleton == skeleton]
259
+ if family is not None:
260
+ infos = [i for i in infos if i.family == family]
261
+
262
+ def version_key(info: ModelInfo) -> int:
263
+ v = info.version
264
+ if v.startswith("v") and v[1:].isdigit():
265
+ return int(v[1:])
266
+ return 0
267
+
268
+ return sorted(infos, key=version_key, reverse=True)
269
+
270
+
271
+ def resolve_to_short_key(dataset_ui_label: str, skeleton: str, version: str) -> Optional[str]:
272
+ """Return the short key for (dataset, skeleton, version), or None."""
273
+ for info in MODEL_INFOS:
274
+ if info.dataset_ui_label == dataset_ui_label and info.skeleton == skeleton and info.version == version:
275
+ return info.short_key
276
+ return None
277
+
278
+
279
+ # -----------------------------------------------------------------------------
280
+ # Flexible model name resolution (partial names, case-insensitive, defaults)
281
+ # -----------------------------------------------------------------------------
282
+
283
+ _FAMILY_ALIASES = {"kimodo": "Kimodo", "tmr": "TMR"}
284
+ _DATASET_ALIASES = {"rp": "RP", "rigplay": "RP", "seed": "SEED"}
285
+ _SKELETON_ALIASES = {
286
+ "soma": "SOMA",
287
+ "smplx": "SMPLX",
288
+ "g1": "G1",
289
+ }
290
+
291
+
292
+ def _normalize_family(s: str) -> Optional[str]:
293
+ """Return canonical family (Kimodo/TMR) or None if unknown."""
294
+ return _FAMILY_ALIASES.get(s.strip().lower())
295
+
296
+
297
+ def _normalize_dataset(s: str) -> Optional[str]:
298
+ """Return canonical dataset (RP/SEED) or None if unknown."""
299
+ return _DATASET_ALIASES.get(s.strip().lower())
300
+
301
+
302
+ def _normalize_skeleton(s: str) -> Optional[str]:
303
+ """Return canonical skeleton (SOMA/SMPLX/G1) or None if unknown."""
304
+ return _SKELETON_ALIASES.get(s.strip().lower())
305
+
306
+
307
+ def _get_latest_for_family_skeleton_dataset(family: str, skeleton: str, dataset: str) -> Optional[ModelInfo]:
308
+ """Return the model info with the highest version for (family, skeleton, dataset)."""
309
+ candidates = [
310
+ info for info in MODEL_INFOS if info.family == family and info.skeleton == skeleton and info.dataset == dataset
311
+ ]
312
+ if not candidates:
313
+ return None
314
+
315
+ def version_key(info: ModelInfo) -> int:
316
+ v = info.version
317
+ if v.startswith("v") and v[1:].isdigit():
318
+ return int(v[1:])
319
+ return 0
320
+
321
+ return max(candidates, key=version_key)
322
+
323
+
324
+ def kimodo_short_key_for_skeleton_dataset(skeleton: str, dataset: str) -> Optional[str]:
325
+ """Return the latest Kimodo model short_key for ``skeleton`` and ``dataset`` (RP/SEED), or
326
+ None."""
327
+ info = _get_latest_for_family_skeleton_dataset("Kimodo", skeleton, dataset)
328
+ return info.short_key if info is not None else None
329
+
330
+
331
+ def registry_skeleton_for_joint_count(nb_joints: int) -> str:
332
+ """Map motion joint count to registry skeleton key (SOMA / SMPLX / G1)."""
333
+ if nb_joints == 34:
334
+ return "G1"
335
+ if nb_joints == 22:
336
+ return "SMPLX"
337
+ if nb_joints in (77, 30):
338
+ return "SOMA"
339
+ raise ValueError(f"No Kimodo model registered for motion with J={nb_joints}")
340
+
341
+
342
+ # Optional version: Family-Skeleton-Dataset-vN or Family-Skeleton-Dataset
343
+ _RESOLVE_FULL_PATTERN = re.compile(
344
+ r"^(Kimodo|TMR|kimodo|tmr)[\-_]" r"([A-Za-z0-9]+)[\-_]" r"(RP|SEED|rp|seed)" r"(?:[\-_]v(\d+))?$",
345
+ re.IGNORECASE,
346
+ )
347
+ # Partial: Skeleton-Dataset or Skeleton or Dataset (no family)
348
+ _RESOLVE_PARTIAL_PATTERN = re.compile(
349
+ r"^([A-Za-z0-9]+)(?:[\-_](RP|SEED|rp|seed))?(?:[\-_]v(\d+))?$",
350
+ re.IGNORECASE,
351
+ )
352
+
353
+
354
+ def resolve_model_name(name: Optional[str], default_family: Optional[str] = None) -> str:
355
+ """Resolve a user-facing model name to a short_key.
356
+
357
+ Accepts full names (e.g. Kimodo-SOMA-RP-v1), case-insensitive matching,
358
+ and partial names with defaults: dataset=RP, skeleton=SOMA, family from
359
+ default_family (Kimodo for demo/generation, TMR for embed script).
360
+ Omitted version resolves to the latest for that model.
361
+
362
+ Args:
363
+ name: User-provided name (can be None or empty).
364
+ default_family: "Kimodo" or "TMR" when name is empty or omits family.
365
+
366
+ Returns:
367
+ Short key (e.g. kimodo-soma-rp) for use with load_model / MODEL_NAMES.
368
+
369
+ Raises:
370
+ ValueError: If name cannot be resolved or default_family is missing when needed.
371
+ """
372
+ if name is not None:
373
+ name = name.strip()
374
+ if not name:
375
+ if default_family is None:
376
+ raise ValueError('Model name is empty; provide a name or set default_family ("Kimodo" or "TMR").')
377
+ fam = _normalize_family(default_family)
378
+ if fam is None:
379
+ raise ValueError(f"default_family must be 'Kimodo' or 'TMR', got {default_family!r}")
380
+ info = _get_latest_for_family_skeleton_dataset(fam, "SOMA", "RP")
381
+ if info is None:
382
+ raise ValueError(f"No model found for {fam}-SOMA-RP. Available: {list(MODEL_NAMES.keys())}")
383
+ return info.short_key
384
+
385
+ # Exact short_key
386
+ if name in MODEL_NAMES:
387
+ return name
388
+
389
+ # Case-insensitive match against short_key or display_name
390
+ name_lower = name.lower()
391
+ matches = []
392
+ for info in MODEL_INFOS:
393
+ if name_lower == info.short_key.lower():
394
+ matches.append(info)
395
+ disp = info.display_name.lower()
396
+ if name_lower == disp or name_lower == ("nvidia/" + disp):
397
+ matches.append(info)
398
+ if len(matches) == 1:
399
+ return matches[0].short_key
400
+ if len(matches) > 1:
401
+ return matches[0].short_key
402
+
403
+ # Parsed full form: Family-Skeleton-Dataset or Family-Skeleton-Dataset-vN
404
+ m = _RESOLVE_FULL_PATTERN.match(name)
405
+ if m:
406
+ fam_raw, skel_raw, ds_raw, ver_num = m.groups()
407
+ fam = _normalize_family(fam_raw)
408
+ skel = _normalize_skeleton(skel_raw)
409
+ ds = _normalize_dataset(ds_raw)
410
+ if fam is not None and skel is not None and ds is not None:
411
+ if ver_num is not None:
412
+ version = f"v{ver_num}"
413
+ for info in MODEL_INFOS:
414
+ if info.family == fam and info.skeleton == skel and info.dataset == ds and info.version == version:
415
+ return info.short_key
416
+ else:
417
+ info = _get_latest_for_family_skeleton_dataset(fam, skel, ds)
418
+ if info is not None:
419
+ return info.short_key
420
+
421
+ # Parsed partial: Skeleton-Dataset, Skeleton, or Dataset (use default_family)
422
+ if default_family is not None:
423
+ m = _RESOLVE_PARTIAL_PATTERN.match(name)
424
+ if m:
425
+ tok1, ds_raw, ver_num = m.groups()
426
+ fam = _normalize_family(default_family)
427
+ if fam is not None:
428
+ skel = _normalize_skeleton(tok1)
429
+ ds_candidate = _normalize_dataset(ds_raw) if ds_raw else None
430
+ if skel is not None and ds_candidate is not None:
431
+ ds = ds_candidate
432
+ elif skel is not None:
433
+ ds = "RP"
434
+ else:
435
+ skel = "SOMA"
436
+ ds = _normalize_dataset(tok1) if tok1 else "RP"
437
+ if ds is None:
438
+ ds = "RP"
439
+ if ver_num is not None:
440
+ version = f"v{ver_num}"
441
+ for info in MODEL_INFOS:
442
+ if (
443
+ info.family == fam
444
+ and info.skeleton == skel
445
+ and info.dataset == ds
446
+ and info.version == version
447
+ ):
448
+ return info.short_key
449
+ else:
450
+ info = _get_latest_for_family_skeleton_dataset(fam, skel, ds)
451
+ if info is not None:
452
+ return info.short_key
453
+
454
+ # Single token: skeleton or dataset
455
+ fam = _normalize_family(default_family)
456
+ if fam is not None:
457
+ skel = _normalize_skeleton(name)
458
+ if skel is not None:
459
+ info = _get_latest_for_family_skeleton_dataset(fam, skel, "RP")
460
+ if info is not None:
461
+ return info.short_key
462
+ ds = _normalize_dataset(name)
463
+ if ds is not None:
464
+ info = _get_latest_for_family_skeleton_dataset(fam, "SOMA", ds)
465
+ if info is not None:
466
+ return info.short_key
467
+
468
+ raise ValueError(
469
+ f"Model name {name!r} could not be resolved. "
470
+ f"Use a short key (e.g. {list(MODEL_NAMES.keys())[:3]}...), "
471
+ "a full name (e.g. Kimodo-SOMA-RP-v1), or a partial (e.g. SOMA-RP, SOMA) "
472
+ "with default_family set."
473
+ )
kimodo/model/text_encoder_api.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Remote text encoder API client (Gradio) for motion generation."""
4
+
5
+ import logging
6
+
7
+ import numpy as np
8
+ import torch
9
+ from gradio_client import Client
10
+
11
+ # Suppress the [httpx] logs (GET requests)
12
+ logging.getLogger("httpx").setLevel(logging.WARNING)
13
+
14
+ # Suppress internal gradio_client logs
15
+ logging.getLogger("gradio_client").setLevel(logging.WARNING)
16
+
17
+
18
+ class TextEncoderAPI:
19
+ """Text encoder API client for motion generation."""
20
+
21
+ def __init__(self, url: str):
22
+ self.client = Client(url, verbose=False)
23
+ self.device = "cpu"
24
+ self.dtype = torch.float
25
+
26
+ def _create_np_random_name(self):
27
+ import uuid
28
+
29
+ return str(uuid.uuid4()) + ".npy"
30
+
31
+ def to(self, device=None, dtype=None):
32
+ if device is not None:
33
+ self.device = device
34
+ if dtype is not None:
35
+ self.dtype = dtype
36
+ return self
37
+
38
+ def __call__(self, texts):
39
+ """Encode text prompts into tensors.
40
+
41
+ Args:
42
+ texts (str | list[str]): text prompts to encode
43
+
44
+ Returns:
45
+ tuple[torch.Tensor, list[int]]: encoded text tensors and their lengths
46
+ """
47
+ if isinstance(texts, str):
48
+ texts = [texts]
49
+
50
+ tensors = []
51
+ lengths = []
52
+ for text in texts:
53
+ filename = self._create_np_random_name()
54
+
55
+ # Use a long result timeout to tolerate text-encoder cold-start (LLM2Vec model load ~60-120s).
56
+ result = self.client.submit(
57
+ text=text,
58
+ filename=filename,
59
+ api_name="/DemoWrapper",
60
+ ).result(timeout=300)
61
+ path = result[0]["value"]
62
+ tensor = np.load(path)
63
+ length = tensor.shape[0]
64
+
65
+ tensors.append(tensor)
66
+ lengths.append(length)
67
+
68
+ padded_tensor = np.zeros((len(lengths), max(lengths), tensors[0].shape[-1]), dtype=tensors[0].dtype)
69
+ for idx, (tensor, length) in enumerate(zip(tensors, lengths)):
70
+ padded_tensor[idx, :length] = tensor
71
+
72
+ padded_tensor = torch.from_numpy(padded_tensor)
73
+ padded_tensor = padded_tensor.to(device=self.device, dtype=self.dtype)
74
+ return padded_tensor, lengths
kimodo/model/tmr.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """TMR model: encoder, and text-to-motion retrieval head."""
4
+
5
+ import contextlib
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import repeat
12
+ from torch import Tensor
13
+
14
+ from kimodo.model import load_checkpoint_state_dict
15
+ from kimodo.motion_rep.feature_utils import length_to_mask
16
+ from kimodo.sanitize import sanitize_texts
17
+ from kimodo.skeleton import SkeletonBase, build_skeleton
18
+ from kimodo.tools import ensure_batched
19
+
20
+
21
+ class PositionalEncoding(nn.Module):
22
+ """Sinusoidal positional encoding for sequences (batch_first optional)."""
23
+
24
+ def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None:
25
+ super().__init__()
26
+ self.batch_first = batch_first
27
+
28
+ self.dropout = nn.Dropout(p=dropout)
29
+
30
+ pe = torch.zeros(max_len, d_model)
31
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
32
+ # Note: have to replace torch.exp() and math.log() with torch.pow()
33
+ # due to MKL exp() and ln() throws floating point exceptions on certain CPUs
34
+ div_term = torch.pow(10000.0, -torch.arange(0, d_model, 2).float() / d_model)
35
+ # div_term = torch.exp(
36
+ # torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
37
+ # )
38
+
39
+ pe[:, 0::2] = torch.sin(position * div_term)
40
+ pe[:, 1::2] = torch.cos(position * div_term)
41
+ pe = pe.unsqueeze(0).transpose(0, 1)
42
+ self.register_buffer("pe", pe, persistent=False)
43
+
44
+ def forward(self, x: Tensor) -> Tensor:
45
+ if self.batch_first:
46
+ x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]
47
+ else:
48
+ x = x + self.pe[: x.shape[0], :]
49
+ return self.dropout(x)
50
+
51
+
52
+ def load_ckpt(self, ckpt_path):
53
+ """Load model weights from checkpoint path."""
54
+ state_dict = load_checkpoint_state_dict(ckpt_path)
55
+ self.load_state_dict(state_dict)
56
+
57
+
58
+ class ACTORStyleEncoder(nn.Module):
59
+ """Motion encoder in ACTOR style: optional motion_rep projection, VAE/MLP tokens, transformer."""
60
+
61
+ def __init__(
62
+ self,
63
+ motion_rep: Optional[nn.Module],
64
+ llm_shape: Optional[Tuple],
65
+ vae: bool,
66
+ latent_dim: int = 256,
67
+ ff_size: int = 1024,
68
+ num_layers: int = 4,
69
+ num_heads: int = 4,
70
+ dropout: float = 0.1,
71
+ activation: str = "gelu",
72
+ ckpt_path: Optional[str] = None,
73
+ ) -> None:
74
+ super().__init__()
75
+
76
+ self.motion_rep = motion_rep
77
+ if motion_rep is not None and llm_shape is None:
78
+ nfeats = motion_rep.motion_rep_dim
79
+ elif motion_rep is None and llm_shape is not None:
80
+ nfeats = llm_shape[-1]
81
+ else:
82
+ raise ValueError
83
+
84
+ self.nfeats = nfeats
85
+ self.projection = nn.Linear(nfeats, latent_dim)
86
+
87
+ self.vae = vae
88
+ self.nbtokens = 2 if vae else 1
89
+ self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim))
90
+
91
+ self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout=dropout, batch_first=True)
92
+
93
+ seq_trans_encoder_layer = nn.TransformerEncoderLayer(
94
+ d_model=latent_dim,
95
+ nhead=num_heads,
96
+ dim_feedforward=ff_size,
97
+ dropout=dropout,
98
+ activation=activation,
99
+ batch_first=True,
100
+ )
101
+
102
+ self.seqTransEncoder = nn.TransformerEncoder(
103
+ seq_trans_encoder_layer,
104
+ num_layers=num_layers,
105
+ enable_nested_tensor=False,
106
+ )
107
+
108
+ if ckpt_path is not None:
109
+ load_ckpt(self, ckpt_path)
110
+
111
+ def forward(self, x_dict: Dict) -> Tensor:
112
+ x = x_dict["x"]
113
+ mask = x_dict["mask"]
114
+
115
+ x = self.projection(x)
116
+
117
+ device = x.device
118
+ bs = len(x)
119
+
120
+ tokens = repeat(self.tokens, "nbtoken dim -> bs nbtoken dim", bs=bs)
121
+ xseq = torch.cat((tokens, x), 1)
122
+
123
+ token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device)
124
+ aug_mask = torch.cat((token_mask, mask), 1)
125
+
126
+ # add positional encoding
127
+ xseq = self.sequence_pos_encoding(xseq)
128
+ final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
129
+ return final[:, : self.nbtokens]
130
+
131
+
132
+ class TMR(nn.Module):
133
+ r"""TMR: Text-to-Motion Retrieval inference code (no decoder)
134
+ Find more information about the model on the following website:
135
+ https://mathis.petrovich.fr/tmr
136
+ """
137
+
138
+ @classmethod
139
+ def from_args(
140
+ cls,
141
+ motion_rep: nn.Module,
142
+ llm_shape: tuple | list,
143
+ vae: bool,
144
+ latent_dim: int = 256,
145
+ ff_size: int = 1024,
146
+ num_layers: int = 4,
147
+ num_heads: int = 4,
148
+ dropout: float = 0.1,
149
+ activation: str = "gelu",
150
+ ckpt_folder: Optional[str] = None,
151
+ device: Optional[str] = None,
152
+ **kwargs,
153
+ ):
154
+ motion_encoder, top_text_encoder = None, None
155
+
156
+ motion_encoder = ACTORStyleEncoder(
157
+ motion_rep=motion_rep,
158
+ llm_shape=None,
159
+ vae=vae,
160
+ latent_dim=latent_dim,
161
+ ff_size=ff_size,
162
+ num_layers=num_layers,
163
+ num_heads=num_heads,
164
+ dropout=dropout,
165
+ activation=activation,
166
+ ckpt_path=Path(ckpt_folder) / "motion_encoder.pt",
167
+ ).to(device)
168
+
169
+ top_text_encoder = ACTORStyleEncoder(
170
+ motion_rep=None,
171
+ llm_shape=llm_shape,
172
+ vae=vae,
173
+ latent_dim=latent_dim,
174
+ ff_size=ff_size,
175
+ num_layers=num_layers,
176
+ num_heads=num_heads,
177
+ dropout=dropout,
178
+ activation=activation,
179
+ ckpt_path=Path(ckpt_folder) / "text_encoder.pt",
180
+ ).to(device)
181
+ return cls(
182
+ motion_encoder,
183
+ top_text_encoder,
184
+ vae,
185
+ device=device,
186
+ **kwargs,
187
+ )
188
+
189
+ def __init__(
190
+ self,
191
+ motion_encoder: nn.Module,
192
+ top_text_encoder: nn.Module,
193
+ vae: bool,
194
+ text_encoder: Optional = None,
195
+ fact: Optional[float] = None,
196
+ sample_mean: Optional[bool] = True,
197
+ unit_vector: Optional[bool] = False,
198
+ compute_grads: bool = False,
199
+ device: Optional[str] = None,
200
+ ) -> None:
201
+ super().__init__()
202
+
203
+ self.motion_encoder = motion_encoder
204
+ self.text_encoder = top_text_encoder
205
+ self.raw_text_encoder = text_encoder
206
+
207
+ self.motion_rep = None
208
+ self.skeleton = None
209
+ if self.motion_encoder is not None:
210
+ self.motion_rep = self.motion_encoder.motion_rep
211
+ if self.motion_rep is not None:
212
+ self.skeleton = self.motion_rep.skeleton
213
+
214
+ self.compute_grads = compute_grads
215
+
216
+ self.device = device
217
+
218
+ # sampling parameters
219
+ self.vae = vae
220
+ self.fact = fact if fact is not None else 1.0
221
+ self.sample_mean = sample_mean
222
+ self.unit_vector = unit_vector
223
+
224
+ def full_text_encoder(self, texts: list[str]):
225
+ assert isinstance(texts, list), "The input should be batched."
226
+ # sanitize the texts first
227
+ # then encode the text, and then use the top text encoder
228
+ texts = sanitize_texts(texts)
229
+ text_feat, text_length = self.raw_text_encoder(texts)
230
+ if isinstance(text_length, list):
231
+ text_length = torch.tensor(text_length, device=self.device)
232
+ else:
233
+ text_length = text_length.to(self.device)
234
+ inputs = {
235
+ "x": text_feat.to(self.device),
236
+ "mask": length_to_mask(text_length, device=self.device),
237
+ }
238
+ return self.text_encoder(inputs)
239
+
240
+ def _find_encoder(self, inputs, modality):
241
+ assert modality in ["text", "motion", "raw_text", "auto"]
242
+
243
+ if modality == "text":
244
+ return self.text_encoder
245
+ elif modality == "motion":
246
+ return self.motion_encoder
247
+ elif modality == "raw_text":
248
+ return self.full_text_encoder
249
+
250
+ if isinstance(inputs[0], str):
251
+ return self.full_text_encoder
252
+
253
+ m_nfeats = self.motion_encoder.nfeats
254
+ t_nfeats = self.text_encoder.nfeats
255
+
256
+ if m_nfeats == t_nfeats:
257
+ raise ValueError("Cannot automatically find the encoder, as they share the same input space.")
258
+
259
+ nfeats = inputs["x"].shape[-1]
260
+ if nfeats == m_nfeats:
261
+ return self.motion_encoder
262
+ elif nfeats == t_nfeats:
263
+ return self.text_encoder
264
+ else:
265
+ raise ValueError("The inputs is not recognized.")
266
+
267
+ def _encode(
268
+ self,
269
+ inputs,
270
+ modality: str = "auto",
271
+ sample_mean: Optional[bool] = None,
272
+ fact: Optional[float] = None,
273
+ return_distribution: bool = False,
274
+ unit_vector: Optional[bool] = None,
275
+ ):
276
+ sample_mean = self.sample_mean if sample_mean is None else sample_mean
277
+ fact = self.fact if fact is None else fact
278
+ unit_vector = self.unit_vector if unit_vector is None else unit_vector
279
+
280
+ # Encode the inputs
281
+ encoder = self._find_encoder(inputs, modality)
282
+ encoded = encoder(inputs)
283
+
284
+ # Sampling
285
+ if self.vae:
286
+ dists = encoded.unbind(1)
287
+ mu, logvar = dists
288
+ if sample_mean:
289
+ latent_vectors = mu
290
+ else:
291
+ # Reparameterization trick
292
+ std = logvar.exp().pow(0.5)
293
+ eps = std.data.new(std.size()).normal_()
294
+ latent_vectors = mu + fact * eps * std
295
+ else:
296
+ dists = None
297
+ (latent_vectors,) = encoded.unbind(1)
298
+
299
+ if unit_vector:
300
+ latent_vectors = torch.nn.functional.normalize(latent_vectors, dim=-1)
301
+
302
+ if return_distribution:
303
+ return latent_vectors, dists
304
+
305
+ return latent_vectors
306
+
307
+ @ensure_batched(posed_joints=4, lengths=1)
308
+ def encode_motion(
309
+ self,
310
+ posed_joints: torch.Tensor,
311
+ original_skeleton: Optional[SkeletonBase] = None,
312
+ lengths: Optional[torch.Tensor] = None,
313
+ unit_vector: Optional[bool] = None,
314
+ ):
315
+ # TODO here.
316
+ convert_ctx = torch.no_grad() if not self.compute_grads else contextlib.nullcontext()
317
+
318
+ if original_skeleton is None:
319
+ original_skeleton = build_skeleton(posed_joints.shape[-2])
320
+
321
+ if lengths is None:
322
+ nbatch, nbframes = posed_joints.shape[:2]
323
+ device = posed_joints.device
324
+ assert nbatch == 1, "If lenghts is not provided, the input should not be batched."
325
+ lengths = torch.tensor([nbframes], device=device)
326
+
327
+ # slice the posed joints if we use less joints
328
+ skel_slice = self.motion_rep.skeleton.get_skel_slice(original_skeleton)
329
+ posed_joints = posed_joints[..., skel_slice, :]
330
+
331
+ with convert_ctx:
332
+ features = self.motion_rep(
333
+ posed_joints=posed_joints,
334
+ to_normalize=True,
335
+ lengths=lengths,
336
+ )
337
+ mask = length_to_mask(lengths, device=features.device)
338
+ x_dict = {"x": features, "mask": mask}
339
+ latent_vectors = self._encode(
340
+ x_dict,
341
+ modality="motion",
342
+ unit_vector=unit_vector,
343
+ )
344
+ return latent_vectors
345
+
346
+ def encode_text(
347
+ self,
348
+ x_dict: Dict,
349
+ unit_vector: Optional[bool] = None,
350
+ ):
351
+ # TODO: make it ensure batched
352
+ convert_ctx = torch.no_grad() if not self.compute_grads else contextlib.nullcontext()
353
+
354
+ with convert_ctx:
355
+ latent_vectors = self._encode(
356
+ x_dict,
357
+ modality="text",
358
+ unit_vector=unit_vector,
359
+ )
360
+ return latent_vectors
361
+
362
+ def encode_raw_text(
363
+ self,
364
+ texts: List[str],
365
+ unit_vector: Optional[bool] = None,
366
+ ):
367
+ is_batched = True
368
+ if isinstance(texts, str):
369
+ is_batched = False
370
+ texts = [texts]
371
+
372
+ convert_ctx = torch.no_grad() if not self.compute_grads else contextlib.nullcontext()
373
+
374
+ with convert_ctx:
375
+ latent_vectors = self._encode(
376
+ texts,
377
+ modality="raw_text",
378
+ unit_vector=unit_vector,
379
+ )
380
+ if not is_batched:
381
+ latent_vectors = latent_vectors[0]
382
+ return latent_vectors
kimodo/model/twostage_denoiser.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Two-stage transformer denoiser: root stage then body stage for motion diffusion."""
4
+
5
+ import contextlib
6
+ from typing import Optional
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+ from .backbone import TransformerEncoderBlock
12
+ from .loading import load_checkpoint_state_dict
13
+
14
+
15
+ class TwostageDenoiser(nn.Module):
16
+ """Two-stage denoiser: first predicts global root features, then body features conditioned on local root."""
17
+
18
+ def __init__(
19
+ self,
20
+ motion_rep,
21
+ motion_mask_mode,
22
+ ckpt_path: Optional[str] = None,
23
+ **kwargs,
24
+ ):
25
+ """Build root and body transformer blocks; optionally load checkpoint from ckpt_path."""
26
+ super().__init__()
27
+ self.motion_rep = motion_rep
28
+ self.motion_mask_mode = motion_mask_mode
29
+
30
+ # it should be a dual motion_rep
31
+ # and be global by default
32
+ # global motion_rep as inpnut
33
+ input_dim = motion_rep.motion_rep_dim
34
+ will_concatenate = motion_mask_mode == "concat"
35
+
36
+ # stage 1: root only
37
+ root_input_dim = input_dim * 2 if will_concatenate else input_dim
38
+ root_output_dim = motion_rep.global_root_dim
39
+
40
+ self.root_model = TransformerEncoderBlock(
41
+ input_dim=root_input_dim,
42
+ output_dim=root_output_dim,
43
+ skeleton=self.motion_rep.skeleton,
44
+ **kwargs,
45
+ )
46
+
47
+ # replace the global root by the local root
48
+ local_motion_rep_dim = input_dim - motion_rep.global_root_dim + motion_rep.local_root_dim
49
+
50
+ # stage 2: local body
51
+ body_input_dim = local_motion_rep_dim + (
52
+ input_dim if will_concatenate else 0
53
+ ) # body stage always takes in local root info for motion (but still the global mask)
54
+
55
+ body_output_dim = input_dim - motion_rep.global_root_dim
56
+ self.body_model = TransformerEncoderBlock(
57
+ input_dim=body_input_dim,
58
+ output_dim=body_output_dim,
59
+ skeleton=self.motion_rep.skeleton,
60
+ **kwargs,
61
+ )
62
+
63
+ if ckpt_path:
64
+ self.load_ckpt(ckpt_path)
65
+
66
+ def load_ckpt(self, ckpt_path: str) -> None:
67
+ """Load checkpoint from path; state dict keys are stripped of 'denoiser.backbone.'
68
+ prefix."""
69
+ state_dict = load_checkpoint_state_dict(ckpt_path)
70
+ state_dict = {key.replace("denoiser.backbone.", ""): val for key, val in state_dict.items()}
71
+ self.load_state_dict(state_dict)
72
+
73
+ def forward(
74
+ self,
75
+ x: torch.Tensor,
76
+ x_pad_mask: torch.Tensor,
77
+ text_feat: torch.Tensor,
78
+ text_feat_pad_mask: torch.Tensor,
79
+ timesteps: torch.Tensor,
80
+ first_heading_angle: Optional[torch.Tensor] = None,
81
+ motion_mask: Optional[torch.Tensor] = None,
82
+ observed_motion: Optional[torch.Tensor] = None,
83
+ ) -> torch.Tensor:
84
+ """
85
+ Args:
86
+ x (torch.Tensor): [B, T, dim_motion] current noisy motion
87
+ x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not
88
+ text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts
89
+ text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not
90
+ timesteps (torch.Tensor): [B,] current denoising step
91
+ motion_mask
92
+ observed_motion
93
+
94
+ Returns:
95
+ torch.Tensor: same size as input x
96
+ """
97
+
98
+ if self.motion_mask_mode == "concat":
99
+ if motion_mask is None or observed_motion is None:
100
+ motion_mask = torch.zeros_like(x)
101
+ observed_motion = torch.zeros_like(x)
102
+ x = x * (1 - motion_mask) + observed_motion * motion_mask
103
+ x_extended = torch.cat([x, motion_mask], axis=-1)
104
+ else:
105
+ x_extended = x
106
+
107
+ # Stage 1: predict root motion in global
108
+ root_motion_pred = self.root_model(
109
+ x_extended,
110
+ x_pad_mask,
111
+ text_feat,
112
+ text_feat_pad_mask,
113
+ timesteps,
114
+ first_heading_angle,
115
+ ) # [B, T, 5]
116
+
117
+ # Maybe pass this as argument instead of recomputing it
118
+ lengths = x_pad_mask.sum(-1)
119
+
120
+ # Convert root pred to local rep
121
+ # At test-time want to allow gradient through for guidance
122
+ convert_ctx = torch.no_grad() if self.training else contextlib.nullcontext()
123
+ with convert_ctx:
124
+ root_motion_local = self.motion_rep.global_root_to_local_root(
125
+ root_motion_pred,
126
+ normalized=True,
127
+ lengths=lengths,
128
+ )
129
+ if self.training:
130
+ root_motion_local = root_motion_local.detach()
131
+
132
+ # concatenate the predicted local root with the body motion
133
+ body_x = x[..., self.motion_rep.body_slice]
134
+ x_new = torch.cat([root_motion_local, body_x], axis=-1)
135
+
136
+ if self.motion_mask_mode == "concat":
137
+ x_new_extended = torch.cat([x_new, motion_mask], axis=-1)
138
+ else:
139
+ x_new_extended = x_new
140
+
141
+ # Stage 2: predict local body motion based on local root
142
+ predicted_body = self.body_model(
143
+ x_new_extended,
144
+ x_pad_mask,
145
+ text_feat,
146
+ text_feat_pad_mask,
147
+ timesteps,
148
+ first_heading_angle,
149
+ )
150
+
151
+ # concatenate the predicted local body with the predicted root
152
+ output = torch.cat([root_motion_pred, predicted_body], axis=-1)
153
+ return output
kimodo/motion_rep/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Motion representation utilities."""
4
+
5
+ from .reps import KimodoMotionRep, MotionRepBase, TMRMotionRep
6
+
7
+ __all__ = [
8
+ "MotionRepBase",
9
+ "KimodoMotionRep",
10
+ "TMRMotionRep",
11
+ ]
kimodo/motion_rep/conditioning.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Constraint conditioning: build index and data dicts from constraint sets for the denoiser."""
4
+
5
+ from collections import defaultdict
6
+
7
+ import torch
8
+
9
+
10
+ def build_condition_dicts(constraints_lst: list):
11
+ index_dict = defaultdict(list)
12
+ data_dict = defaultdict(list)
13
+ for constraint in constraints_lst:
14
+ constraint.update_constraints(data_dict, index_dict)
15
+ return index_dict, data_dict
16
+
17
+
18
+ def get_unique_index_and_data(indices_lst, data):
19
+ # unique + sort them by t
20
+ indices_unique, inverse = torch.unique(indices_lst, dim=0, return_inverse=True)
21
+ # pick first value for each unique (t, j)
22
+ first_idx = torch.zeros(indices_unique.size(0), dtype=torch.long, device=inverse.device)
23
+ first_idx.scatter_(0, inverse, torch.arange(len(inverse), device=inverse.device))
24
+ assert (indices_lst[first_idx] == indices_unique).all()
25
+ # get the data
26
+ indices_lst = indices_lst[first_idx]
27
+ data = data[first_idx]
28
+ return indices_lst, data
kimodo/motion_rep/feature_utils.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Motion representation helpers: velocity, heading, masks, and rotation of features."""
4
+
5
+ from typing import List, Optional, Union
6
+
7
+ import einops
8
+ import torch
9
+
10
+ from kimodo.geometry import cont6d_to_matrix, matrix_to_cont6d
11
+ from kimodo.skeleton import SkeletonBase
12
+ from kimodo.tools import ensure_batched
13
+
14
+
15
+ def diff_angles(angles: torch.Tensor, fps: float) -> torch.Tensor:
16
+ """Compute frame-to-frame angular differences in radians, scaled by fps.
17
+
18
+ Args:
19
+ angles: [..., T] batched sequences of rotation angles in radians.
20
+ fps: Sampling rate used to convert frame differences to per-second rate.
21
+
22
+ Returns:
23
+ [..., T-1] difference between consecutive angles (rad/s).
24
+ """
25
+
26
+ cos = torch.cos(angles)
27
+ sin = torch.sin(angles)
28
+
29
+ cos_diff = cos[..., 1:] * cos[..., :-1] + sin[..., 1:] * sin[..., :-1]
30
+ sin_diff = sin[..., 1:] * cos[..., :-1] - cos[..., 1:] * sin[..., :-1]
31
+
32
+ # should be close to angles.diff() but more robust
33
+ # multiply by fps = 1 / dt
34
+ angles_diff = fps * torch.arctan2(sin_diff, cos_diff)
35
+ return angles_diff
36
+
37
+
38
+ @ensure_batched(positions=4, lengths=1)
39
+ def compute_vel_xyz(
40
+ positions: torch.Tensor,
41
+ fps: float,
42
+ lengths: Optional[torch.Tensor] = None,
43
+ ) -> torch.Tensor:
44
+ """Compute the velocities from positions: dx/dt. Works with batches. The last velocity is duplicated to keep the same size.
45
+
46
+ Args:
47
+ positions (torch.Tensor): [..., T, J, 3] xyz positions of a human skeleton
48
+ fps (float): frame per seconds
49
+ lengths (Optional[torch.Tensor]): [...] size of each input batched. If not provided, positions should not be batched
50
+
51
+ Returns:
52
+ velocity (torch.Tensor): [..., T, J, 3] velocities computed from the positions
53
+ """
54
+ device = positions.device
55
+
56
+ if lengths is None:
57
+ assert positions.shape[0] == 1, "If lengths is not provided, the input should not be batched."
58
+ lengths = torch.tensor([len(positions)], device=device)
59
+
60
+ # useful for indexing
61
+ range_len = torch.arange(len(lengths))
62
+
63
+ # compute velocities with fps
64
+ velocity = fps * (positions[:, 1:] - positions[:, :-1])
65
+ # pading the velocity vector
66
+ vel_pad = torch.zeros_like(velocity[:, 0])
67
+ velocity, _ = einops.pack([velocity, vel_pad], "batch * nbjoints dim")
68
+
69
+ # repeat the last velocities
70
+ # with special care for different lengths with batches
71
+ velocity[(range_len, lengths - 1)] = velocity[(range_len, lengths - 2)]
72
+ return velocity
73
+
74
+
75
+ @ensure_batched(root_rot_angles=2, lengths=1)
76
+ def compute_vel_angle(
77
+ root_rot_angles: torch.Tensor,
78
+ fps: float,
79
+ lengths: Optional[torch.Tensor] = None,
80
+ ) -> torch.Tensor:
81
+ """Compute the local root rotation velocity: dtheta/dt.
82
+
83
+ Args:
84
+ root_rot_angles (torch.Tensor): [..., T] rotation angle (in radian)
85
+ fps (float): frame per seconds
86
+ lengths (Optional[torch.Tensor]): [...] size of each input batched. If not provided, root_rot_angles should not be batched
87
+
88
+ Returns:
89
+ local_root_rot_vel (torch.Tensor): [..., T] local root rotation velocity (in radian/s)
90
+ """
91
+ device = root_rot_angles.device
92
+ if lengths is None:
93
+ assert root_rot_angles.shape[0] == 1, "If lengths is not provided, the input should not be batched."
94
+ lengths = torch.tensor([len(root_rot_angles)], device=device)
95
+
96
+ # useful for indexing
97
+ range_len = torch.arange(len(lengths))
98
+
99
+ local_root_rot_vel = diff_angles(root_rot_angles, fps)
100
+ pad_rot_vel_angles = torch.zeros_like(root_rot_angles[:, 0])
101
+ local_root_rot_vel, _ = einops.pack(
102
+ [local_root_rot_vel, pad_rot_vel_angles],
103
+ "batch *",
104
+ )
105
+ # repeat the last rotation angle
106
+ # with special care for different lengths with batches
107
+ local_root_rot_vel[(range_len, lengths - 1)] = local_root_rot_vel[(range_len, lengths - 2)]
108
+ return local_root_rot_vel
109
+
110
+
111
+ @ensure_batched(posed_joints=4)
112
+ def compute_heading_angle(posed_joints: torch.Tensor, skeleton: SkeletonBase) -> torch.Tensor:
113
+ """Compute the heading direction from joint positions using the hip vector.
114
+
115
+ Args:
116
+ posed_joints: [B, T, J, 3] global joint positions.
117
+ skeleton: Skeleton instance used to get hip joint indices.
118
+
119
+ Returns:
120
+ [B] heading angle in radians.
121
+ """
122
+ # compute root heading for the sequence from hip positions
123
+ r_hip, l_hip = skeleton.hip_joint_idx
124
+ diff = posed_joints[:, :, r_hip] - posed_joints[:, :, l_hip]
125
+ heading_angle = torch.atan2(diff[..., 2], -diff[..., 0])
126
+ return heading_angle
127
+
128
+
129
+ def length_to_mask(
130
+ length: Union[torch.Tensor, List],
131
+ max_len: Optional[int] = None,
132
+ device=None,
133
+ ) -> torch.Tensor:
134
+ """Convert sequence lengths to a boolean validity mask.
135
+
136
+ Args:
137
+ length: Sequence lengths, either a tensor ``[B]`` or a Python list.
138
+ max_len: Optional mask width. If omitted, uses ``max(length)``.
139
+ device: Optional device. When ``length`` is a list, this controls where
140
+ the new tensor is created.
141
+
142
+ Returns:
143
+ A boolean tensor of shape ``[B, max_len]`` where ``True`` marks valid
144
+ timesteps.
145
+ """
146
+ if isinstance(length, list):
147
+ if device is None:
148
+ device = "cpu"
149
+ length = torch.tensor(length, device=device)
150
+
151
+ # Use requested device for output; move length if needed so mask and length match
152
+ if device is not None:
153
+ target = torch.device(device)
154
+ if length.device != target:
155
+ length = length.to(target)
156
+ device = length.device
157
+
158
+ if max_len is None:
159
+ max_len = max(length)
160
+
161
+ mask = torch.arange(max_len, device=device).expand(len(length), max_len) < length.unsqueeze(1)
162
+ return mask
163
+
164
+
165
+ class RotateFeatures:
166
+ """Helper that applies a global heading rotation to motion features."""
167
+
168
+ def __init__(self, angle: torch.Tensor):
169
+ """Precompute 2D and 3D rotation matrices for a batch of angles.
170
+
171
+ Args:
172
+ angle: Rotation angle(s) in radians, shaped ``[B]``.
173
+ """
174
+ self.angle = angle
175
+
176
+ ## Create the necessary rotations matrices
177
+ cos, sin = torch.cos(angle), torch.sin(angle)
178
+ one, zero = torch.ones_like(angle), torch.zeros_like(angle)
179
+
180
+ # 2D rotation transposed (sin are -sin)
181
+ self.corrective_mat_2d_T = torch.stack((cos, sin, -sin, cos), -1).reshape(angle.shape + (2, 2))
182
+ # 3D rotation on Y axis
183
+ self.corrective_mat_Y = torch.stack((cos, zero, sin, zero, one, zero, -sin, zero, cos), -1).reshape(
184
+ angle.shape + (3, 3)
185
+ )
186
+ self.corrective_mat_Y_T = self.corrective_mat_Y.transpose(1, 2).contiguous()
187
+
188
+ def rotate_positions(self, positions: torch.Tensor):
189
+ """Rotate 3D positions around the Y axis."""
190
+ return positions @ self.corrective_mat_Y_T
191
+
192
+ def rotate_2d_positions(self, positions_2d: torch.Tensor):
193
+ """Rotate 2D ``(x, z)`` vectors in the ground plane."""
194
+ return positions_2d @ self.corrective_mat_2d_T
195
+
196
+ def rotate_rotations(self, rotations: torch.Tensor):
197
+ """Left-multiply global rotation matrices by the heading correction."""
198
+ # "Rotate" the global rotations
199
+ # which means add an extra Y rotation after the transform
200
+ # so at the left R' = R_y R
201
+ # (since we use the convention x' = R x)
202
+ # "bik,btdkj->btdij"
203
+
204
+ B, T, J = rotations.shape[:3]
205
+ BTJ = B * T * J
206
+ return (
207
+ self.corrective_mat_Y[:, None, None].expand(B, T, J, 3, 3).reshape(BTJ, 3, 3) @ rotations.reshape(BTJ, 3, 3)
208
+ ).reshape(B, T, J, 3, 3)
209
+
210
+ def rotate_6d_rotations(self, rotations_6d: torch.Tensor):
211
+ """Rotate 6D rotation features via matrix conversion."""
212
+ return matrix_to_cont6d(self.rotate_rotations(cont6d_to_matrix(rotations_6d)))
kimodo/motion_rep/feet.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Foot contact detection from joint positions and velocities."""
4
+
5
+ import torch
6
+
7
+ from ..tools import ensure_batched
8
+
9
+
10
+ @ensure_batched(positions=4, velocity=4)
11
+ def foot_detect_from_pos_and_vel(
12
+ positions: torch.Tensor,
13
+ velocity: torch.Tensor,
14
+ skeleton,
15
+ vel_thres: float,
16
+ height_thresh: float,
17
+ ) -> torch.Tensor:
18
+ """Compute foot contact labels using heuristics combining joint height and velocities.
19
+
20
+ Args:
21
+ positions (torch.Tensor): [X, T, J, 3] global joint positions
22
+ velocity (torch.Tensor): [X, T, J, 3] velocities (already padded correctly), already multiplied by 1 / dt
23
+ vel_thres (float): threshold for joint velocity
24
+ height_thresh (float): threshold for joint height
25
+
26
+ Returns:
27
+ torch.Tensor: [X, T, 4] contact labels for left and right foot joints
28
+ (heel/toe order follows the skeleton joint index definition), where
29
+ ``1`` denotes contact.
30
+ """
31
+
32
+ device = positions.device
33
+ # Use at most 2 foot joints per side (ankle + toe); SOMA77 defines a
34
+ # third end-effector (ToeEnd) that SOMA30 and other skeletons omit.
35
+ fid_l = skeleton.left_foot_joint_idx[:2]
36
+ fid_r = skeleton.right_foot_joint_idx[:2]
37
+
38
+ velfactor, heightfactor = (
39
+ torch.tensor([vel_thres, vel_thres], device=device),
40
+ torch.tensor([height_thresh, height_thresh], device=device),
41
+ )
42
+
43
+ feet_l_v = torch.linalg.norm(velocity[:, :, fid_l], axis=-1)
44
+ feet_l_h = positions[:, :, fid_l, 1]
45
+
46
+ feet_l = torch.logical_and(
47
+ feet_l_v < velfactor,
48
+ feet_l_h < heightfactor,
49
+ ).to(positions.dtype)
50
+
51
+ feet_r_v = torch.linalg.norm(velocity[:, :, fid_r], axis=-1)
52
+ feet_r_h = positions[:, :, fid_r, 1]
53
+
54
+ feet_r = torch.logical_and(
55
+ feet_r_v < velfactor,
56
+ feet_r_h < heightfactor,
57
+ ).to(positions.dtype)
58
+
59
+ foot_contacts = torch.cat((feet_l, feet_r), axis=-1)
60
+ return foot_contacts
kimodo/motion_rep/reps/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Motion representation implementations: base, Kimodo, and TMR."""
4
+
5
+ from .base import MotionRepBase
6
+ from .kimodo_motionrep import KimodoMotionRep
7
+ from .tmr_motionrep import TMRMotionRep
8
+
9
+ __all__ = [
10
+ "MotionRepBase",
11
+ "KimodoMotionRep",
12
+ "TMRMotionRep",
13
+ ]
kimodo/motion_rep/reps/base.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Base motion representation: feature layout, normalization, and conditioning helpers."""
4
+
5
+ import os
6
+ from typing import Optional
7
+
8
+ import einops
9
+ import numpy as np
10
+ import torch
11
+ from einops import repeat
12
+
13
+ from ...tools import ensure_batched
14
+ from ..conditioning import build_condition_dicts
15
+ from ..feature_utils import compute_vel_angle, compute_vel_xyz
16
+ from ..stats import Stats
17
+
18
+
19
+ def _require_split_stats_layout(stats_path: str) -> None:
20
+ """Raise if stats_path does not contain the required global_root, local_root, body subdirs."""
21
+ subdirs = ("global_root", "local_root", "body")
22
+ missing = []
23
+ for name in subdirs:
24
+ subpath = os.path.join(stats_path, name)
25
+ mean_path = os.path.join(subpath, "mean.npy")
26
+ if not os.path.isfile(mean_path):
27
+ missing.append(f"{subpath}/ (mean.npy)")
28
+ if missing:
29
+ raise FileNotFoundError(
30
+ f"Checkpoint stats must use the split layout with subfolders "
31
+ f"global_root/, local_root/, and body/ under '{stats_path}'. "
32
+ f"Missing or incomplete: {', '.join(missing)}. "
33
+ )
34
+
35
+
36
+ class MotionRepBase:
37
+ """Base class for motion representations used in generation and conditioning.
38
+
39
+ Subclasses define:
40
+ - ``size_dict``: feature blocks and their shapes,
41
+ - ``last_root_feature``: last entry of the root block,
42
+ - ``local_root_size_dict``: local-root feature layout,
43
+ and implement transform-specific methods such as ``__call__``, ``inverse``,
44
+ ``rotate``, ``translate_2d`` and ``create_conditions``.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ skeleton,
50
+ fps,
51
+ stats_path: Optional[str] = None,
52
+ ):
53
+ """Initialize feature slicing metadata and optional normalization stats."""
54
+
55
+ self.skeleton = skeleton
56
+ self.fps = fps
57
+ self.nbjoints = skeleton.nbjoints
58
+
59
+ self.feature_names = list(self.size_dict.keys())
60
+ self.ps = list(self.size_dict.values())
61
+ self.nfeats_dict = {key: val.numel() for key, val in self.size_dict.items()}
62
+ feats_cumsum = np.cumsum([0] + list(self.nfeats_dict.values())).tolist()
63
+ self.slice_dict = {key: slice(feats_cumsum[i], feats_cumsum[i + 1]) for i, key in enumerate(self.feature_names)}
64
+
65
+ self.motion_rep_dim = sum(self.nfeats_dict.values())
66
+ self.root_slice = slice(0, self.slice_dict[self.last_root_feature].stop)
67
+ self.body_slice = slice(self.root_slice.stop, self.motion_rep_dim)
68
+ self.body_dim = self.body_slice.stop - self.body_slice.start
69
+ self.global_root_dim = self.root_slice.stop
70
+ self.local_root_dim = sum(val.numel() for val in self.local_root_size_dict.values())
71
+
72
+ if stats_path:
73
+ _require_split_stats_layout(stats_path)
74
+ self.global_root_stats = Stats(os.path.join(stats_path, "global_root"))
75
+ self.local_root_stats = Stats(os.path.join(stats_path, "local_root"))
76
+ self.body_stats = Stats(os.path.join(stats_path, "body"))
77
+ # self.stats not set; normalize/unnormalize apply per-part below
78
+
79
+ def get_root_pos(self, features: torch.Tensor, fallback_to_smooth: bool = True):
80
+ """Extract root positions from a feature tensor.
81
+
82
+ Supports both ``root_pos`` and ``smooth_root_pos`` representations.
83
+ """
84
+ if "root_pos" in self.slice_dict:
85
+ return features[..., self.slice_dict["root_pos"]]
86
+
87
+ if "smooth_root_pos" not in self.slice_dict:
88
+ raise TypeError("This motion rep should have either a root_pos or smooth_root_pos field")
89
+
90
+ if fallback_to_smooth:
91
+ return features[:, :, self.slice_dict["smooth_root_pos"]]
92
+
93
+ # else compute the root pos from the smooth root and local joints offset
94
+ smooth_root_pos = features[:, :, self.slice_dict["smooth_root_pos"]].clone()
95
+ local_joints_positions_flatten = features[..., self.slice_dict["local_joints_positions"]]
96
+ hips_offset = local_joints_positions_flatten[..., self.skeleton.root_idx : self.skeleton.root_idx + 3]
97
+ root_pos = torch.stack(
98
+ [
99
+ smooth_root_pos[..., 0] + hips_offset[..., 0],
100
+ smooth_root_pos[..., 1],
101
+ smooth_root_pos[..., 2] + hips_offset[..., 2],
102
+ ],
103
+ axis=-1,
104
+ )
105
+ return root_pos
106
+
107
+ @ensure_batched(root_features=3, lengths=1)
108
+ def global_root_to_local_root(
109
+ self,
110
+ root_features: torch.Tensor,
111
+ normalized: bool,
112
+ lengths: Optional[torch.Tensor],
113
+ ):
114
+ """Convert global root features to local-root motion features.
115
+
116
+ Args:
117
+ root_features: Root feature tensor containing root position and
118
+ global heading, shaped ``[B, T, D_root]``.
119
+ normalized: Whether ``root_features`` are normalized.
120
+ lengths: Optional valid lengths per sequence.
121
+
122
+ Returns:
123
+ Tensor ``[B, T, 4]`` with local root rotational velocity, planar
124
+ velocity, and global root height.
125
+ """
126
+ if normalized:
127
+ root_features = self.global_root_stats.unnormalize(root_features)
128
+
129
+ [root_pos, global_root_heading] = einops.unpack(root_features, self.ps[:2], "batch time *")
130
+ cos, sin = global_root_heading.unbind(-1)
131
+ heading_angle = torch.arctan2(sin, cos)
132
+
133
+ local_root_rot_vel = compute_vel_angle(heading_angle, self.fps, lengths=lengths)
134
+ local_root_vel = compute_vel_xyz(
135
+ root_pos[..., None, :],
136
+ self.fps,
137
+ lengths=lengths,
138
+ )[..., 0, [0, 2]]
139
+ global_root_y = root_pos[..., 1]
140
+ local_root_motion = torch.cat(
141
+ [
142
+ local_root_rot_vel[..., None],
143
+ local_root_vel,
144
+ global_root_y[..., None],
145
+ ],
146
+ axis=-1,
147
+ )
148
+
149
+ if normalized:
150
+ local_root_motion = self.local_root_stats.normalize(local_root_motion)
151
+ return local_root_motion
152
+
153
+ def get_root_heading_angle(self, features: torch.Tensor) -> torch.Tensor:
154
+ """Compute root heading angle from cosine/sine heading features."""
155
+ global_root_heading = features[:, :, self.slice_dict["global_root_heading"]]
156
+ cos, sin = global_root_heading.unbind(-1)
157
+ return torch.arctan2(sin, cos)
158
+
159
+ @ensure_batched(features=3)
160
+ def rotate_to(
161
+ self,
162
+ features: torch.Tensor,
163
+ target_angle: torch.Tensor,
164
+ return_delta_angle=False,
165
+ ):
166
+ """Rotate each sequence so frame-0 heading matches ``target_angle``."""
167
+ # rotate so that the first frame angle is the target
168
+ # it put the motion_rep to the angle
169
+ current_first_angle = self.get_root_heading_angle(features)[:, 0]
170
+ delta_angle = target_angle - current_first_angle
171
+ rotated_features = self.rotate(features, delta_angle)
172
+ if return_delta_angle:
173
+ return rotated_features, delta_angle
174
+ return rotated_features
175
+
176
+ @ensure_batched(features=3)
177
+ def rotate_to_zero(
178
+ self,
179
+ features: torch.Tensor,
180
+ return_delta_angle=False,
181
+ ):
182
+ """Rotate each sequence so frame-0 heading becomes zero."""
183
+ target_angle = torch.zeros(len(features), device=features.device)
184
+ return self.rotate_to(features, target_angle, return_delta_angle=return_delta_angle)
185
+
186
+ @ensure_batched(features=3)
187
+ def randomize_first_heading(
188
+ self,
189
+ features: torch.Tensor,
190
+ return_delta_angle=False,
191
+ ) -> torch.Tensor:
192
+ """Rotate each sequence to a random frame-0 heading."""
193
+ target_heading_angle = torch.rand(features.shape[0]) * 2 * np.pi
194
+ return self.rotate_to(
195
+ features,
196
+ target_heading_angle,
197
+ return_delta_angle=return_delta_angle,
198
+ )
199
+
200
+ @ensure_batched(features=3, target_2d_pos=2)
201
+ def translate_2d_to(
202
+ self,
203
+ features: torch.Tensor,
204
+ target_2d_pos: torch.Tensor,
205
+ return_delta_pos: bool = False,
206
+ ) -> torch.Tensor:
207
+ """Translate each sequence so frame-0 root ``(x, z)`` matches a target."""
208
+ root_pos = self.get_root_pos(features)
209
+ current_first_2d_pos = root_pos[:, 0, [0, 2]].clone()
210
+ delta_2d_pos = target_2d_pos - current_first_2d_pos
211
+ translated_features = self.translate_2d(features, delta_2d_pos)
212
+ if return_delta_pos:
213
+ return translated_features, delta_2d_pos
214
+ return translated_features
215
+
216
+ @ensure_batched(features=3)
217
+ def translate_2d_to_zero(
218
+ self,
219
+ features: torch.Tensor,
220
+ return_delta_pos: bool = False,
221
+ ) -> torch.Tensor:
222
+ """Translate each sequence so frame-0 root ``(x, z)`` is at the origin."""
223
+ target_2d_pos = torch.zeros(len(features), 2, device=features.device)
224
+ return self.translate_2d_to(features, target_2d_pos, return_delta_pos=return_delta_pos)
225
+
226
+ @ensure_batched(features=3)
227
+ def canonicalize(self, features: torch.Tensor):
228
+ """Canonicalize heading and planar position at frame 0."""
229
+ rotated_features = self.rotate_to_zero(features)
230
+ return self.translate_2d_to_zero(rotated_features)
231
+
232
+ def normalize(self, features):
233
+ """Normalize features using per-part stats (global_root, local_root, body)."""
234
+ gr = slice(0, self.global_root_dim)
235
+ lr = slice(self.global_root_dim, self.global_root_dim + self.local_root_dim)
236
+ out = torch.empty_like(features, device=features.device, dtype=features.dtype)
237
+ out[..., gr] = self.global_root_stats.normalize(features[..., gr])
238
+ out[..., lr] = self.local_root_stats.normalize(features[..., lr])
239
+ out[..., self.body_slice] = self.body_stats.normalize(features[..., self.body_slice])
240
+ return out
241
+
242
+ def unnormalize(self, features):
243
+ """Undo feature normalization using per-part stats."""
244
+ gr = slice(0, self.global_root_dim)
245
+ lr = slice(self.global_root_dim, self.global_root_dim + self.local_root_dim)
246
+ out = torch.empty_like(features, device=features.device, dtype=features.dtype)
247
+ out[..., gr] = self.global_root_stats.unnormalize(features[..., gr])
248
+ out[..., lr] = self.local_root_stats.unnormalize(features[..., lr])
249
+ out[..., self.body_slice] = self.body_stats.unnormalize(features[..., self.body_slice])
250
+ return out
251
+
252
+ def create_conditions_from_constraints(
253
+ self,
254
+ constraints_lst: list,
255
+ length: int,
256
+ to_normalize: bool,
257
+ device: str,
258
+ ):
259
+ """Create a conditioning tensor and mask from constraint objects."""
260
+ index_dict, data_dict = build_condition_dicts(constraints_lst)
261
+ return self.create_conditions(index_dict, data_dict, length, to_normalize, device)
262
+
263
+ def create_conditions_from_constraints_batched(
264
+ self,
265
+ constraints_lst: list | list[list],
266
+ lengths: torch.Tensor,
267
+ to_normalize: bool,
268
+ device: str,
269
+ ):
270
+ """Batched version of ``create_conditions_from_constraints``.
271
+
272
+ Supports either one shared constraint list for all batch elements, or a per-sample list of
273
+ constraint lists.
274
+ """
275
+ num_samples = len(lengths)
276
+ if not constraints_lst or not isinstance(constraints_lst[0], list):
277
+ # If no constraints, or constraints are shared across the batch,
278
+ # build once and repeat.
279
+ observed_motion, motion_mask = self.create_conditions_from_constraints(
280
+ constraints_lst, int(lengths.max()), to_normalize, device
281
+ )
282
+ observed_motion = repeat(observed_motion, "t d -> b t d", b=num_samples)
283
+ motion_mask = repeat(motion_mask, "t d -> b t d", b=num_samples)
284
+ return observed_motion, motion_mask
285
+
286
+ length = int(lengths.max())
287
+ observed_motion_lst = []
288
+ motion_mask_lst = []
289
+ for constraints_lst_el in constraints_lst:
290
+ observed_motion, motion_mask = self.create_conditions_from_constraints(
291
+ constraints_lst_el,
292
+ length,
293
+ to_normalize,
294
+ device,
295
+ )
296
+ observed_motion_lst.append(observed_motion)
297
+ motion_mask_lst.append(motion_mask)
298
+ observed_motion = torch.stack(observed_motion_lst, axis=0)
299
+ motion_mask = torch.stack(motion_mask_lst, axis=0)
300
+ return observed_motion, motion_mask
kimodo/motion_rep/reps/kimodo_motionrep.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Optional
5
+
6
+ import einops
7
+ import torch
8
+ from torch import Tensor
9
+
10
+ from kimodo.tools import to_numpy
11
+
12
+ from ...geometry import cont6d_to_matrix, matrix_to_cont6d
13
+ from ...skeleton.kinematics import fk
14
+ from ...skeleton.transforms import global_rots_to_local_rots
15
+ from ...tools import ensure_batched
16
+ from ..conditioning import get_unique_index_and_data
17
+ from ..feature_utils import RotateFeatures, compute_heading_angle, compute_vel_xyz
18
+ from ..feet import foot_detect_from_pos_and_vel
19
+ from ..smooth_root import get_smooth_root_pos
20
+ from .base import MotionRepBase
21
+
22
+
23
+ class KimodoMotionRep(MotionRepBase):
24
+ """Global root / global joints rotations representation, relative to a smooth root."""
25
+
26
+ def __init__(
27
+ self,
28
+ skeleton,
29
+ fps,
30
+ stats_path: Optional[str] = None,
31
+ ):
32
+ nbjoints = skeleton.nbjoints
33
+
34
+ self.size_dict = {
35
+ "smooth_root_pos": torch.Size([3]),
36
+ "global_root_heading": torch.Size([2]),
37
+ "local_joints_positions": torch.Size([nbjoints, 3]),
38
+ "global_rot_data": torch.Size([nbjoints, 6]),
39
+ "velocities": torch.Size([nbjoints, 3]),
40
+ "foot_contacts": torch.Size([4]),
41
+ }
42
+ self.last_root_feature = "global_root_heading"
43
+ self.local_root_size_dict = {
44
+ "local_root_rot_vel": torch.Size([1]),
45
+ "local_root_vel": torch.Size([2]),
46
+ "global_root_y": torch.Size([1]),
47
+ }
48
+ super().__init__(skeleton, fps, stats_path)
49
+
50
+ @ensure_batched(local_joint_rots=5, root_positions=3, lengths=1)
51
+ def __call__(
52
+ self,
53
+ local_joint_rots: torch.Tensor,
54
+ root_positions: torch.Tensor,
55
+ to_normalize: bool,
56
+ lengths: Optional[torch.Tensor] = None,
57
+ ) -> torch.Tensor:
58
+ """Convert local rotations and root trajectory into smooth-root features.
59
+
60
+ Args:
61
+ local_joint_rots: Local joint rotation matrices ``[B, T, J, 3, 3]``.
62
+ root_positions: Root positions ``[B, T, 3]``.
63
+ to_normalize: Whether to normalize output features.
64
+ lengths: Optional valid lengths for variable-length batches.
65
+
66
+ Returns:
67
+ Motion features with shape ``[B, T, motion_rep_dim]``.
68
+ """
69
+ device = local_joint_rots.device
70
+ if lengths is None:
71
+ assert local_joint_rots.shape[0] == 1, "If lenghts is not provided, the input should not be batched."
72
+ lengths = torch.tensor([local_joint_rots.shape[1]], device=device)
73
+
74
+ (
75
+ global_joints_rots,
76
+ global_joints_positions,
77
+ local_joints_positions_origin_is_pelvis,
78
+ ) = fk(local_joint_rots, root_positions, self.skeleton)
79
+
80
+ root_heading_angle = compute_heading_angle(global_joints_positions, self.skeleton)
81
+ global_root_heading = torch.stack([torch.cos(root_heading_angle), torch.sin(root_heading_angle)], dim=-1)
82
+
83
+ smooth_root_pos = get_smooth_root_pos(root_positions)
84
+ hips_offset = root_positions - smooth_root_pos
85
+ hips_offset[..., 1] = root_positions[..., 1]
86
+ local_joints_positions = local_joints_positions_origin_is_pelvis + hips_offset[:, :, None]
87
+
88
+ velocities = compute_vel_xyz(global_joints_positions, self.fps, lengths=lengths)
89
+ foot_contacts = foot_detect_from_pos_and_vel(global_joints_positions, velocities, self.skeleton, 0.15, 0.10)
90
+ global_rot_data = matrix_to_cont6d(global_joints_rots)
91
+
92
+ features, _ = einops.pack(
93
+ [
94
+ smooth_root_pos,
95
+ global_root_heading,
96
+ local_joints_positions,
97
+ global_rot_data,
98
+ velocities,
99
+ foot_contacts,
100
+ ],
101
+ "batch time *",
102
+ )
103
+
104
+ if to_normalize:
105
+ features = self.normalize(features)
106
+ return features
107
+
108
+ @ensure_batched(features=3, angle=1)
109
+ def rotate(self, features: torch.Tensor, angle: torch.Tensor):
110
+ """Rotate root/joint positional and rotational features by heading."""
111
+ # assume it is not normalized
112
+ bs = features.shape[0]
113
+ device = features.device
114
+ [
115
+ smooth_root_pos,
116
+ global_root_heading,
117
+ local_joints_positions,
118
+ global_rot_data,
119
+ velocities,
120
+ foot_contacts,
121
+ ] = einops.unpack(features, self.ps, "batch time *")
122
+
123
+ if not isinstance(angle, torch.Tensor):
124
+ angle = torch.tensor(angle, device=device)
125
+ if len(angle.shape) == 0:
126
+ angle = angle.repeat(bs)
127
+
128
+ RF = RotateFeatures(angle)
129
+ new_features, _ = einops.pack(
130
+ [
131
+ RF.rotate_positions(smooth_root_pos),
132
+ RF.rotate_2d_positions(global_root_heading),
133
+ RF.rotate_positions(local_joints_positions),
134
+ RF.rotate_6d_rotations(global_rot_data),
135
+ RF.rotate_positions(velocities),
136
+ foot_contacts,
137
+ ],
138
+ "batch time *",
139
+ )
140
+ return new_features
141
+
142
+ @ensure_batched(features=3, translation_2d=2)
143
+ def translate_2d(
144
+ self,
145
+ features: torch.Tensor,
146
+ translation_2d: torch.Tensor,
147
+ ) -> torch.Tensor:
148
+ """Translate smooth root planar position by ``(dx, dz)``."""
149
+ # only move on the ground
150
+ # If we need a translate_3D function, we should not forget to move the local_joints_positions as well
151
+ bs = features.shape[0]
152
+ if len(translation_2d.shape) == 1:
153
+ translation_2d = translation_2d.repeat(bs, 1)
154
+
155
+ new_features = features.clone()
156
+ new_smooth_root_pos = new_features[:, :, self.slice_dict["smooth_root_pos"]]
157
+ new_smooth_root_pos[:, :, 0] += translation_2d[:, [0]]
158
+ new_smooth_root_pos[:, :, 2] += translation_2d[:, [1]]
159
+ return new_features
160
+
161
+ @ensure_batched(features=3)
162
+ def inverse(
163
+ self,
164
+ features: torch.Tensor,
165
+ is_normalized: bool,
166
+ posed_joints_from="rotations",
167
+ return_numpy: bool = False,
168
+ ) -> torch.Tensor:
169
+ """Decode smooth-root features into motion tensors."""
170
+ assert posed_joints_from in [
171
+ "rotations",
172
+ "positions",
173
+ ], "posed_joints_from should 'rotations' or 'positions'"
174
+
175
+ if is_normalized:
176
+ features = self.unnormalize(features)
177
+
178
+ [
179
+ smooth_root_pos,
180
+ global_root_heading,
181
+ local_joints_positions,
182
+ global_rot_data,
183
+ velocities,
184
+ foot_contacts,
185
+ ] = einops.unpack(features, self.ps, "batch time *")
186
+
187
+ global_rot_mats = cont6d_to_matrix(global_rot_data)
188
+ local_rot_mats = global_rots_to_local_rots(global_rot_mats, self.skeleton)
189
+
190
+ posed_joints_from_pos = local_joints_positions.clone()
191
+ posed_joints_from_pos[..., 0] += smooth_root_pos[..., None, 0]
192
+ posed_joints_from_pos[..., 2] += smooth_root_pos[..., None, 2]
193
+ root_positions = posed_joints_from_pos[..., self.skeleton.root_idx, :]
194
+ foot_contacts = foot_contacts > 0.5
195
+
196
+ if posed_joints_from == "rotations":
197
+ _, posed_joints, _ = self.skeleton.fk(
198
+ local_rot_mats,
199
+ root_positions,
200
+ )
201
+ else:
202
+ posed_joints = posed_joints_from_pos
203
+
204
+ output_tensor_dict = {
205
+ "local_rot_mats": local_rot_mats,
206
+ "global_rot_mats": global_rot_mats,
207
+ "posed_joints": posed_joints,
208
+ "root_positions": root_positions,
209
+ "smooth_root_pos": smooth_root_pos,
210
+ "foot_contacts": foot_contacts,
211
+ "global_root_heading": global_root_heading,
212
+ }
213
+ if return_numpy:
214
+ return to_numpy(output_tensor_dict)
215
+ return output_tensor_dict
216
+
217
+ def create_conditions(
218
+ self,
219
+ index_dict: dict[Tensor],
220
+ data_dict: dict[Tensor],
221
+ length: int,
222
+ to_normalize: bool,
223
+ device: str,
224
+ ):
225
+ """Build sparse conditioning tensors for smooth-root representation."""
226
+ # create empty features and mask to be filled in
227
+ observed_motion = torch.zeros(length, self.motion_rep_dim, device=device)
228
+ motion_mask = torch.zeros(length, self.motion_rep_dim, dtype=bool, device=device)
229
+
230
+ def _cat_indices(indices_list: list[Tensor]) -> Tensor:
231
+ indices = torch.cat([torch.tensor(x) if not isinstance(x, Tensor) else x for x in indices_list])
232
+ return indices.to(device=device, dtype=torch.long)
233
+
234
+ def _match_obs_dtype(tensor: Tensor) -> Tensor:
235
+ return tensor.to(device=device, dtype=observed_motion.dtype)
236
+
237
+ if (fname := "smooth_root_2d") in index_dict and index_dict[fname]:
238
+ indices = _cat_indices(index_dict[fname])
239
+ indices, smooth_root_2d = get_unique_index_and_data(indices, torch.cat(data_dict[fname]))
240
+ smooth_root_2d = _match_obs_dtype(smooth_root_2d)
241
+ f_sliced = observed_motion[:, self.slice_dict["smooth_root_pos"]]
242
+ f_sliced[indices, 0] = smooth_root_2d[:, 0]
243
+ f_sliced[indices, 2] = smooth_root_2d[:, 1]
244
+ m_sliced = motion_mask[:, self.slice_dict["smooth_root_pos"]]
245
+ m_sliced[indices, 0] = True
246
+ m_sliced[indices, 2] = True
247
+
248
+ if (fname := "root_y_pos") in index_dict and index_dict[fname]:
249
+ indices = _cat_indices(index_dict[fname])
250
+ indices, root_pos_Y = get_unique_index_and_data(indices, torch.cat(data_dict[fname]))
251
+ root_pos_Y = _match_obs_dtype(root_pos_Y)
252
+ f_sliced = observed_motion[:, self.slice_dict["smooth_root_pos"]]
253
+ f_sliced[indices, 1] = root_pos_Y
254
+ m_sliced = motion_mask[:, self.slice_dict["smooth_root_pos"]]
255
+ m_sliced[indices, 1] = True
256
+
257
+ if (fname := "global_root_heading") in index_dict and index_dict[fname]:
258
+ indices = _cat_indices(index_dict[fname])
259
+ indices, global_root_heading = get_unique_index_and_data(indices, torch.cat(data_dict[fname]))
260
+ global_root_heading = _match_obs_dtype(global_root_heading)
261
+ f_sliced = observed_motion[:, self.slice_dict[fname]]
262
+ f_sliced[indices] = global_root_heading
263
+ m_sliced = motion_mask[:, self.slice_dict[fname]]
264
+ m_sliced[indices] = True
265
+
266
+ if (fname := "global_joints_rots") in index_dict and index_dict[fname]:
267
+ indices_lst = _cat_indices(index_dict[fname])
268
+ indices_lst, global_joints_rots = get_unique_index_and_data(indices_lst, torch.cat(data_dict[fname]))
269
+ global_joints_rots = _match_obs_dtype(global_joints_rots)
270
+ global_rot_data = matrix_to_cont6d(global_joints_rots)
271
+ f_sliced = observed_motion[:, self.slice_dict["global_rot_data"]]
272
+ masking = torch.zeros(len(f_sliced) * self.nbjoints, 6, device=device, dtype=bool)
273
+ masking[indices_lst.T[0] * self.nbjoints + indices_lst.T[1]] = True
274
+ masking = masking.reshape(len(f_sliced), self.nbjoints * 6)
275
+ f_sliced[masking] = global_rot_data.flatten()
276
+ m_sliced = motion_mask[:, self.slice_dict["global_rot_data"]]
277
+ m_sliced[masking] = True
278
+
279
+ if (fname := "global_joints_positions") in index_dict and index_dict[fname]:
280
+ indices_lst = _cat_indices(index_dict[fname])
281
+ indices_lst, global_joints_positions = get_unique_index_and_data(indices_lst, torch.cat(data_dict[fname]))
282
+ global_joints_positions = _match_obs_dtype(global_joints_positions)
283
+ T_indices = indices_lst[:, 0].contiguous()
284
+ _test = motion_mask[T_indices, self.slice_dict["smooth_root_pos"]]
285
+ if not _test[:, [0, 2]].all():
286
+ raise ValueError("For constraining global positions, the smooth root should also be constrained.")
287
+ smooth_root_pos = observed_motion[T_indices, self.slice_dict["smooth_root_pos"]].clone()
288
+ local_reference = smooth_root_pos.clone()
289
+ local_reference[..., 1] = 0.0
290
+ local_joints_positions = global_joints_positions - local_reference
291
+ f_sliced = observed_motion[:, self.slice_dict["local_joints_positions"]]
292
+ masking = torch.zeros(len(f_sliced) * self.nbjoints, 3, device=device, dtype=bool)
293
+ masking[indices_lst.T[0] * self.nbjoints + indices_lst.T[1]] = True
294
+ masking = masking.reshape(len(f_sliced), self.nbjoints * 3)
295
+ f_sliced[masking] = local_joints_positions.flatten()
296
+ m_sliced = motion_mask[:, self.slice_dict["local_joints_positions"]]
297
+ m_sliced[masking] = True
298
+
299
+ if to_normalize:
300
+ observed_motion = self.normalize(observed_motion)
301
+ return observed_motion, motion_mask
kimodo/motion_rep/reps/tmr_motionrep.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """TMR motion representation: global root, global joints, velocities, and foot contacts."""
4
+
5
+ from typing import Optional
6
+
7
+ import einops
8
+ import torch
9
+
10
+ from ...skeleton.kinematics import fk
11
+ from ...tools import ensure_batched, to_numpy
12
+ from ..feature_utils import RotateFeatures, compute_heading_angle, compute_vel_xyz
13
+ from ..feet import foot_detect_from_pos_and_vel
14
+ from .base import MotionRepBase
15
+
16
+
17
+ class TMRMotionRep(MotionRepBase):
18
+ """Motion representation with global root and global joint positions.
19
+
20
+ Feature layout:
21
+ - root position ``(x, y, z)``
22
+ - root heading as ``(cos(theta), sin(theta))``
23
+ - local joint positions (root removed, ground-referenced)
24
+ - global joint velocities
25
+ - binary foot contacts
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ skeleton,
31
+ fps,
32
+ stats_path: Optional[str] = None,
33
+ ):
34
+ nbjoints = skeleton.nbjoints
35
+
36
+ self.size_dict = {
37
+ "root_pos": torch.Size([3]),
38
+ "global_root_heading": torch.Size([2]),
39
+ "local_joints_positions": torch.Size([nbjoints - 1, 3]),
40
+ "velocities": torch.Size([nbjoints, 3]),
41
+ "foot_contacts": torch.Size([4]),
42
+ }
43
+ self.last_root_feature = "global_root_heading"
44
+ self.local_root_size_dict = {
45
+ "local_root_rot_vel": torch.Size([1]),
46
+ "local_root_vel": torch.Size([2]),
47
+ "global_root_y": torch.Size([1]),
48
+ }
49
+ super().__init__(skeleton, fps, stats_path)
50
+
51
+ @ensure_batched(local_joint_rots=5, root_positions=3, posed_joints=4, lengths=1)
52
+ def __call__(
53
+ self,
54
+ local_joint_rots: Optional[torch.Tensor] = None,
55
+ root_positions: Optional[torch.Tensor] = None,
56
+ posed_joints: Optional[torch.Tensor] = None,
57
+ *,
58
+ to_normalize: bool,
59
+ lengths: Optional[torch.Tensor] = None,
60
+ ) -> torch.Tensor:
61
+ """Convert motion inputs to this feature representation.
62
+
63
+ Args:
64
+ local_joint_rots: Local joint rotation matrices ``[B, T, J, 3, 3]``.
65
+ Required when ``posed_joints`` is not provided.
66
+ root_positions: Root translations ``[B, T, 3]``. Required when
67
+ ``posed_joints`` is not provided.
68
+ posed_joints: Optional precomputed global joint positions
69
+ ``[B, T, J, 3]``. If passed, FK is skipped.
70
+ to_normalize: Whether to normalize output features.
71
+ lengths: Optional valid lengths for variable-length batches.
72
+
73
+ Returns:
74
+ Motion features with shape ``[B, T, motion_rep_dim]``.
75
+ """
76
+ if posed_joints is not None:
77
+ device = posed_joints.device
78
+ nbatch, nbframes, nbjoints = posed_joints.shape[:3]
79
+ else:
80
+ device = local_joint_rots.device
81
+ nbatch, nbframes, nbjoints = local_joint_rots.shape[:3]
82
+
83
+ if lengths is None:
84
+ assert nbatch == 1, "If lenghts is not provided, the input should not be batched."
85
+ lengths = torch.tensor([nbframes], device=device)
86
+
87
+ if posed_joints is None:
88
+ _, global_positions, local_joints_positions_origin_is_pelvis = fk(
89
+ local_joint_rots, root_positions, self.skeleton
90
+ )
91
+ else:
92
+ global_positions = posed_joints
93
+ root_positions = posed_joints[:, :, 0]
94
+ local_joints_positions_origin_is_pelvis = posed_joints - root_positions[:, :, None]
95
+
96
+ root_heading_angle = compute_heading_angle(global_positions, self.skeleton)
97
+ global_root_heading = torch.stack([torch.cos(root_heading_angle), torch.sin(root_heading_angle)], dim=-1)
98
+
99
+ ground_offset = 0 * root_positions
100
+ ground_offset[..., 1] = root_positions[..., 1]
101
+ local_joints_positions = local_joints_positions_origin_is_pelvis[:, :, 1:] + ground_offset[:, :, None]
102
+ velocities = compute_vel_xyz(global_positions, self.fps, lengths=lengths)
103
+ foot_contacts = foot_detect_from_pos_and_vel(global_positions, velocities, self.skeleton, 0.15, 0.10)
104
+
105
+ features, _ = einops.pack(
106
+ [
107
+ root_positions,
108
+ global_root_heading,
109
+ local_joints_positions,
110
+ velocities,
111
+ foot_contacts,
112
+ ],
113
+ "batch time *",
114
+ )
115
+
116
+ if to_normalize:
117
+ features = self.normalize(features)
118
+ return features
119
+
120
+ @ensure_batched(features=3, angle=1)
121
+ def rotate(self, features: torch.Tensor, angle: torch.Tensor):
122
+ """Rotate all spatial features by a heading delta (radians)."""
123
+ # rotate by the angle
124
+ # it add the angle to the current features
125
+ # assume it is not normalized
126
+ bs = features.shape[0]
127
+ device = features.device
128
+ [
129
+ root_pos,
130
+ global_root_heading,
131
+ local_joints_positions,
132
+ velocities,
133
+ foot_contacts,
134
+ ] = einops.unpack(features, self.ps, "batch time *")
135
+
136
+ if not isinstance(angle, torch.Tensor):
137
+ angle = torch.tensor(angle, device=device)
138
+ if len(angle.shape) == 0:
139
+ angle = angle.repeat(bs)
140
+
141
+ RF = RotateFeatures(angle)
142
+ new_features, _ = einops.pack(
143
+ [
144
+ RF.rotate_positions(root_pos),
145
+ RF.rotate_2d_positions(global_root_heading),
146
+ RF.rotate_positions(local_joints_positions),
147
+ RF.rotate_positions(velocities),
148
+ foot_contacts,
149
+ ],
150
+ "batch time *",
151
+ )
152
+ return new_features
153
+
154
+ @ensure_batched(features=3, translation_2d=2)
155
+ def translate_2d(
156
+ self,
157
+ features: torch.Tensor,
158
+ translation_2d: torch.Tensor,
159
+ ) -> torch.Tensor:
160
+ """Translate root planar position by ``(dx, dz)``."""
161
+ # only move on the ground
162
+ # For 3D, we should not forget to move the local_joints_positions as well
163
+ bs = features.shape[0]
164
+ if len(translation_2d.shape) == 1:
165
+ translation_2d = translation_2d.repeat(bs, 1)
166
+
167
+ new_features = features.clone()
168
+ new_root_pos = new_features[:, :, self.slice_dict["root_pos"]]
169
+ new_root_pos[:, :, 0] += translation_2d[:, 0]
170
+ new_root_pos[:, :, 2] += translation_2d[:, 1]
171
+ return new_features
172
+
173
+ @ensure_batched(features=3)
174
+ def inverse(
175
+ self,
176
+ features: torch.Tensor,
177
+ is_normalized: bool,
178
+ posed_joints_from="positions",
179
+ return_numpy: bool = False,
180
+ ) -> torch.Tensor:
181
+ """Decode features back to a motion dictionary.
182
+
183
+ Args:
184
+ features: Feature tensor ``[B, T, D]``.
185
+ is_normalized: Whether input features are normalized.
186
+ posed_joints_from: Must be ``"positions"`` for this representation.
187
+ return_numpy: Whether to convert tensors to numpy arrays.
188
+
189
+ Returns:
190
+ Dictionary containing reconstructed positions and auxiliary data.
191
+ """
192
+ assert posed_joints_from == "positions"
193
+ if is_normalized:
194
+ features = self.unnormalize(features)
195
+
196
+ [
197
+ root_positions,
198
+ global_root_heading,
199
+ local_joints_positions,
200
+ velocities,
201
+ foot_contacts,
202
+ ] = einops.unpack(features, self.ps, "batch time *")
203
+
204
+ dummy_root = 0 * local_joints_positions[:, :, [0]]
205
+ posed_joints_from_pos = torch.stack([dummy_root, local_joints_positions], axis=2)
206
+ posed_joints_from_pos[..., 0] += root_positions[..., None, 0]
207
+ posed_joints_from_pos[..., 2] += root_positions[..., None, 2]
208
+ root_positions = posed_joints_from_pos[..., self.skeleton.root_idx, :]
209
+ foot_contacts = foot_contacts > 0.5
210
+ posed_joints = posed_joints_from_pos
211
+
212
+ output_tensor_dict = {
213
+ "local_rot_mats": None,
214
+ "global_rot_mats": None,
215
+ "posed_joints": posed_joints,
216
+ "root_positions": root_positions,
217
+ "foot_contacts": foot_contacts,
218
+ "global_root_heading": global_root_heading,
219
+ }
220
+ if return_numpy:
221
+ return to_numpy(output_tensor_dict)
222
+ return output_tensor_dict
kimodo/motion_rep/smooth_root.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Smooth root trajectory: ADMM-based smoother with margin constraints and get_smooth_root_pos helper."""
4
+
5
+ import math
6
+
7
+ import numpy as np
8
+ import torch
9
+ from scipy import sparse
10
+ from scipy.sparse.linalg import splu
11
+
12
+ from kimodo.tools import ensure_batched
13
+
14
+
15
+ class TrajectorySmoother:
16
+ """Modify trajectories to hit target values while respecting soft constraints.
17
+
18
+ This smoother keeps the trajectory close to the original positions while minimizing
19
+ accelerations. Targets are enforced at specified frames via soft constraints.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ margins,
25
+ pos_weight=0.0,
26
+ loop=False,
27
+ admm_iters=100,
28
+ alpha_overrelax=1.0,
29
+ circle_project=False,
30
+ ):
31
+ """Initialize the TrajectorySmoother.
32
+
33
+ Args:
34
+ margins: Array of margin values for each frame.
35
+ margins[i] < 0: unconstrained
36
+ margins[i] == 0: pinned on this frame
37
+ margins[i] > 0: can deviate within the margin
38
+ pos_weight: Weight for position preservation
39
+ loop: Whether the trajectory should loop
40
+ admm_iters: Number of ADMM iterations
41
+ """
42
+ self.pos_weight = pos_weight
43
+ self.admm_iters = admm_iters
44
+ self.alpha_overrelax = alpha_overrelax
45
+ self.circle_project = circle_project
46
+ N = len(margins)
47
+
48
+ # Store margin information as numpy arrays
49
+ self.margin_vals = margins
50
+
51
+ # Build acceleration matrix A
52
+ a_data = []
53
+ a_rows = []
54
+ a_cols = []
55
+
56
+ for i in range(1, N - 1):
57
+ scale = 1.0
58
+ a_data.extend([-scale, 2.0 * scale, -scale])
59
+ a_rows.extend([i, i, i])
60
+ a_cols.extend([i - 1, i, i + 1])
61
+
62
+ if loop:
63
+ # Add periodic accelerations
64
+ scale = 1.0
65
+ a_data.extend([-scale, 2.0 * scale, -scale])
66
+ a_rows.extend([0, 0, 0])
67
+ a_cols.extend([N - 1, 0, 1])
68
+
69
+ scale = 1.0
70
+ a_data.extend([-scale, 2.0 * scale, -scale])
71
+ a_rows.extend([N - 1, N - 1, N - 1])
72
+ a_cols.extend([N - 2, N - 1, 0])
73
+
74
+ A = sparse.csr_matrix((a_data, (a_rows, a_cols)), shape=(N, N))
75
+
76
+ # Build identity matrix
77
+ identity_matrix = sparse.eye(N)
78
+
79
+ # Build system matrix M
80
+ M = pos_weight * identity_matrix + A.T @ A
81
+
82
+ # Calculate ADMM step size
83
+ diag_max = max(abs(M.diagonal()))
84
+ self.admm_stepsize = 0.25 * np.sqrt(diag_max)
85
+
86
+ M = M + self.admm_stepsize * identity_matrix
87
+ self.system_lu = splu(M.tocsc())
88
+
89
+ def smooth(self, targets, x0):
90
+ """Interpolate between reference positions while satisfying constraints.
91
+
92
+ Args:
93
+ observations: Target positions for constrained frames (numpy array)
94
+ ref_positions: Reference positions defining original shape
95
+ (numpy array)
96
+
97
+ Returns:
98
+ Interpolated positions (numpy array)
99
+ """
100
+ x_target = targets.copy()
101
+ x = x0.copy()
102
+ z = np.zeros_like(x)
103
+ u = np.zeros_like(x)
104
+
105
+ for _ in range(self.admm_iters):
106
+ self.z_update(z, x, x_target, u)
107
+ self.u_update(u, x, z)
108
+ self.x_update(x, z, u, x_target)
109
+
110
+ return x
111
+
112
+ def x_update(self, x, z, u, x_t):
113
+ """Update x in the ADMM iteration."""
114
+
115
+ # x = (wp * I + A^T A + p I)^-1 (wp * x_orig + p (z - u))
116
+ r = self.pos_weight * x_t + self.admm_stepsize * (z - u)
117
+ x[:] = self.system_lu.solve(r)
118
+
119
+ def z_update(self, z, x, z_t, u):
120
+ """Update z in the ADMM iteration using vectorized operations."""
121
+ # Compute the difference from target for all margin locations at once
122
+ z[:] = x + u - z_t
123
+
124
+ # Check if we need to project back to margin
125
+ z_diff_norms = np.linalg.norm(z, axis=1)
126
+ mask = z_diff_norms > self.margin_vals
127
+ if np.any(mask):
128
+ scale_factors = self.margin_vals[mask] / z_diff_norms[mask]
129
+ z[mask] *= scale_factors[:, np.newaxis]
130
+
131
+ # Add back the target
132
+ z[:] += z_t
133
+
134
+ if self.circle_project:
135
+ z[:] = z / (np.linalg.norm(z, axis=1, keepdims=True) + 1.0e-6)
136
+
137
+ def u_update(self, u, x, z):
138
+ """Update u in the ADMM iteration using vectorized operations."""
139
+ u[:] += self.alpha_overrelax * (x - z)
140
+
141
+
142
+ def smooth_signal(x, margins, pos_weight=0, alpha_overrelax=1.8, admm_iters=500, circle_project=False):
143
+ """Multigrid trajectory smoothing with margin constraints.
144
+
145
+ Args:
146
+ x: Input trajectory ``[T, D]`` as a NumPy array.
147
+ margins: Allowed radius around each target frame ``[T]``.
148
+ pos_weight: Weight for staying close to the original signal.
149
+ alpha_overrelax: ADMM over-relaxation coefficient.
150
+ admm_iters: ADMM iterations per multigrid level.
151
+ circle_project: If ``True``, project each vector to the unit sphere.
152
+
153
+ Returns:
154
+ Smoothed trajectory of shape ``[T, D]``.
155
+ """
156
+ x_smoothed = x.copy()
157
+ x_smoothed[:] = x.mean(axis=0, keepdims=True)
158
+
159
+ # smooth the signal, multigrid style by starting out coarse,
160
+ # doubling the resolution and repeating until we're at the full
161
+ # resolution, using the previous result as the initial guess.
162
+ levels = int(math.floor(math.log2(len(x))))
163
+ levels = max(levels - 4, 1)
164
+
165
+ stepsize = 2**levels
166
+ while True:
167
+ # smooth signals at this level:
168
+ num_steps = len(x_smoothed[::stepsize])
169
+ smoother = TrajectorySmoother(
170
+ margins=margins[::stepsize],
171
+ pos_weight=pos_weight,
172
+ alpha_overrelax=alpha_overrelax,
173
+ admm_iters=admm_iters,
174
+ circle_project=circle_project,
175
+ )
176
+ x_smoothed[::stepsize] = smoother.smooth(x[::stepsize], x_smoothed[::stepsize])
177
+
178
+ # interpolate to next level:
179
+ next_stepsize = stepsize // 2
180
+ num_interleaved = len(x_smoothed[next_stepsize::stepsize])
181
+ if num_interleaved == num_steps:
182
+ # linearly extrapolate the last value if we have to:
183
+ x_smoothed[next_stepsize::stepsize][-1] = (
184
+ x_smoothed[::stepsize][-1] + (x_smoothed[::stepsize][-1] - x_smoothed[::stepsize][-2]) / 2
185
+ )
186
+ num_interleaved = num_interleaved - 1
187
+
188
+ # linearly interpolate the remaining values:
189
+ x_smoothed[next_stepsize::stepsize][:num_interleaved] = (
190
+ x_smoothed[::stepsize][:-1] + x_smoothed[::stepsize][1:]
191
+ ) / 2
192
+
193
+ if stepsize == 1:
194
+ break
195
+
196
+ stepsize //= 2
197
+
198
+ return x_smoothed
199
+
200
+
201
+ @ensure_batched(hip_translations=3)
202
+ def get_smooth_root_pos(hip_translations):
203
+ """Smooth root trajectory in the ground plane while preserving height.
204
+
205
+ Args:
206
+ hip_translations: Root translations ``[B, T, 3]``.
207
+
208
+ Returns:
209
+ Smoothed root translations ``[B, T, 3]`` where ``x/z`` are smoothed and
210
+ ``y`` remains unchanged.
211
+ """
212
+ root_translations_xz = hip_translations[..., [0, 2]]
213
+ root_translations_y = hip_translations[..., [1]]
214
+
215
+ batch_size, nframes = root_translations_xz.shape[:2]
216
+ margins = np.full(root_translations_xz.shape[1], 0.06)
217
+
218
+ root_translations_smoothed_xz = []
219
+ for batch in range(batch_size):
220
+ root_translations_smoothed_xz.append(
221
+ smooth_signal(root_translations_xz[batch].detach().cpu().numpy(), margins)[None]
222
+ )
223
+
224
+ root_translations_smoothed_xz = torch.tensor(np.concatenate(root_translations_smoothed_xz))
225
+
226
+ root_translations = torch.cat(
227
+ [
228
+ root_translations_smoothed_xz.to(root_translations_y.device),
229
+ root_translations_y,
230
+ ],
231
+ dim=-1,
232
+ )[..., [0, 2, 1]]
233
+
234
+ return root_translations
kimodo/motion_rep/stats.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Feature normalization statistics (mean/std) for motion representations."""
4
+
5
+ import logging
6
+ import os
7
+ from typing import Optional
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ class Stats(torch.nn.Module):
16
+ """Utility module for feature normalization statistics.
17
+
18
+ Normalization follows:
19
+ ``(data - mean) / sqrt(std**2 + eps)``
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ folder: Optional[str] = None,
25
+ load: bool = True,
26
+ eps=1e-05,
27
+ ):
28
+ super().__init__()
29
+ self.folder = folder
30
+ self.eps = eps
31
+ if folder is not None and load:
32
+ self.load()
33
+
34
+ def sliced(self, indices):
35
+ """Return a new ``Stats`` object containing selected feature indices."""
36
+ new_stats = Stats(folder=self.folder, load=False, eps=self.eps)
37
+ new_stats.register_from_tensors(
38
+ self.mean[..., indices].clone(),
39
+ self.std[..., indices].clone(),
40
+ )
41
+ return new_stats
42
+
43
+ def load(self):
44
+ """Load ``mean.npy`` and ``std.npy`` from ``self.folder``."""
45
+ mean_path = os.path.join(self.folder, "mean.npy")
46
+ std_path = os.path.join(self.folder, "std.npy")
47
+ if not os.path.exists(mean_path) or not os.path.exists(std_path):
48
+ raise FileNotFoundError(
49
+ f"Missing stats files in '{self.folder}'. Expected:\n"
50
+ f" - {mean_path}\n"
51
+ f" - {std_path}\n\n"
52
+ "Make sure the checkpoint/stats have been downloaded and are mounted into the container.\n"
53
+ "If you're using Docker Compose, run it from the repo root so `./:/workspace` mounts the correct directory."
54
+ )
55
+
56
+ mean = torch.from_numpy(np.load(mean_path))
57
+ std = torch.from_numpy(np.load(std_path))
58
+ self.register_from_tensors(mean, std)
59
+
60
+ def register_from_tensors(self, mean: torch.Tensor, std: torch.Tensor):
61
+ """Register mean/std tensors as non-persistent buffers."""
62
+ self.register_buffer("mean", mean, persistent=False)
63
+ self.register_buffer("std", std, persistent=False)
64
+
65
+ def normalize(self, data: torch.Tensor) -> torch.Tensor:
66
+ """Normalize data using the stored statistics."""
67
+ mean = self.mean.to(device=data.device, dtype=data.dtype)
68
+ std = self.std.to(device=data.device, dtype=data.dtype)
69
+ # adjust std with eps
70
+ return (data - mean) / torch.sqrt(std**2 + self.eps)
71
+
72
+ def unnormalize(self, data: torch.Tensor) -> torch.Tensor:
73
+ """Undo normalization using the stored statistics."""
74
+ mean = self.mean.to(device=data.device, dtype=data.dtype)
75
+ std = self.std.to(device=data.device, dtype=data.dtype)
76
+ # adjust std with eps
77
+ return data * torch.sqrt(std**2 + self.eps) + mean
78
+
79
+ def is_loaded(self):
80
+ """Return whether statistics are currently available."""
81
+ return hasattr(self, "mean")
82
+
83
+ def get_dim(self):
84
+ """Return feature dimensionality."""
85
+ return self.mean.shape[0]
86
+
87
+ def save(
88
+ self,
89
+ folder: Optional[str] = None,
90
+ mean: Optional[torch.Tensor] = None,
91
+ std: Optional[torch.Tensor] = None,
92
+ ):
93
+ """Save statistics to ``folder`` as ``mean.npy`` and ``std.npy``."""
94
+ if folder is None:
95
+ folder = self.folder
96
+ if folder is None:
97
+ raise ValueError("No folder to save stats")
98
+
99
+ if mean is None and std is None:
100
+ try:
101
+ mean = self.mean.cpu().numpy()
102
+ std = self.std.cpu().numpy()
103
+ except AttributeError:
104
+ raise ValueError("Stats were not loaded")
105
+
106
+ # don't override stats folder
107
+ os.makedirs(folder, exist_ok=False)
108
+
109
+ np.save(os.path.join(folder, "mean.npy"), mean)
110
+ np.save(os.path.join(folder, "std.npy"), std)
111
+
112
+ def __eq__(self, other):
113
+ return (self.mean.cpu() == other.mean.cpu()).all() and (self.std.cpu() == other.std.cpu()).all()
114
+
115
+ # should define a hash value for pytorch, as we defined __eq__
116
+ def __hash__(self):
117
+ # Convert mean and std to bytes for a consistent hash value
118
+ mean_hash = hash(self.mean.detach().cpu().numpy().tobytes())
119
+ std_hash = hash(self.std.detach().cpu().numpy().tobytes())
120
+ return hash((mean_hash, std_hash))
121
+
122
+ def __repr__(self):
123
+ return f'Stats(folder="{self.folder}")'
kimodo/pipeline/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pipeline utilities for prompt/script to Kimodo generation flows."""
2
+
3
+ from .blend_quality import (
4
+ BlendGuardrailConfig,
5
+ TransitionSettings,
6
+ apply_transition_guardrails,
7
+ harmonize_scene_transitions,
8
+ )
9
+ from .script_to_kimodo import (
10
+ CharacterKimodoPlan,
11
+ build_character_plan,
12
+ generator_request_to_plans,
13
+ run_multi_character_generation,
14
+ )
15
+ from .scheduler_runtime import SceneScheduleResult, run_scheduled_scene
16
+
17
+ __all__ = [
18
+ "CharacterKimodoPlan",
19
+ "BlendGuardrailConfig",
20
+ "TransitionSettings",
21
+ "apply_transition_guardrails",
22
+ "harmonize_scene_transitions",
23
+ "build_character_plan",
24
+ "generator_request_to_plans",
25
+ "run_multi_character_generation",
26
+ "SceneScheduleResult",
27
+ "run_scheduled_scene",
28
+ ]
kimodo/pipeline/blend_quality.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Card 7 blend quality guardrails for transition blending safety and consistency."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+
8
+ @dataclass(frozen=True)
9
+ class TransitionSettings:
10
+ """Transition settings passed to Kimodo generation."""
11
+
12
+ num_transition_frames: int
13
+ share_transition: bool
14
+ percentage_transition_override: float
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class BlendGuardrailConfig:
19
+ """Runtime safety bounds for transition blending."""
20
+
21
+ min_transition_frames: int = 1
22
+ max_transition_frames: int = 12
23
+ min_segment_frames_for_share: int = 12
24
+ max_transition_ratio: float = 0.30
25
+ max_shared_window_frames: int = 24
26
+ harmonize_window: int = 2
27
+
28
+
29
+ def _clamp(value: float, low: float, high: float) -> float:
30
+ return max(low, min(high, value))
31
+
32
+
33
+ def apply_transition_guardrails(
34
+ segment_frames: list[int],
35
+ policies: list[str],
36
+ requested: TransitionSettings,
37
+ *,
38
+ config: BlendGuardrailConfig = BlendGuardrailConfig(),
39
+ ) -> TransitionSettings:
40
+ """Clamp transition settings to safe ranges for short/long segments.
41
+
42
+ Guardrails avoid transition windows that dominate short segments and reduce blending artifacts
43
+ for scripted interactions.
44
+ """
45
+ if len(segment_frames) < 2:
46
+ safe_frames = int(_clamp(requested.num_transition_frames, config.min_transition_frames, config.max_transition_frames))
47
+ return TransitionSettings(
48
+ num_transition_frames=safe_frames,
49
+ share_transition=False,
50
+ percentage_transition_override=0.0,
51
+ )
52
+
53
+ min_prev = min(segment_frames[:-1])
54
+ min_next = min(segment_frames[1:])
55
+ # Keep at least one non-transition frame in the shortest pair.
56
+ shortest_pair_budget = max(config.min_transition_frames, min(min_prev, min_next) - 1)
57
+
58
+ safe_frames = int(
59
+ _clamp(
60
+ requested.num_transition_frames,
61
+ config.min_transition_frames,
62
+ min(config.max_transition_frames, shortest_pair_budget),
63
+ )
64
+ )
65
+
66
+ has_cut = "cut" in policies
67
+ can_share = (
68
+ requested.share_transition
69
+ and not has_cut
70
+ and min_prev >= config.min_segment_frames_for_share
71
+ and min_next >= config.min_segment_frames_for_share
72
+ )
73
+
74
+ if not can_share:
75
+ return TransitionSettings(
76
+ num_transition_frames=safe_frames,
77
+ share_transition=False,
78
+ percentage_transition_override=0.0,
79
+ )
80
+
81
+ safe_pct = _clamp(requested.percentage_transition_override, 0.0, config.max_transition_ratio)
82
+
83
+ # Cap shared overlap by configured hard ceiling and shortest-pair budget.
84
+ max_pct_from_shared_window = max(0.0, (config.max_shared_window_frames - safe_frames) / max(1, min_prev))
85
+ max_pct_from_shortest_pair = max(0.0, (shortest_pair_budget - safe_frames) / max(1, min_prev))
86
+ safe_pct = min(safe_pct, max_pct_from_shared_window, max_pct_from_shortest_pair)
87
+
88
+ return TransitionSettings(
89
+ num_transition_frames=safe_frames,
90
+ share_transition=True,
91
+ percentage_transition_override=float(safe_pct),
92
+ )
93
+
94
+
95
+ def harmonize_scene_transitions(
96
+ settings_by_character: dict[str, TransitionSettings],
97
+ *,
98
+ config: BlendGuardrailConfig = BlendGuardrailConfig(),
99
+ ) -> dict[str, TransitionSettings]:
100
+ """Nudge transition-frame counts toward a scene median for multi-character consistency."""
101
+ if len(settings_by_character) < 2:
102
+ return settings_by_character
103
+
104
+ frame_values = sorted(setting.num_transition_frames for setting in settings_by_character.values())
105
+ median = frame_values[len(frame_values) // 2]
106
+ low = max(config.min_transition_frames, median - config.harmonize_window)
107
+ high = min(config.max_transition_frames, median + config.harmonize_window)
108
+
109
+ harmonized: dict[str, TransitionSettings] = {}
110
+ for character_id, setting in settings_by_character.items():
111
+ harmonized[character_id] = TransitionSettings(
112
+ num_transition_frames=int(_clamp(setting.num_transition_frames, low, high)),
113
+ share_transition=setting.share_transition,
114
+ percentage_transition_override=setting.percentage_transition_override,
115
+ )
116
+ return harmonized
kimodo/pipeline/scheduler_runtime.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Card 8 runtime orchestration: deterministic multi-character scheduling."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from dataclasses import dataclass
7
+ from typing import Any, Optional
8
+
9
+ from kimodo.pipeline.script_to_kimodo import run_multi_character_generation
10
+ from kimodo.schemas import GeneratorRequest
11
+ from kimodo.scheduler import (
12
+ CharacterState,
13
+ CharacterSegmentState,
14
+ ConflictResolutionPolicy,
15
+ DeterministicLoop,
16
+ )
17
+
18
+ LOGGER = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class SceneScheduleResult:
23
+ """Structured result for scheduled scene execution."""
24
+
25
+ outputs: dict[str, dict[str, Any]]
26
+ errors: dict[str, str]
27
+ plans: dict[str, Any]
28
+ state_hashes: list[str]
29
+ interactions: list[tuple[int, str, str]]
30
+ completed_segments: dict[str, int]
31
+
32
+
33
+ def _activate_next_segment(loop: DeterministicLoop, character_id: str, plan: Any, segment_index: int) -> None:
34
+ """Set active segment in loop state for one character."""
35
+ slot = loop.characters[character_id]
36
+ slot.segment_state = CharacterSegmentState(
37
+ character_id=character_id,
38
+ segment_index=segment_index,
39
+ frames_elapsed=0,
40
+ total_frames=plan.num_frames[segment_index],
41
+ )
42
+ segment = plan.segment_transition_policies[segment_index]
43
+ # Interaction target is encoded in planner request segments; set later in per-tick update.
44
+ slot.current_state = CharacterState.BUSY if segment != "cut" else CharacterState.TRANSITIONING
45
+
46
+
47
+ def run_scheduled_scene(
48
+ model: Any,
49
+ request: GeneratorRequest,
50
+ *,
51
+ fps: float,
52
+ seed: int = 42,
53
+ conflict_policy: ConflictResolutionPolicy = ConflictResolutionPolicy.COOLDOWN,
54
+ diffusion_steps: int = 100,
55
+ cfg_weight: Optional[list[float]] = None,
56
+ cfg_type: Optional[str] = None,
57
+ post_processing: bool = True,
58
+ root_margin: float = 0.04,
59
+ constraint_resolver: Optional[Any] = None,
60
+ continue_on_error: bool = False,
61
+ ) -> SceneScheduleResult:
62
+ """Run generation then deterministic timeline scheduling for all characters in a scene."""
63
+ LOGGER.info("card8.run_scheduled_scene.start scene_id=%s chars=%s", request.scene_id, len(request.characters))
64
+
65
+ outputs, errors, plans = run_multi_character_generation(
66
+ model,
67
+ request,
68
+ fps=fps,
69
+ diffusion_steps=diffusion_steps,
70
+ cfg_weight=cfg_weight,
71
+ cfg_type=cfg_type,
72
+ post_processing=post_processing,
73
+ root_margin=root_margin,
74
+ constraint_resolver=constraint_resolver,
75
+ continue_on_error=continue_on_error,
76
+ )
77
+
78
+ loop = DeterministicLoop(
79
+ fps=int(fps),
80
+ seed=seed,
81
+ conflict_policy=conflict_policy,
82
+ )
83
+
84
+ for priority, character in enumerate(request.characters):
85
+ loop.register_character(character.character_id, character.skeleton_type, priority=priority)
86
+
87
+ segment_indices = {character.character_id: 0 for character in request.characters}
88
+ completed_segments = {character.character_id: 0 for character in request.characters}
89
+
90
+ for character in request.characters:
91
+ plan = plans.get(character.character_id)
92
+ if plan is None:
93
+ continue
94
+ if not plan.num_frames:
95
+ continue
96
+ _activate_next_segment(loop, character.character_id, plan, segment_index=0)
97
+ first_segment = character.segments[0]
98
+ loop.characters[character.character_id].interaction_target = first_segment.interaction_target
99
+
100
+ total_scene_frames = max((plan.total_frames for plan in plans.values()), default=0)
101
+ state_hashes: list[str] = []
102
+ interactions: list[tuple[int, str, str]] = []
103
+
104
+ for _ in range(total_scene_frames):
105
+ tick = loop.advance_tick({})
106
+ state_hashes.append(loop.get_state_hash())
107
+
108
+ for winner, loser in tick.interactions:
109
+ interactions.append((tick.tick_number, winner, loser))
110
+
111
+ for character_id in tick.completed_segments:
112
+ plan = plans.get(character_id)
113
+ if plan is None:
114
+ continue
115
+ completed_segments[character_id] += 1
116
+ next_index = segment_indices[character_id] + 1
117
+ if next_index < len(plan.num_frames):
118
+ segment_indices[character_id] = next_index
119
+ _activate_next_segment(loop, character_id, plan, next_index)
120
+ source_char = next(c for c in request.characters if c.character_id == character_id)
121
+ loop.characters[character_id].interaction_target = source_char.segments[next_index].interaction_target
122
+ else:
123
+ loop.characters[character_id].segment_state = None
124
+ loop.characters[character_id].interaction_target = None
125
+
126
+ LOGGER.info(
127
+ "card8.run_scheduled_scene.exit scene_id=%s hashes=%s interactions=%s",
128
+ request.scene_id,
129
+ len(state_hashes),
130
+ len(interactions),
131
+ )
132
+ return SceneScheduleResult(
133
+ outputs=outputs,
134
+ errors=errors,
135
+ plans=plans,
136
+ state_hashes=state_hashes,
137
+ interactions=interactions,
138
+ completed_segments=completed_segments,
139
+ )