diff --git a/kimodo/__init__.py b/kimodo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed1fbe045195bcbb460c57abed1c87665eb50974 --- /dev/null +++ b/kimodo/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Kimodo: text-driven and constrained motion generation model.""" + +from .model.load_model import AVAILABLE_MODELS, DEFAULT_MODEL, load_model + +__all__ = [ + "AVAILABLE_MODELS", + "DEFAULT_MODEL", + "load_model", +] diff --git a/kimodo/assets.py b/kimodo/assets.py new file mode 100644 index 0000000000000000000000000000000000000000..91facad0faed373c9ed1ad4667980cf19788b093 --- /dev/null +++ b/kimodo/assets.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +PACKAGE_ROOT = Path(__file__).resolve().parent +ASSETS_ROOT = PACKAGE_ROOT / "assets" +DEMO_ASSETS_ROOT = ASSETS_ROOT / "demo" +DEMO_EXAMPLES_ROOT = DEMO_ASSETS_ROOT / "examples" +SKELETONS_ROOT = ASSETS_ROOT / "skeletons" +SOMA_ASSETS_ROOT = ASSETS_ROOT / "SOMA" + + +def skeleton_asset_path(*parts: str) -> Path: + return SKELETONS_ROOT.joinpath(*parts) + + +def demo_asset_path(*parts: str) -> Path: + return DEMO_ASSETS_ROOT.joinpath(*parts) diff --git a/kimodo/constraints.py b/kimodo/constraints.py new file mode 100644 index 0000000000000000000000000000000000000000..7accd150dc5ccf8481442fc885562e92ed269765 --- /dev/null +++ b/kimodo/constraints.py @@ -0,0 +1,625 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Constraint sets for conditioning motion generation (root 2D, full body, end-effectors).""" + +from typing import Optional, Union + +import torch +from torch import Tensor + +from kimodo.motion_rep.feature_utils import compute_heading_angle +from kimodo.skeleton import SkeletonBase, SOMASkeleton30, SOMASkeleton77 +from kimodo.tools import ensure_batched, load_json, save_json + +from .geometry import axis_angle_to_matrix, matrix_to_axis_angle + + +def _convert_constraint_local_rots_to_skeleton(local_rot_mats: Tensor, skeleton: SkeletonBase) -> Tensor: + """Convert loaded local rotation matrices to match the skeleton's joint count. + + Handles SOMA 30↔77: constraint files may have been saved with 30 or 77 joints while the session + skeleton (e.g. from the SOMA30 model) uses SOMASkeleton77. + """ + n_joints = local_rot_mats.shape[-3] + skeleton_joints = skeleton.nbjoints + if n_joints == skeleton_joints: + return local_rot_mats + if n_joints == 77 and skeleton_joints == 30 and isinstance(skeleton, SOMASkeleton30): + return skeleton.from_SOMASkeleton77(local_rot_mats) + if n_joints == 30 and skeleton_joints == 77 and isinstance(skeleton, SOMASkeleton77): + skel30 = SOMASkeleton30() + return skel30.to_SOMASkeleton77(local_rot_mats) + raise ValueError( + f"Constraint joint count ({n_joints}) does not match skeleton joint count " + f"({skeleton_joints}). Only SOMA 30↔77 conversion is supported." + ) + + +def create_pairs(tensor_A: Tensor, tensor_B: Tensor) -> Tensor: + """Form all (a, b) pairs from two 1D tensors; output shape (len(A)*len(B), 2).""" + pairs = torch.stack( + ( + tensor_A[:, None].expand(-1, len(tensor_B)), + tensor_B.expand(len(tensor_A), -1), + ), + dim=-1, + ).reshape(-1, 2) + return pairs + + +def compute_global_heading(global_joints_positions: Tensor, skeleton: SkeletonBase) -> Tensor: + """Compute global root heading (cos, sin) from global joint positions using skeleton.""" + root_heading_angle = compute_heading_angle(global_joints_positions, skeleton) + global_root_heading = torch.stack([torch.cos(root_heading_angle), torch.sin(root_heading_angle)], dim=-1) + return global_root_heading + + +def _tensor_to( + t: Tensor, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, +) -> Tensor: + """Move tensor to device and/or dtype. + + Returns same tensor if no args. + """ + if device is not None and dtype is not None: + return t.to(device=device, dtype=dtype) + if device is not None: + return t.to(device=device) + if dtype is not None: + return t.to(dtype=dtype) + return t + + +class Root2DConstraintSet: + """Constraint set fixing root (x, z) trajectory and optionally global heading on given + frames.""" + + name = "root2d" + + def __init__( + self, + skeleton: SkeletonBase, + frame_indices: Tensor, + smooth_root_2d: Tensor, + to_crop: bool = False, + global_root_heading: Optional[Tensor] = None, + ) -> None: + self.skeleton = skeleton + + # if we pass the full smooth root 3D as input + if smooth_root_2d.shape[-1] == 3: + smooth_root_2d = smooth_root_2d[..., [0, 1]] + + if to_crop: + smooth_root_2d = smooth_root_2d[frame_indices] + if global_root_heading is not None: + global_root_heading = global_root_heading[frame_indices] + else: + assert len(smooth_root_2d) == len( + frame_indices + ), "The number of smooth root 2d should be match the number of frames" + if global_root_heading is not None: + assert len(global_root_heading) == len( + frame_indices + ), "The number of global root heading should be match the number of frames" + + self.smooth_root_2d = smooth_root_2d + self.global_root_heading = global_root_heading + self.frame_indices = frame_indices + + def update_constraints(self, data_dict: dict, index_dict: dict) -> None: + """Append this constraint's smooth_root_2d (and optional global_root_heading) to data/index + dicts.""" + data_dict["smooth_root_2d"].append(self.smooth_root_2d) + index_dict["smooth_root_2d"].append(self.frame_indices) + + if self.global_root_heading is not None: + # constraint the global heading + data_dict["global_root_heading"].append(self.global_root_heading) + index_dict["global_root_heading"].append(self.frame_indices) + + def crop_move(self, start: int, end: int) -> "Root2DConstraintSet": + """Return a new constraint set for the cropped frame range [start, end).""" + mask = (self.frame_indices >= start) & (self.frame_indices < end) + + if self.global_root_heading is not None: + masked_global_root_heading = self.global_root_heading[mask] + else: + masked_global_root_heading = None + + return Root2DConstraintSet( + self.skeleton, + self.frame_indices[mask] - start, + self.smooth_root_2d[mask], + global_root_heading=masked_global_root_heading, + ) + + def get_save_info(self) -> dict: + """Return a dict suitable for JSON serialization (frame_indices, smooth_root_2d, optional + global_root_heading).""" + out = { + "type": self.name, + "frame_indices": self.frame_indices, + "smooth_root_2d": self.smooth_root_2d, + } + if self.global_root_heading is not None: + out["global_root_heading"] = self.global_root_heading + return out + + def to( + self, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + ) -> "Root2DConstraintSet": + self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype) + self.frame_indices = _tensor_to(self.frame_indices, device, dtype) + if self.global_root_heading is not None: + self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype) + if device is not None and hasattr(self.skeleton, "to"): + self.skeleton = self.skeleton.to(device) + return self + + @classmethod + def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> "Root2DConstraintSet": + """Build a Root2DConstraintSet from a dict (e.g. loaded from JSON).""" + device = skeleton.device if hasattr(skeleton, "device") else "cpu" + + if "global_root_heading" in dico: + global_root_heading = torch.tensor(dico["global_root_heading"], device=device) + else: + global_root_heading = None + + return cls( + skeleton, + frame_indices=torch.tensor(dico["frame_indices"]), + smooth_root_2d=torch.tensor(dico["smooth_root_2d"], device=device), + global_root_heading=global_root_heading, + ) + + +class FullBodyConstraintSet: + """Constraint set fixing full-body global positions and rotations on given keyframes.""" + + name = "fullbody" + + def __init__( + self, + skeleton: SkeletonBase, + frame_indices: Tensor, + global_joints_positions: Tensor, + global_joints_rots: Tensor, + smooth_root_2d: Optional[Tensor] = None, + to_crop: bool = False, + ): + self.skeleton = skeleton + self.frame_indices = frame_indices + + # if we pass the full smooth root 3D as input + if smooth_root_2d is not None and smooth_root_2d.shape[-1] == 3: + smooth_root_2d = smooth_root_2d[..., [0, 1]] + + if to_crop: + global_joints_positions = global_joints_positions[frame_indices] + global_joints_rots = global_joints_rots[frame_indices] + if smooth_root_2d is not None: + smooth_root_2d = smooth_root_2d[frame_indices] + else: + assert len(global_joints_positions) == len( + frame_indices + ), "The number of global positions should be match the number of frames" + assert len(global_joints_rots) == len( + frame_indices + ), "The number of global joint rotations should be match the number of frames" + + if smooth_root_2d is not None: + assert len(smooth_root_2d) == len( + frame_indices + ), "The number of smooth root 2d (if specified) should be match the number of frames" + + if smooth_root_2d is None: + # substitute the smooth root 2d with the real root + smooth_root_2d = global_joints_positions[:, skeleton.root_idx, [0, 2]] + + # root y: from smooth or pelvis is the same + self.root_y_pos = global_joints_positions[:, skeleton.root_idx, 1] + + self.global_joints_positions = global_joints_positions + self.global_joints_rots = global_joints_rots + self.global_root_heading = compute_global_heading(global_joints_positions, skeleton) + self.smooth_root_2d = smooth_root_2d + + def update_constraints(self, data_dict: dict, index_dict: dict) -> None: + """Append global positions, smooth root 2D, root y, and global heading to data/index + dicts.""" + nbjoints = self.skeleton.nbjoints + indices_lst = create_pairs( + self.frame_indices, + torch.arange(nbjoints, device=self.frame_indices.device), + ) + data_dict["global_joints_positions"].append( + self.global_joints_positions.reshape(-1, 3) + ) # flatten the global positions + index_dict["global_joints_positions"].append(indices_lst) + + # global rotations are not used here + + # as we use smooth root, also constraint the smooth root to get the same full body + # maybe keep storing the hips offset, if we smooth it ourselves + data_dict["smooth_root_2d"].append(self.smooth_root_2d) + index_dict["smooth_root_2d"].append(self.frame_indices) + + # constraint the y pos of the root + data_dict["root_y_pos"].append(self.root_y_pos) + index_dict["root_y_pos"].append(self.frame_indices) + + # constraint the global heading + data_dict["global_root_heading"].append(self.global_root_heading) + index_dict["global_root_heading"].append(self.frame_indices) + + def crop_move(self, start: int, end: int) -> "FullBodyConstraintSet": + """Return a new FullBodyConstraintSet for the cropped frame range [start, end).""" + mask = (self.frame_indices >= start) & (self.frame_indices < end) + return FullBodyConstraintSet( + self.skeleton, + self.frame_indices[mask] - start, + self.global_joints_positions[mask], + self.global_joints_rots[mask], + self.smooth_root_2d[mask], + ) + + def get_save_info(self) -> dict: + """Return a dict for JSON save: type, frame_indices, local_joints_rot, root_positions, smooth_root_2d.""" + local_joints_rot = self.skeleton.global_rots_to_local_rots(self.global_joints_rots) + if isinstance(self.skeleton, SOMASkeleton30): + local_joints_rot = self.skeleton.to_SOMASkeleton77(local_joints_rot) + local_joints_rot = matrix_to_axis_angle(local_joints_rot) + + root_positions = self.global_joints_positions[:, self.skeleton.root_idx] + return { + "type": self.name, + "frame_indices": self.frame_indices, + "local_joints_rot": local_joints_rot, + "root_positions": root_positions, + "smooth_root_2d": self.smooth_root_2d, + } + + def to( + self, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + ) -> "FullBodyConstraintSet": + self.frame_indices = _tensor_to(self.frame_indices, device, dtype) + self.global_joints_positions = _tensor_to(self.global_joints_positions, device, dtype) + self.global_joints_rots = _tensor_to(self.global_joints_rots, device, dtype) + self.root_y_pos = _tensor_to(self.root_y_pos, device, dtype) + self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype) + self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype) + if device is not None and hasattr(self.skeleton, "to"): + self.skeleton = self.skeleton.to(device) + return self + + @classmethod + def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> "FullBodyConstraintSet": + """Build a FullBodyConstraintSet from a dict (e.g. loaded from JSON).""" + frame_indices = torch.tensor(dico["frame_indices"]) + device = skeleton.device if hasattr(skeleton, "device") else "cpu" + local_rot = torch.tensor(dico["local_joints_rot"], device=device) + local_rot_mats = axis_angle_to_matrix(local_rot) + local_rot_mats = _convert_constraint_local_rots_to_skeleton(local_rot_mats, skeleton) + global_joints_rots, global_joints_positions, _ = skeleton.fk( + local_rot_mats, + torch.tensor(dico["root_positions"], device=device), + ) + smooth_root_2d = None + if "smooth_root_2d" in dico: + smooth_root_2d = torch.tensor(dico["smooth_root_2d"], device=device) + + return cls( + skeleton, + frame_indices=frame_indices, + global_joints_positions=global_joints_positions, + global_joints_rots=global_joints_rots, + smooth_root_2d=smooth_root_2d, + ) + + +class EndEffectorConstraintSet: + """Constraint set fixing selected end-effector positions and rotations on given frames.""" + + name = "end-effector" + + def __init__( + self, + skeleton: SkeletonBase, + frame_indices: Tensor, + global_joints_positions: Tensor, + global_joints_rots: Tensor, + smooth_root_2d: Optional[Tensor], + *, + joint_names: list[str], + to_crop: bool = False, + ) -> None: + self.skeleton = skeleton + self.frame_indices = frame_indices + self.joint_names = joint_names + + # joint_names are constant for all the frames + rot_joint_names, pos_joint_names = self.skeleton.expand_joint_names(self.joint_names) + # indexing works for motion_rep with smooth root only (contains pelvis index) + self.pos_indices = torch.tensor([self.skeleton.bone_index[jname] for jname in pos_joint_names]) + self.rot_indices = torch.tensor([self.skeleton.bone_index[jname] for jname in rot_joint_names]) + + # if we pass the full smooth root 3D as input + if smooth_root_2d is not None and smooth_root_2d.shape[-1] == 3: + smooth_root_2d = smooth_root_2d[..., [0, 1]] + + if to_crop: + global_joints_positions = global_joints_positions[frame_indices] + global_joints_rots = global_joints_rots[frame_indices] + if smooth_root_2d is not None: + smooth_root_2d = smooth_root_2d[frame_indices] + else: + assert len(global_joints_positions) == len( + frame_indices + ), "The number of global positions should be match the number of frames" + assert len(global_joints_rots) == len( + frame_indices + ), "The number of global joint rotations should be match the number of frames" + if smooth_root_2d is not None: + assert len(smooth_root_2d) == len( + frame_indices + ), "The number of smooth root 2d (if specified) should be match the number of frames" + + if smooth_root_2d is None: + # substitute the smooth root 2d with the real root + smooth_root_2d = global_joints_positions[:, skeleton.root_idx, [0, 2]] + + # root y: from smooth or pelvis is the same + self.root_y_pos = global_joints_positions[:, skeleton.root_idx, 1] + + self.global_joints_positions = global_joints_positions + self.global_root_heading = compute_global_heading(global_joints_positions, skeleton) + self.global_joints_rots = global_joints_rots + self.smooth_root_2d = smooth_root_2d + + def update_constraints(self, data_dict: dict, index_dict: dict) -> None: + """Append constrained joint positions/rots, smooth root 2D, root y, and heading to + data/index dicts.""" + crop_frames_indexing = torch.arange(len(self.frame_indices), device=self.frame_indices.device) + + # constraint positions + pos_indices_real = create_pairs( + self.frame_indices, + self.pos_indices, + ) + pos_indices_crop = create_pairs( + crop_frames_indexing, + self.pos_indices, + ) + data_dict["global_joints_positions"].append(self.global_joints_positions[tuple(pos_indices_crop.T)]) + index_dict["global_joints_positions"].append(pos_indices_real) + + # constraint rotations + rot_indices_real = create_pairs( + self.frame_indices, + self.rot_indices, + ) + rot_indices_crop = create_pairs( + crop_frames_indexing, + self.rot_indices, + ) + data_dict["global_joints_rots"].append(self.global_joints_rots[tuple(rot_indices_crop.T)]) + index_dict["global_joints_rots"].append(rot_indices_real) + + # as we use smooth root, also constraint the smooth root to get the same full body + # maybe keep storing the hips offset, if we smooth it ourselves + data_dict["smooth_root_2d"].append(self.smooth_root_2d) + index_dict["smooth_root_2d"].append(self.frame_indices) + + # constraint the y pos of the root + data_dict["root_y_pos"].append(self.root_y_pos) + index_dict["root_y_pos"].append(self.frame_indices) + + # constraint the global heading + data_dict["global_root_heading"].append(self.global_root_heading) + index_dict["global_root_heading"].append(self.frame_indices) + + def crop_move(self, start: int, end: int) -> "EndEffectorConstraintSet": + """Return a new EndEffectorConstraintSet for the cropped frame range [start, end).""" + mask = (self.frame_indices >= start) & (self.frame_indices < end) + + cls = type(self) + kwargs = {} + if not hasattr(cls, "joint_names"): + kwargs["joint_names"] = self.joint_names + + return cls( + self.skeleton, + self.frame_indices[mask] - start, + self.global_joints_positions[mask], + self.global_joints_rots[mask], + self.smooth_root_2d[mask], + **kwargs, + ) + + def get_save_info(self) -> dict: + """Return a dict for JSON save: type, frame_indices, local_joints_rot, root_positions, smooth_root_2d, joint_names.""" + local_joints_rot = self.skeleton.global_rots_to_local_rots(self.global_joints_rots) + if isinstance(self.skeleton, SOMASkeleton30): + local_joints_rot = self.skeleton.to_SOMASkeleton77(local_joints_rot) + local_joints_rot = matrix_to_axis_angle(local_joints_rot) + + root_positions = self.global_joints_positions[:, self.skeleton.root_idx] + output = { + "type": self.name, + "frame_indices": self.frame_indices, + "local_joints_rot": local_joints_rot, + "root_positions": root_positions, + "smooth_root_2d": self.smooth_root_2d, + } + if not hasattr(self.__class__, "joint_names"): + # save the joint_names for this base class + # but not for children + output["joint_names"] = self.joint_names + return output + + def to( + self, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + ) -> "EndEffectorConstraintSet": + self.frame_indices = _tensor_to(self.frame_indices, device, dtype) + self.pos_indices = _tensor_to(self.pos_indices, device, dtype) + self.rot_indices = _tensor_to(self.rot_indices, device, dtype) + self.root_y_pos = _tensor_to(self.root_y_pos, device, dtype) + self.global_joints_positions = _tensor_to(self.global_joints_positions, device, dtype) + self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype) + self.global_joints_rots = _tensor_to(self.global_joints_rots, device, dtype) + self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype) + if device is not None and hasattr(self.skeleton, "to"): + self.skeleton = self.skeleton.to(device) + return self + + @classmethod + def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> "EndEffectorConstraintSet": + """Build an EndEffectorConstraintSet from a dict (e.g. loaded from JSON).""" + frame_indices = torch.tensor(dico["frame_indices"]) + device = skeleton.device if hasattr(skeleton, "device") else "cpu" + local_rot = torch.tensor(dico["local_joints_rot"], device=device) + local_rot_mats = axis_angle_to_matrix(local_rot) + local_rot_mats = _convert_constraint_local_rots_to_skeleton(local_rot_mats, skeleton) + global_joints_rots, global_joints_positions, _ = skeleton.fk( + local_rot_mats, + torch.tensor(dico["root_positions"], device=device), + ) + smooth_root_2d = None + if "smooth_root_2d" in dico: + smooth_root_2d = torch.tensor(dico["smooth_root_2d"], device=device) + + kwargs = {} + if not hasattr(cls, "joint_names"): + kwargs["joint_names"] = dico["joint_names"] + + return cls( + skeleton, + frame_indices=frame_indices, + global_joints_positions=global_joints_positions, + global_joints_rots=global_joints_rots, + smooth_root_2d=smooth_root_2d, + **kwargs, + ) + + +class LeftHandConstraintSet(EndEffectorConstraintSet): + """End-effector constraint for the left hand only.""" + + name = "left-hand" + joint_names: list[str] = ["LeftHand"] + + def __init__(self, *args, **kwargs: dict): + super().__init__(*args, joint_names=self.joint_names, **kwargs) + + +class RightHandConstraintSet(EndEffectorConstraintSet): + """End-effector constraint for the right hand only.""" + + name = "right-hand" + joint_names: list[str] = ["RightHand"] + + def __init__(self, *args, **kwargs: dict): + super().__init__(*args, joint_names=self.joint_names, **kwargs) + + +class LeftFootConstraintSet(EndEffectorConstraintSet): + """End-effector constraint for the left foot only.""" + + name = "left-foot" + joint_names: list[str] = ["LeftFoot"] + + def __init__(self, *args, **kwargs: dict): + super().__init__(*args, joint_names=self.joint_names, **kwargs) + + +class RightFootConstraintSet(EndEffectorConstraintSet): + """End-effector constraint for the right foot only.""" + + name = "right-foot" + joint_names: list[str] = ["RightFoot"] + + def __init__(self, *args, **kwargs: dict): + super().__init__(*args, joint_names=self.joint_names, **kwargs) + + +TYPE_TO_CLASS = { + "root2d": Root2DConstraintSet, + "fullbody": FullBodyConstraintSet, + "left-hand": LeftHandConstraintSet, + "right-hand": RightHandConstraintSet, + "left-foot": LeftFootConstraintSet, + "right-foot": RightFootConstraintSet, + "end-effector": EndEffectorConstraintSet, +} + + +def load_constraints_lst( + path_or_data: str | list, + skeleton: SkeletonBase, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, +): + """Load a list of constraints from JSON path or list of dicts. + + Args: + path_or_data: Path to constraints.json or list of constraint dicts. + skeleton: Skeleton instance (used for from_dict). + device: If set, move all constraint tensors and skeleton to this device. + dtype: If set, cast constraint tensors to this dtype. + """ + if isinstance(path_or_data, str): + saved = load_json(path_or_data) + else: + saved = path_or_data + + constraints_lst = [] + for el in saved: + cls = TYPE_TO_CLASS[el["type"]] + c = cls.from_dict(skeleton, el) + if device is not None or dtype is not None: + c.to(device=device, dtype=dtype) + constraints_lst.append(c) + return constraints_lst + + +def save_constraints_lst(path: str, constraints_lst: list) -> list | None: + """Save a list of constraint sets to a JSON file. + + Returns None if list is empty. + """ + if not constraints_lst: + print("The constraints lst is empty. Skip saving") + return + + to_save = [] + + def tensor_to_list(obj): + """Recursively convert tensors to lists for JSON serialization.""" + if isinstance(obj, Tensor): + return obj.cpu().tolist() + elif isinstance(obj, dict): + return {k: tensor_to_list(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [tensor_to_list(v) for v in obj] + else: + return obj + + for constraint in constraints_lst: + constraint_info = constraint.get_save_info() + # Convert all tensors to lists for JSON serialization + constraint_info = tensor_to_list(constraint_info) + to_save.append(constraint_info) + + save_json(path, to_save) + print(f"Saved constraints to {path}") + return to_save diff --git a/kimodo/exports/__init__.py b/kimodo/exports/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57ce23ddceef985d59b206f2f3fd0f14ee36ca69 --- /dev/null +++ b/kimodo/exports/__init__.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Export utilities: MuJoCo, BVH, SMPLX/AMASS, and motion I/O helpers.""" + +from .bvh import bvh_to_kimodo_motion, motion_to_bvh_bytes, read_bvh_frame_time_seconds, save_motion_bvh +from .motion_convert_lib import convert_motion_files +from .motion_formats import ( + infer_npz_kind, + infer_source_format_from_path, + infer_target_format_from_path, + resolve_source_fps, +) +from .motion_io import ( + KIMODO_CONVERT_TARGET_FPS, + amass_npz_to_bytes, + complete_motion_dict, + g1_csv_to_bytes, + kimodo_npz_to_bytes, + load_amass_npz, + load_g1_csv, + load_kimodo_npz, + load_kimodo_npz_as_torch, + load_motion_file, + motion_dict_to_numpy, + save_kimodo_npz, + save_kimodo_npz_at_target_fps, +) +from .mujoco import MujocoQposConverter, apply_g1_real_robot_projection +from .smplx import ( + AMASSConverter, + amass_npz_to_kimodo_motion, + get_amass_parameters, + kimodo_y_up_to_amass_coord_rotation_matrix, +) + +__all__ = [ + "AMASSConverter", + "KIMODO_CONVERT_TARGET_FPS", + "MujocoQposConverter", + "amass_npz_to_bytes", + "amass_npz_to_kimodo_motion", + "apply_g1_real_robot_projection", + "bvh_to_kimodo_motion", + "complete_motion_dict", + "convert_motion_files", + "g1_csv_to_bytes", + "get_amass_parameters", + "infer_npz_kind", + "infer_source_format_from_path", + "infer_target_format_from_path", + "kimodo_npz_to_bytes", + "kimodo_y_up_to_amass_coord_rotation_matrix", + "load_amass_npz", + "load_g1_csv", + "load_kimodo_npz", + "load_kimodo_npz_as_torch", + "load_motion_file", + "motion_dict_to_numpy", + "motion_to_bvh_bytes", + "read_bvh_frame_time_seconds", + "resolve_source_fps", + "save_kimodo_npz", + "save_kimodo_npz_at_target_fps", + "save_motion_bvh", +] diff --git a/kimodo/exports/bvh.py b/kimodo/exports/bvh.py new file mode 100644 index 0000000000000000000000000000000000000000..02cc2e4546633cacce2fd1425e4b809eba10a417 --- /dev/null +++ b/kimodo/exports/bvh.py @@ -0,0 +1,282 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Export utilities for converting internal motion representations into common file formats. + +This module is intended to hold lightweight serialization / export helpers that can be reused +outside of interactive demos. +""" + +import os +import tempfile +from pathlib import Path +from typing import Tuple, Union + +import numpy as np +import torch + +from kimodo.geometry import matrix_to_quaternion as _matrix_to_quaternion + + +def _strip_end_site_blocks(bvh_text: str) -> str: + """Remove all 'End Site { ... }' blocks from BVH text so output matches original format. + + bvhio adds an End Site for every leaf joint when writing; we do not set EndSite on joints, so we + post-process the string to remove these blocks for Blender/original compatibility. + """ + lines = bvh_text.splitlines(keepends=True) + result = [] + i = 0 + while i < len(lines): + line = lines[i] + if "End Site" in line: + # Skip this line and the following block { ... }; brace-count to find closing } + i += 1 + if i < len(lines) and "{" in lines[i]: + i += 1 + depth = 1 + while i < len(lines) and depth > 0: + if "{" in lines[i]: + depth += 1 + if "}" in lines[i]: + depth -= 1 + i += 1 + continue + result.append(line) + i += 1 + return "".join(result) + + +def _coerce_batch(name: str, x: torch.Tensor, *, expected_ndim: int) -> torch.Tensor: + """Coerce (T, ...) or (1, T, ...) into (T, ...).""" + if x.ndim == expected_ndim: + return x + if x.ndim == expected_ndim + 1: + if int(x.shape[0]) != 1: + raise ValueError( + f"{name} has batch dimension B={int(x.shape[0])}, but BVH export " "only supports a single clip (B==1)." + ) + return x[0] + raise ValueError(f"{name} must have shape (T, ...) or (1, T, ...); got {tuple(x.shape)}") + + +def motion_to_bvh( + local_rot_mats: torch.Tensor, + root_positions: torch.Tensor, + *, + skeleton, + fps: float, +) -> str: + """Convert local rotations and root positions to BVH format; return UTF-8 string. + + Args: + local_rot_mats: (T, J, 3, 3) or (1, T, J, 3, 3) local rotation matrices. + root_positions: (T, 3) or (1, T, 3) root joint positions (e.g. from posed joints). + skeleton: Skeleton with bone_order_names, bvh_neutral_joints, etc. + fps: Frames per second for the motion. + + Notes: + BVH is plain-text. Root is named "Root" with ZYX rotation order; leaf joints + have no End Site block. + """ + try: + import bvhio # type: ignore[import-not-found] + import glm # type: ignore[import-not-found] + from SpatialTransform import Pose # type: ignore[import-not-found] + except Exception as e: # pragma: no cover + raise ImportError( + "BVH export requires `bvhio` (and its deps `PyGLM` + `SpatialTransform`). " + "Install with: `pip install bvhio`." + ) from e + + local_rot_mats = local_rot_mats.detach() + root_positions = root_positions.detach() + # SOMA: accept either somaskel30 (convert to 77) or somaskel77 (use as-is) + if skeleton.name == "somaskel30": + local_rot_mats = skeleton.to_SOMASkeleton77(local_rot_mats) + skeleton = skeleton.somaskel77 + + local_rot_mats, _ = skeleton.from_standard_tpose(local_rot_mats) + + neutral = skeleton.bvh_neutral_joints.detach().cpu().numpy() + joint_names = list(skeleton.bone_order_names) + parents = skeleton.joint_parents.detach().cpu().numpy().astype(int) + root_idx = int(skeleton.root_idx) + + local_rot_mats = _coerce_batch("local_rot_mats", local_rot_mats, expected_ndim=4) + T, J = local_rot_mats.shape[:2] + q_wxyz = _matrix_to_quaternion(local_rot_mats).detach().cpu().numpy() # [T, J, 4] + + root_xyz = _coerce_batch("root_positions", root_positions, expected_ndim=2) + root_xyz = root_xyz.cpu().numpy() # [T, 3] + + # Build BVH hierarchy: Root (wrapper at origin) -> Hips (pelvis with offset in meters) -> ... + # Offsets are in meters to match the original format. + children: dict[int, list[int]] = {i: [] for i in range(J)} + for i, p in enumerate(parents): + if p >= 0: + children[int(p)].append(int(i)) + + _ROOT_CHANNELS = [ + "Xposition", + "Yposition", + "Zposition", + "Zrotation", + "Yrotation", + "Xrotation", + ] + _JOINT_CHANNELS = ["Zrotation", "Yrotation", "Xrotation"] + + # Scale from meters to centimeters (match original BVH scale). + neutral = neutral * 100 + root_xyz = root_xyz * 100 + + # Hips offset from Root: use skeleton neutral; if root is at origin (zeros), use a + # nominal pelvis height so the hierarchy is non-degenerate in Blender. + hips_offset = neutral[root_idx] + if (hips_offset == 0).all(): + hips_offset = np.array([0.0, 100.0, 0.0], dtype=neutral.dtype) # 1 m in cm + + def _make_joint(i: int) -> "bvhio.BvhJoint": + name = joint_names[i] + j = bvhio.BvhJoint(name, offset=glm.vec3(0, 0, 0)) + if i == root_idx: + # Hips: offset from Root (origin) in cm + off = hips_offset + j.Offset = glm.vec3(float(off[0]), float(off[1]), float(off[2])) + j.Channels = _ROOT_CHANNELS.copy() + else: + p = int(parents[i]) + off = neutral[i] - neutral[p] + j.Offset = glm.vec3(float(off[0]), float(off[1]), float(off[2])) + j.Channels = _JOINT_CHANNELS.copy() + + for c in children[i]: + j.Children.append(_make_joint(c)) + return j + + # Wrapper Root at origin; single child is Hips (skeleton root). + root_wrapper = bvhio.BvhJoint("Root", offset=glm.vec3(0.0, 0.0, 0.0)) + root_wrapper.Channels = _ROOT_CHANNELS.copy() + root_wrapper.Children.append(_make_joint(root_idx)) + root_joint = root_wrapper + + # Populate keyframes: Root = identity/zero, Hips = root motion, others = local rotation. + bvh_layout = root_joint.layout() + name_to_id = {n: idx for idx, n in enumerate(joint_names)} + ordered_joint_ids = [] + for bj, _, _ in bvh_layout: + if bj.Name == "Root": + ordered_joint_ids.append(None) + else: + ordered_joint_ids.append(name_to_id[bj.Name]) + + bvh_joints = [bj for bj, _, _ in bvh_layout] + for bj in bvh_joints: + bj.Keyframes = [None] * T # type: ignore[list-item] + + identity_quat = glm.quat(1.0, 0.0, 0.0, 0.0) + zero_vec = glm.vec3(0.0, 0.0, 0.0) + for t in range(T): + for bj, jid in zip(bvh_joints, ordered_joint_ids): + if jid is None: + position = zero_vec + rotation = identity_quat + elif jid == root_idx: + pos = root_xyz[t] + position = glm.vec3(float(pos[0]), float(pos[1]), float(pos[2])) + qw, qx, qy, qz = q_wxyz[t, jid] + rotation = glm.quat(float(qw), float(qx), float(qy), float(qz)) + else: + position = zero_vec + qw, qx, qy, qz = q_wxyz[t, jid] + rotation = glm.quat(float(qw), float(qx), float(qy), float(qz)) + bj.Keyframes[t] = Pose(position, rotation) # type: ignore[index] + + container = bvhio.BvhContainer(root_joint, frameCount=T, frameTime=1.0 / float(fps)) + with tempfile.NamedTemporaryFile(mode="w", suffix=".bvh", delete=False, encoding="utf-8") as f: + tmp_path = f.name + try: + bvhio.writeBvh(tmp_path, container, percision=6) + bvh_text = Path(tmp_path).read_text(encoding="utf-8") + return _strip_end_site_blocks(bvh_text) + finally: + try: + os.remove(tmp_path) + except Exception: + pass + + +def motion_to_bvh_bytes( + local_rot_mats: torch.Tensor, + root_positions: torch.Tensor, + *, + skeleton, + fps: float, +) -> bytes: + """Convert local rotations and root positions to BVH bytes (UTF-8). + + Convenience wrapper around :func:`motion_to_bvh`. + """ + return motion_to_bvh(local_rot_mats, root_positions, skeleton=skeleton, fps=fps).encode("utf-8") + + +def save_motion_bvh( + path: Union[str, Path], + local_rot_mats: torch.Tensor, + root_positions: torch.Tensor, + *, + skeleton, + fps: float, +) -> None: + """Write local rotations and root positions to a BVH file at the given path.""" + Path(path).write_text( + motion_to_bvh(local_rot_mats, root_positions, skeleton=skeleton, fps=fps), + encoding="utf-8", + ) + + +def read_bvh_frame_time_seconds(path: Union[str, Path]) -> float: + """Read ``Frame Time`` from a BVH file (seconds per frame).""" + with open(path, encoding="utf-8") as f: + for line in f: + if "Frame Time:" in line: + parts = line.split() + return float(parts[-1]) + raise ValueError(f"Could not find 'Frame Time:' in {path}") + + +def bvh_to_kimodo_motion( + path: Union[str, Path], + skeleton=None, +) -> Tuple: + """Load a Kimodo-style SOMA BVH into a Kimodo motion dict. + + Expects the same hierarchy as :func:`save_motion_bvh` (``Root`` wrapper + SOMA77 joints). + The frame rate is always read from the BVH ``Frame Time`` header. Callers + that need a different playback rate should resample the returned motion dict + (see :func:`~kimodo.exports.motion_io.resample_motion_dict_to_kimodo_fps`). + + Returns: + ``(motion_dict, source_fps)`` where ``source_fps`` is the native BVH + frame rate read from the file header. + """ + from kimodo.exports.motion_io import complete_motion_dict + from kimodo.skeleton.bvh import parse_bvh_motion + from kimodo.skeleton.registry import build_skeleton + + if skeleton is None: + skeleton = build_skeleton(77) + device = skeleton.neutral_joints.device + + local_rot_mats, root_trans, bvh_fps = parse_bvh_motion(str(path)) + local_rot_mats = local_rot_mats.to(device=device) + root_trans = root_trans.to(device=device) + + if int(local_rot_mats.shape[1]) != int(skeleton.nbjoints): + raise ValueError( + f"BVH has {local_rot_mats.shape[1]} joints but skeleton has {skeleton.nbjoints}; " + "use a Kimodo-exported SOMA BVH or matching skeleton." + ) + local_rot_mats, _ = skeleton.to_standard_tpose(local_rot_mats) + + return complete_motion_dict(local_rot_mats, root_trans, skeleton, float(bvh_fps)), bvh_fps diff --git a/kimodo/exports/motion_convert_lib.py b/kimodo/exports/motion_convert_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..bd2deaa6bcaf1580761204ac1d1665600d829fc6 --- /dev/null +++ b/kimodo/exports/motion_convert_lib.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Library API for converting between Kimodo NPZ, AMASS NPZ, SOMA BVH, and G1 MuJoCo CSV.""" + +from __future__ import annotations + +import warnings + +import numpy as np + +from kimodo.exports.bvh import bvh_to_kimodo_motion, save_motion_bvh +from kimodo.exports.motion_formats import ( + infer_source_format_from_path, + infer_target_format_from_path, + resolve_source_fps, +) +from kimodo.exports.motion_io import ( + load_amass_npz, + load_g1_csv, + load_kimodo_npz_as_torch, + save_kimodo_npz_at_target_fps, +) +from kimodo.exports.mujoco import MujocoQposConverter +from kimodo.exports.smplx import AMASSConverter +from kimodo.skeleton.registry import build_skeleton + + +def convert_motion_files( + input_path: str, + output_path: str, + *, + from_fmt: str | None = None, + to_fmt: str | None = None, + source_fps: float | None = None, + z_up: bool = True, + mujoco_rest_zero: bool = False, +) -> None: + """Convert a motion file between Kimodo-supported formats. + + Supported pairs (hub-and-spoke through Kimodo NPZ): + + - amass <-> kimodo + - soma-bvh <-> kimodo + - g1-csv <-> kimodo + + Args: + input_path: Source file (``.npz``, ``.bvh``, or ``.csv``). + output_path: Destination file. + from_fmt: Source format; inferred from extension/contents when ``None``. + to_fmt: Target format; inferred from extension when ``None``. + source_fps: Source motion frame rate (Hz). If provided, trusted as-is. + If ``None``, auto-detected from BVH ``Frame Time``, AMASS + ``mocap_frame_rate``, or default 30. + z_up: For AMASS conversions, apply the Z-up <-> Kimodo Y-up transform. + mujoco_rest_zero: For G1 CSV, joint angles relative to MuJoCo rest pose. + """ + from_fmt = from_fmt or infer_source_format_from_path(input_path) + to_fmt = to_fmt or infer_target_format_from_path(output_path, from_fmt) + + _validate_output_extension(to_fmt, output_path) + + pair = (from_fmt, to_fmt) + + if pair == ("amass", "kimodo"): + sk = build_skeleton(22) + effective_source = source_fps + if effective_source is None: + with np.load(input_path, allow_pickle=True) as z: + effective_source = float(z["mocap_frame_rate"]) if "mocap_frame_rate" in z.files else 30.0 + motion = load_amass_npz(input_path, source_fps=effective_source, z_up=z_up) + save_kimodo_npz_at_target_fps(motion, sk, effective_source, output_path) + return + + if pair == ("kimodo", "amass"): + data, J = load_kimodo_npz_as_torch(input_path, ensure_complete=False) + if J != 22: + raise ValueError(f"Kimodo→AMASS requires 22 joints (SMPL-X); this file has J={J}.") + sk = build_skeleton(22) + effective_source = resolve_source_fps(source_fps, "kimodo", input_path, None) + converter = AMASSConverter(fps=effective_source, skeleton=sk) + converter.convert_save_npz(data, output_path, z_up=z_up) + return + + if pair == ("soma-bvh", "kimodo"): + sk = build_skeleton(77) + motion, bvh_fps = bvh_to_kimodo_motion(input_path, skeleton=sk) + effective_source = source_fps if source_fps is not None else bvh_fps + save_kimodo_npz_at_target_fps(motion, sk, effective_source, output_path) + return + + if pair == ("kimodo", "soma-bvh"): + data, J = load_kimodo_npz_as_torch(input_path, ensure_complete=False) + if J == 30: + warnings.warn( + f"Input has 30 joints (somaskel30); expanding to somaskel77 for BVH export.", + UserWarning, + stacklevel=2, + ) + sk = build_skeleton(30) + elif J == 77: + sk = build_skeleton(77) + else: + raise ValueError(f"Kimodo→BVH requires a SOMA skeleton (30 or 77 joints); this file has J={J}.") + effective_source = resolve_source_fps(source_fps, "kimodo", input_path, None) + save_motion_bvh( + output_path, + data["local_rot_mats"], + data["root_positions"], + skeleton=sk, + fps=effective_source, + ) + return + + if pair == ("g1-csv", "kimodo"): + sk = build_skeleton(34) + effective_source = resolve_source_fps(source_fps, "g1-csv", input_path, None) + motion = load_g1_csv(input_path, source_fps=effective_source, mujoco_rest_zero=mujoco_rest_zero) + save_kimodo_npz_at_target_fps(motion, sk, effective_source, output_path) + return + + if pair == ("kimodo", "g1-csv"): + data, J = load_kimodo_npz_as_torch(input_path, ensure_complete=False) + if J != 34: + raise ValueError(f"Kimodo→CSV requires G1 with 34 joints; this file has J={J}.") + sk = build_skeleton(34) + effective_source = resolve_source_fps(source_fps, "kimodo", input_path, None) + converter = MujocoQposConverter(sk) + qpos = converter.dict_to_qpos( + {k: v for k, v in data.items() if k in ("local_rot_mats", "root_positions")}, + device=str(sk.neutral_joints.device), + numpy=True, + mujoco_rest_zero=mujoco_rest_zero, + ) + converter.save_csv(qpos, output_path) + return + + raise ValueError( + f"Unsupported conversion {from_fmt!r} → {to_fmt!r}. " + "Supported: amass↔kimodo (SMPL-X NPZ), soma-bvh↔kimodo, g1-csv↔kimodo." + ) + + +def _validate_output_extension(to_fmt: str, output_path: str) -> None: + lower = output_path.lower() + if to_fmt == "kimodo" and lower.endswith(".npz"): + return + if to_fmt == "amass": + if not lower.endswith(".npz"): + raise ValueError("AMASS output must use a .npz path.") + elif to_fmt == "soma-bvh": + if not lower.endswith(".bvh"): + raise ValueError("SOMA BVH output must use a .bvh path.") + elif to_fmt == "g1-csv": + if not lower.endswith(".csv"): + raise ValueError("G1 CSV output must use a .csv path.") diff --git a/kimodo/exports/motion_formats.py b/kimodo/exports/motion_formats.py new file mode 100644 index 0000000000000000000000000000000000000000..c2ba4cb4a0aeb214d79084fe6911a8176a7f92c9 --- /dev/null +++ b/kimodo/exports/motion_formats.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Infer motion file formats from paths and NPZ contents.""" + +from __future__ import annotations + +import os +from typing import Literal + +import numpy as np + +MotionSourceFormat = Literal["amass", "kimodo", "soma-bvh", "g1-csv"] +MotionTargetFormat = Literal["amass", "kimodo", "soma-bvh", "g1-csv"] +NpzMotionKind = Literal["amass", "kimodo"] + + +def infer_npz_kind(path: str) -> NpzMotionKind: + """Classify a ``.npz`` as AMASS SMPL-X or Kimodo from required array keys.""" + with np.load(path, allow_pickle=False) as z: + keys = set(z.files) + if "trans" in keys and "pose_body" in keys and "root_orient" in keys: + return "amass" + if "local_rot_mats" in keys or "posed_joints" in keys: + return "kimodo" + raise ValueError( + f"Unrecognized NPZ {path!r}: expected AMASS keys (trans, pose_body, ...) " + "or Kimodo keys (local_rot_mats, posed_joints, ...)." + ) + + +def infer_source_format_from_path(path: str) -> MotionSourceFormat: + """Infer converter input format from file extension and NPZ contents when needed.""" + ext = os.path.splitext(path)[1].lower() + if ext == ".bvh": + return "soma-bvh" + if ext == ".csv": + return "g1-csv" + if ext == ".npz": + return infer_npz_kind(path) # type: ignore[return-value] + raise ValueError(f"Cannot infer format from extension of {path!r}") + + +def infer_target_format_from_path(path: str, from_fmt: MotionSourceFormat) -> MotionTargetFormat: + """Infer converter output format from destination path and source format.""" + ext = os.path.splitext(path)[1].lower() + if ext == ".bvh": + return "soma-bvh" + if ext == ".csv": + return "g1-csv" + if ext == ".npz": + if from_fmt == "amass": + return "kimodo" + if from_fmt == "kimodo": + return "amass" + if from_fmt in ("g1-csv", "soma-bvh"): + return "kimodo" + raise ValueError( + "Ambiguous .npz output: set --to to 'kimodo' or 'amass' when the input format is not amass/kimodo." + ) + raise ValueError(f"Cannot infer output format from extension of {path!r}") + + +def resolve_source_fps( + fps: float | None, + from_kind: str, + input_path: str, + data: dict | None, +) -> float: + """Resolve source frame rate (Hz) for conversion when ``fps`` is not overridden.""" + if fps is not None: + return float(fps) + if data is not None and "mocap_frame_rate" in data: + return float(np.asarray(data["mocap_frame_rate"]).item()) + if from_kind == "soma-bvh": + from kimodo.exports.bvh import read_bvh_frame_time_seconds + + return 1.0 / read_bvh_frame_time_seconds(input_path) + return 30.0 diff --git a/kimodo/exports/motion_io.py b/kimodo/exports/motion_io.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4bacc98ef84f24b77be9db01479bd1a966e877 --- /dev/null +++ b/kimodo/exports/motion_io.py @@ -0,0 +1,443 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Assemble Kimodo NPZ-compatible motion dicts from local rotations + root trajectory.""" + +from __future__ import annotations + +import os +import warnings +from typing import Any, Dict, Tuple + +import numpy as np +import torch + +from kimodo.geometry import matrix_to_quaternion, quaternion_to_matrix +from kimodo.motion_rep.feature_utils import compute_heading_angle, compute_vel_xyz +from kimodo.motion_rep.feet import foot_detect_from_pos_and_vel +from kimodo.motion_rep.smooth_root import get_smooth_root_pos +from kimodo.skeleton import SkeletonBase +from kimodo.skeleton.registry import build_skeleton +from kimodo.tools import to_numpy + +# Default motion rate for Kimodo NPZ produced by format conversion (matches common model FPS). +KIMODO_CONVERT_TARGET_FPS = 30.0 + + +def _quaternion_slerp(q0: torch.Tensor, q1: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """Spherical linear interpolation; ``q0``, ``q1`` (..., 4) wxyz; ``t`` broadcastable to (..., + 1).""" + if t.dim() < q0.dim(): + t = t.unsqueeze(-1) + dot = (q0 * q1).sum(dim=-1, keepdim=True) + q1 = torch.where(dot < 0, -q1, q1) + dot = torch.abs(dot).clamp(-1.0, 1.0) + theta_0 = torch.acos(dot) + sin_theta = torch.sin(theta_0) + s0 = torch.sin((1.0 - t) * theta_0) / sin_theta.clamp(min=1e-8) + s1 = torch.sin(t * theta_0) / sin_theta.clamp(min=1e-8) + q = s0 * q0 + s1 * q1 + return q / torch.linalg.norm(q, dim=-1, keepdim=True).clamp(min=1e-8) + + +def resample_motion_dict_to_kimodo_fps( + motion_dict: Dict[str, torch.Tensor], + skeleton: SkeletonBase, + source_fps: float, + target_fps: float = KIMODO_CONVERT_TARGET_FPS, +) -> Tuple[Dict[str, torch.Tensor], bool]: + """Resample a Kimodo motion dict to ``target_fps``. + + When the fps ratio is close to an integer (e.g. 120 / 30 = 4), the faster + stepping method is used (take every *step*-th frame). Otherwise falls back + to linear interp (root) + quaternion slerp (joints). + + Re-runs :func:`complete_motion_dict` at the target rate so derived channels stay consistent. + + Returns: + The motion dict and ``True`` if time resampling was applied, else ``False`` (already at + ``target_fps`` with matching frame count; only re-derived via FK). + """ + local_rot_mats = motion_dict["local_rot_mats"] + root_positions = motion_dict["root_positions"] + local_rot_mats, root_positions = _coerce_time_local_root(local_rot_mats, root_positions) + t_in = int(local_rot_mats.shape[0]) + if t_in < 1: + raise ValueError("Motion must have at least one frame.") + if source_fps <= 0: + raise ValueError(f"source_fps must be positive; got {source_fps}") + + t_out = max(1, int(round(t_in * target_fps / source_fps))) + if t_out == t_in and abs(float(source_fps) - float(target_fps)) < 1e-3: + return complete_motion_dict(local_rot_mats, root_positions, skeleton, float(target_fps)), False + + ratio = source_fps / target_fps + step = round(ratio) + if step >= 2 and abs(ratio - step) < 0.05: + local_out = local_rot_mats[::step] + root_out = root_positions[::step] + else: + device = local_rot_mats.device + dtype = local_rot_mats.dtype + u = torch.linspace(0, t_in - 1, t_out, device=device, dtype=dtype) + i0 = u.floor().long().clamp(0, t_in - 1) + i1 = torch.minimum(i0 + 1, torch.tensor(t_in - 1, device=device)) + tau_1d = (u - i0.float()).unsqueeze(-1) + rp0 = root_positions[i0] + rp1 = root_positions[i1] + root_out = (1.0 - tau_1d) * rp0 + tau_1d * rp1 + + quats = matrix_to_quaternion(local_rot_mats) + q0 = quats[i0] + q1 = quats[i1] + tau_q = (u - i0.float()).view(t_out, 1, 1) + quat_out = _quaternion_slerp(q0, q1, tau_q) + local_out = quaternion_to_matrix(quat_out) + + return complete_motion_dict(local_out, root_out, skeleton, float(target_fps)), True + + +def warn_kimodo_npz_framerate(source_fps: float, t_before: int, t_after: int) -> None: + """Emit a warning after time resampling for Kimodo NPZ (linear root, quaternion slerp per + joint).""" + warnings.warn( + f"Resampled motion to {KIMODO_CONVERT_TARGET_FPS:.0f} Hz for Kimodo NPZ " + f"(source ~{source_fps:.4g} Hz, {t_before} input frames → {t_after} output frames). " + "Pass --source-fps if the detected source rate is wrong.", + UserWarning, + stacklevel=3, + ) + + +def _coerce_time_local_root( + local_rot_mats: torch.Tensor, + root_positions: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Normalize to shapes (T, J, 3, 3) and (T, 3).""" + if local_rot_mats.dim() == 5: + if int(local_rot_mats.shape[0]) != 1: + raise ValueError(f"local_rot_mats batch size must be 1 for single clip; got {local_rot_mats.shape[0]}") + local_rot_mats = local_rot_mats[0] + if root_positions.dim() == 3: + if int(root_positions.shape[0]) != 1: + raise ValueError(f"root_positions batch size must be 1; got {root_positions.shape[0]}") + root_positions = root_positions[0] + if local_rot_mats.dim() != 4: + raise ValueError(f"local_rot_mats must be (T,J,3,3); got {tuple(local_rot_mats.shape)}") + if root_positions.dim() != 2 or int(root_positions.shape[-1]) != 3: + raise ValueError(f"root_positions must be (T,3); got {tuple(root_positions.shape)}") + if int(local_rot_mats.shape[0]) != int(root_positions.shape[0]): + raise ValueError("local_rot_mats and root_positions must have the same number of frames") + return local_rot_mats, root_positions + + +def complete_motion_dict( + local_rot_mats: torch.Tensor, + root_positions: torch.Tensor, + skeleton: SkeletonBase, + fps: float, +) -> Dict[str, torch.Tensor]: + """Build the Kimodo motion output dict from local rotations and root positions. + + Matches keys written by CLI generation (see docs/source/user_guide/output_formats.md). + + Args: + local_rot_mats: (T, J, 3, 3) or (1, T, J, 3, 3) local rotation matrices. + root_positions: (T, 3) or (1, T, 3) root / pelvis world positions (meters). + skeleton: Skeleton instance (SOMA77, G1, SMPL-X, etc.). + fps: Sampling rate (Hz). + + Returns: + Dict with tensors ``posed_joints``, ``global_rot_mats``, ``local_rot_mats``, + ``foot_contacts``, ``smooth_root_pos``, ``root_positions``, ``global_root_heading``. + """ + device = local_rot_mats.device + dtype = local_rot_mats.dtype + local_rot_mats, root_positions = _coerce_time_local_root( + local_rot_mats.to(device=device, dtype=dtype), + root_positions.to(device=device, dtype=dtype), + ) + + global_rot_mats, posed_joints, _ = skeleton.fk(local_rot_mats, root_positions) + + smooth_root_pos = get_smooth_root_pos(root_positions.unsqueeze(0)).squeeze(0) + + lengths = torch.tensor([posed_joints.shape[0]], device=device) + velocities = compute_vel_xyz(posed_joints.unsqueeze(0), fps, lengths=lengths).squeeze(0) + + heading_angle = compute_heading_angle(posed_joints.unsqueeze(0), skeleton).squeeze(0) + global_root_heading = torch.stack([torch.cos(heading_angle), torch.sin(heading_angle)], dim=-1) + + foot_contacts = foot_detect_from_pos_and_vel( + posed_joints.unsqueeze(0), + velocities.unsqueeze(0), + skeleton, + 0.15, + 0.10, + ).squeeze(0) + + return { + "posed_joints": posed_joints, + "global_rot_mats": global_rot_mats, + "local_rot_mats": local_rot_mats, + "foot_contacts": foot_contacts, + "smooth_root_pos": smooth_root_pos, + "root_positions": root_positions, + "global_root_heading": global_root_heading, + } + + +def motion_dict_to_numpy(d: Dict[str, Any]) -> Dict[str, np.ndarray]: + """Convert motion dict values to numpy arrays for ``np.savez``.""" + out: Dict[str, np.ndarray] = {} + for k, v in d.items(): + if hasattr(v, "detach"): + out[k] = to_numpy(v) + elif isinstance(v, np.ndarray): + out[k] = v + else: + out[k] = np.asarray(v) + return out + + +def save_kimodo_npz(path: str, motion_dict: Dict[str, Any]) -> None: + """Save a Kimodo-compatible motion dict to ``.npz`` (numpy arrays).""" + np.savez(path, **motion_dict_to_numpy(motion_dict)) + + +def load_kimodo_npz(path: str) -> Dict[str, np.ndarray]: + """Load arrays from a Kimodo ``.npz`` file.""" + with np.load(path, allow_pickle=False) as data: + return {k: np.asarray(data[k]) for k in data.files} + + +def load_g1_csv( + path: str, + source_fps: float = KIMODO_CONVERT_TARGET_FPS, + *, + mujoco_rest_zero: bool = False, +) -> Dict[str, torch.Tensor]: + """Load a G1 MuJoCo ``qpos`` CSV (``(T, 36)``) into a Kimodo motion dict. + + Args: + path: CSV path (comma-separated, no header). + source_fps: Source frame rate (Hz) of the CSV data. + mujoco_rest_zero: Must match how the CSV was written (see :class:`MujocoQposConverter`). + """ + from kimodo.exports.mujoco import MujocoQposConverter + + qpos = np.loadtxt(path, delimiter=",") + if qpos.ndim != 2 or qpos.shape[-1] != 36: + raise ValueError(f"Expected G1 CSV with shape (T, 36); got {qpos.shape}") + sk = build_skeleton(34) + converter = MujocoQposConverter(sk) + return converter.qpos_to_motion_dict(qpos, float(source_fps), mujoco_rest_zero=mujoco_rest_zero) + + +def load_amass_npz( + path: str, + source_fps: float | None = None, + *, + z_up: bool = True, +) -> Dict[str, torch.Tensor]: + """Load an AMASS-style SMPL-X ``.npz`` into a Kimodo motion dict (22 joints). + + Args: + path: NPZ with ``trans``, ``root_orient``, ``pose_body``, etc. + source_fps: Source frame rate (Hz); if ``None``, uses ``mocap_frame_rate`` + from the file when present, else 30 Hz. + z_up: If ``True``, apply AMASS Z-up to Kimodo Y-up transform (same as CLI). + """ + from kimodo.exports.smplx import amass_npz_to_kimodo_motion + + sk = build_skeleton(22) + return amass_npz_to_kimodo_motion(path, sk, source_fps=source_fps, z_up=z_up) + + +def load_kimodo_npz_as_torch( + path: str, + source_fps: float = KIMODO_CONVERT_TARGET_FPS, + *, + ensure_complete: bool = True, +) -> tuple[Dict[str, torch.Tensor], int]: + """Load a Kimodo NPZ and return all arrays as torch tensors on the skeleton device. + + Args: + path: Kimodo NPZ file path. + source_fps: Source frame rate (Hz) used for derived channels when + ``ensure_complete=True``. + ensure_complete: If ``True`` and the NPZ lacks derived channels + (``posed_joints``, ``global_rot_mats``, …), run :func:`complete_motion_dict` + to fill them from ``local_rot_mats`` + ``root_positions``. + If ``False``, load all arrays verbatim (requires ``local_rot_mats``). + + Returns: + ``(tensor_dict, num_joints)`` + """ + raw = load_kimodo_npz(path) + if "local_rot_mats" in raw: + j = int(raw["local_rot_mats"].shape[1]) + elif "posed_joints" in raw: + j = int(raw["posed_joints"].shape[1]) + else: + raise ValueError("Kimodo NPZ must contain 'local_rot_mats' or 'posed_joints'.") + sk = build_skeleton(j) + device = sk.neutral_joints.device + dtype = torch.float32 + + if not ensure_complete: + if "local_rot_mats" not in raw: + raise ValueError("Kimodo NPZ must contain 'local_rot_mats' (and typically 'root_positions').") + out: Dict[str, torch.Tensor] = {} + for k, v in raw.items(): + out[k] = torch.from_numpy(np.asarray(v)).to(device=device, dtype=dtype) + return out, j + + if "posed_joints" in raw and "global_rot_mats" in raw: + out = {} + for k, v in raw.items(): + out[k] = torch.from_numpy(np.asarray(v)).to(device=device, dtype=dtype) + return out, j + + if "local_rot_mats" not in raw or "root_positions" not in raw: + raise ValueError("Kimodo NPZ must contain posed_joints+global_rot_mats, or local_rot_mats+root_positions.") + local = torch.from_numpy(np.asarray(raw["local_rot_mats"])).to(device=device, dtype=dtype) + root = torch.from_numpy(np.asarray(raw["root_positions"])).to(device=device, dtype=dtype) + return complete_motion_dict(local, root, sk, float(source_fps)), j + + +def save_kimodo_npz_at_target_fps( + motion: Dict[str, torch.Tensor], + skeleton: SkeletonBase, + source_fps: float, + output_path: str, + target_fps: float = KIMODO_CONVERT_TARGET_FPS, +) -> None: + """Resample a motion dict to ``target_fps`` when needed, then save Kimodo NPZ.""" + t_before = int(motion["local_rot_mats"].shape[0]) + motion, did_resample = resample_motion_dict_to_kimodo_fps(motion, skeleton, source_fps, target_fps) + t_after = int(motion["local_rot_mats"].shape[0]) + if did_resample: + warn_kimodo_npz_framerate(source_fps, t_before, t_after) + save_kimodo_npz(output_path, motion) + + +def kimodo_npz_to_bytes(motion_dict: Dict[str, Any]) -> bytes: + """Serialize a Kimodo motion dict to in-memory NPZ bytes.""" + import io + + buf = io.BytesIO() + np.savez(buf, **motion_dict_to_numpy(motion_dict)) + return buf.getvalue() + + +def g1_csv_to_bytes(motion_dict: Dict[str, Any], skeleton: SkeletonBase, device: Any) -> bytes: + """Convert a motion dict to G1 MuJoCo CSV bytes via :class:`MujocoQposConverter`.""" + import io + + from kimodo.exports.mujoco import MujocoQposConverter + + converter = MujocoQposConverter(skeleton) + qpos = converter.dict_to_qpos( + {k: v for k, v in motion_dict.items() if k in ("local_rot_mats", "root_positions")}, + device, + numpy=True, + ) + buf = io.StringIO() + np.savetxt(buf, qpos, delimiter=",") + return buf.getvalue().encode("utf-8") + + +def amass_npz_to_bytes(motion_dict: Dict[str, Any], skeleton: SkeletonBase, fps: float) -> bytes: + """Convert a motion dict to AMASS NPZ bytes via :class:`AMASSConverter`.""" + import io + + from kimodo.exports.smplx import AMASSConverter + + converter = AMASSConverter(skeleton=skeleton, fps=fps) + buf = io.BytesIO() + converter.convert_save_npz( + {k: v for k, v in motion_dict.items() if k in ("local_rot_mats", "root_positions")}, + buf, + ) + return buf.getvalue() + + +def _read_amass_source_fps(path: str) -> float: + """Read the source frame rate from an AMASS NPZ, defaulting to 30 Hz.""" + with np.load(path, allow_pickle=True) as z: + if "mocap_frame_rate" in z.files: + return float(z["mocap_frame_rate"]) + return 30.0 + + +def load_motion_file( + path: str, + source_fps: float | None = None, + target_fps: float | None = None, + *, + z_up: bool = True, + mujoco_rest_zero: bool = False, +) -> tuple[Dict[str, torch.Tensor], int]: + """Load a motion file and return a Kimodo motion dict plus joint count. + + Supports SOMA BVH (``.bvh``), G1 MuJoCo CSV (``.csv``), Kimodo NPZ, and AMASS SMPL-X NPZ + (``.npz``). + + The motion is loaded at its native (or overridden) source rate, then + resampled to ``target_fps`` when they differ. + + Args: + path: Path to ``.bvh``, ``.csv``, or ``.npz``. + source_fps: Source frame rate (Hz). If provided, trusted as-is. + If ``None``, auto-detected per format: BVH ``Frame Time`` header, + AMASS ``mocap_frame_rate``, or :data:`KIMODO_CONVERT_TARGET_FPS` + (30 Hz) for CSV / Kimodo NPZ. + target_fps: Desired output frame rate (Hz). Defaults to + :data:`KIMODO_CONVERT_TARGET_FPS` (30 Hz). The motion is + resampled when ``source_fps`` and ``target_fps`` differ. + z_up: AMASS NPZ only; passed to :func:`load_amass_npz`. + mujoco_rest_zero: G1 CSV only; passed to :func:`load_g1_csv`. + + Returns: + ``(motion_dict, num_joints)`` with the same keys as :func:`complete_motion_dict`. + """ + from kimodo.exports.motion_formats import infer_npz_kind + + if target_fps is None: + target_fps = KIMODO_CONVERT_TARGET_FPS + + ext = os.path.splitext(path)[1].lower() + if ext == ".bvh": + from kimodo.exports.bvh import bvh_to_kimodo_motion + + motion_dict, bvh_fps = bvh_to_kimodo_motion(path) + effective_source = source_fps if source_fps is not None else bvh_fps + num_joints = int(motion_dict["local_rot_mats"].shape[1]) + elif ext == ".csv": + effective_source = source_fps if source_fps is not None else KIMODO_CONVERT_TARGET_FPS + motion_dict = load_g1_csv(path, source_fps=effective_source, mujoco_rest_zero=mujoco_rest_zero) + num_joints = 34 + elif ext == ".npz": + kind = infer_npz_kind(path) + if kind == "amass": + effective_source = source_fps if source_fps is not None else _read_amass_source_fps(path) + motion_dict = load_amass_npz(path, source_fps=effective_source, z_up=z_up) + num_joints = 22 + else: + effective_source = source_fps if source_fps is not None else KIMODO_CONVERT_TARGET_FPS + motion_dict, num_joints = load_kimodo_npz_as_torch(path, source_fps=effective_source) + else: + raise ValueError(f"Unsupported motion file {path!r}; expected .bvh, .csv, or .npz") + + if abs(effective_source - target_fps) > 0.5: + sk = build_skeleton(num_joints) + motion_dict, did_resample = resample_motion_dict_to_kimodo_fps(motion_dict, sk, effective_source, target_fps) + if did_resample: + t_out = int(motion_dict["local_rot_mats"].shape[0]) + warnings.warn( + f"Resampled motion from {effective_source:.4g} Hz to " f"{target_fps:.0f} Hz ({t_out} frames).", + UserWarning, + stacklevel=2, + ) + + return motion_dict, num_joints diff --git a/kimodo/exports/mujoco.py b/kimodo/exports/mujoco.py new file mode 100644 index 0000000000000000000000000000000000000000..77015dd24015f239529c1437d393cdc5859cdd97 --- /dev/null +++ b/kimodo/exports/mujoco.py @@ -0,0 +1,588 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Convert kimodo motion (y-up, z-forward) to MuJoCo qpos (z-up, x-forward) for G1 skeleton.""" + +import os +import xml.etree.ElementTree as ET +from typing import Optional + +import numpy as np +import torch +from scipy.spatial.transform import Rotation + +from kimodo.assets import skeleton_asset_path +from kimodo.geometry import ( + axis_angle_to_matrix, + matrix_to_axis_angle, + matrix_to_quaternion, + quaternion_to_matrix, +) +from kimodo.skeleton import G1Skeleton34, SkeletonBase, global_rots_to_local_rots +from kimodo.tools import ensure_batched, to_numpy, to_torch + +# Cache so that the same (skeleton, xml_path) returns the same converter instance. +_converter_cache: dict[tuple[int, str], "MujocoQposConverter"] = {} + + +class MujocoQposConverter: + """Fast batch converter from our dictionary format to mujoco qpos with precomputed transforms. + + In mujoco, the coordination is z up and x forward, right handed. + + Features (30 joints): + - root (pelvis, 7 = translation + rotation) + 29 dof joints (29) + + In kimodo, the coordinate system is y up and z forward, right handed. + Features (34 joints): + - root (pelvis) + (34 - 1) joints; among these joints, 4 are end-effector joints added by kimodo. + + Cached by (input_skeleton id, xml_path); repeated calls with the same args return the same instance. + """ + + def __new__( + cls, + input_skeleton: SkeletonBase, + xml_path: str = str(skeleton_asset_path("g1skel34", "xml", "g1.xml")), + ): + key = (id(input_skeleton), xml_path) + if key not in _converter_cache: + inst = object.__new__(cls) + _converter_cache[key] = inst + return _converter_cache[key] + + def __init__( + self, + input_skeleton: SkeletonBase, + xml_path: str = str(skeleton_asset_path("g1skel34", "xml", "g1.xml")), + ): + """Initialize converter with precomputed transforms. + + Args: + xml_path: Path to the mujoco XML file containing joint definitions + """ + if getattr(self, "_initialized", False): + return + self.xml_path = xml_path + self.skeleton = input_skeleton + self._prepare_transforms() + self._subtree_joints = {} + self._initialized = True + + def _prepare_transforms(self): + """Precompute all necessary transforms for efficient batch processing.""" + # Define coordinate transformations between mujoco and kimodo space + # 1) R_zup_to_yup: rotation around x-axis by -90 degrees + # 2) x_forward_to_y_forward: rotation around z-axis by -90 degrees + # Combined transformation matrix: mujoco_to_kimodo = R_zup_to_yup * x_forward_to_y_forward + self.mujoco_to_kimodo_matrix = torch.tensor( + [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], dtype=torch.float32 + ) + self.kimodo_to_mujoco_matrix = self.mujoco_to_kimodo_matrix.T # Inverse transformation: kimodo_to_mujoco + + # Parse XML once and extract joint information + tree = ET.parse(self.xml_path) + root = tree.getroot() + + xml_classes = [x for x in tree.findall(".//default") if "class" in x.attrib] + joint_axes = dict() + class_ranges: dict[str, tuple[float, float]] = {} + for xml_class in xml_classes: + j = xml_class.findall("joint") + if j: + joint_axes[xml_class.get("class")] = j[0].get("axis") + range_str = j[0].get("range") + if range_str: + range_vals = [float(x) for x in range_str.split()] + if len(range_vals) == 2: + class_ranges[xml_class.get("class")] = ( + range_vals[0], + range_vals[1], + ) + + mujoco_hinge_joints = root.find("worldbody").findall(".//joint") # skip the base joint + self._mujoco_joint_axis_values_kimodo_space = torch.zeros( + (len(mujoco_hinge_joints), 3), dtype=torch.float32 + ) # mujoco order but kimodo space + self._mujoco_joint_axis_values_mujoco_space = torch.zeros( + (len(mujoco_hinge_joints), 3), dtype=torch.float32 + ) # mujoco order but mujoco space + + # for the below indices, mujoco_indices_to_kimodo_indices does not include mujoco root (30 - 1 = 29 elements), + # while kimodo_indices_to_mujoco_indices inclues the kimodo root (32 elements). + self._mujoco_indices_to_kimodo_indices = torch.zeros((len(mujoco_hinge_joints),), dtype=torch.int32) + self._kimodo_indices_to_mujoco_indices = ( + torch.ones((self.skeleton.nbjoints,), dtype=torch.int32) * -1 + ) # -1 means not in the csv skeleton + + self._nb_joints_mujoco = len(mujoco_hinge_joints) + 1 + self._nb_joints_kimodo = self.skeleton.nbjoints + self._mujoco_joint_including_root_parent_list = torch.full( + (len(mujoco_hinge_joints) + 1,), -1, dtype=torch.int32 + ) + self._mujoco_joint_including_root_list = ["pelvis_skel"] + + for joint_id_in_csv, joint in enumerate(mujoco_hinge_joints): + joint_name_in_skeleton = joint.get("name").replace("_joint", "_skel") + joint_parent_name_in_skeleton = self.skeleton.bone_parents[joint_name_in_skeleton] + + self._mujoco_joint_including_root_list.append(joint_name_in_skeleton) + self._mujoco_joint_including_root_parent_list[joint_id_in_csv + 1] = ( + self._mujoco_joint_including_root_list.index(joint_parent_name_in_skeleton) + ) + + joint_idx_in_kimodo_skeleton = self.skeleton.bone_order_names.index(joint_name_in_skeleton) + axis_values = [float(x) for x in (joint.get("axis") or joint_axes[joint.get("class")]).split(" ")] + + # the mapped axis in kimodo skeleton space is calculated as bones_axis = mujoco_to_kimodo.apply(axis_values) + # [1, 0, 0] -> [0, 0, 1]; [0, 1, 0] -> [1, 0, 0]; [0, 0, 1] -> [0, 1, 0] + mujoco_joint_axis_mapping_kimodo_space = [ + torch.tensor([0, 0, 1]), + torch.tensor([1, 0, 0]), + torch.tensor([0, 1, 0]), + ][np.argmax(axis_values)] + + self._mujoco_joint_axis_values_kimodo_space[joint_id_in_csv] = mujoco_joint_axis_mapping_kimodo_space + self._mujoco_joint_axis_values_mujoco_space[joint_id_in_csv] = torch.tensor(axis_values) + + self._mujoco_indices_to_kimodo_indices[joint_id_in_csv] = joint_idx_in_kimodo_skeleton + self._kimodo_indices_to_mujoco_indices[joint_idx_in_kimodo_skeleton] = ( + joint_id_in_csv + 1 + ) # +1 for the root + self._kimodo_indices_to_mujoco_indices[0] = 0 # the root joint mapping + + # Joint limits (min, max) in radians for each mujoco hinge, for clamping + self._joint_limits_min = torch.full((len(mujoco_hinge_joints),), float("-inf"), dtype=torch.float32) + self._joint_limits_max = torch.full((len(mujoco_hinge_joints),), float("inf"), dtype=torch.float32) + for joint_id_in_csv, joint in enumerate(mujoco_hinge_joints): + range_vals = None + if joint.get("range"): + range_vals = [float(x) for x in joint.get("range").split()] + elif joint.get("class") and joint.get("class") in class_ranges: + lo, hi = class_ranges[joint.get("class")] + range_vals = [lo, hi] + if range_vals is not None and len(range_vals) == 2: + self._joint_limits_min[joint_id_in_csv] = range_vals[0] + self._joint_limits_max[joint_id_in_csv] = range_vals[1] + + # load the offset matrices from the xml + R_zup_to_yup = Rotation.from_euler("x", -90, degrees=True) + x_forward_to_y_forward = Rotation.from_euler("z", -90, degrees=True) + mujoco_to_kimodo = R_zup_to_yup * x_forward_to_y_forward + + self._rot_offsets_q2t = torch.zeros(len(self._kimodo_indices_to_mujoco_indices), 3, 3, dtype=torch.float32) + self._rot_offsets_q2t[...] = torch.eye(3)[None] + + self._rot_offsets_f2q = torch.zeros(len(self._kimodo_indices_to_mujoco_indices), 3, 3, dtype=torch.float32) + self._rot_offsets_f2q[...] = torch.eye(3)[None] + parent_map = {child: parent for parent in root.iter() for child in parent} + for i, joint in enumerate(mujoco_hinge_joints): + body = parent_map[joint] + if "quat" in body.attrib: + rot = Rotation.from_quat( + [float(x) for x in body.get("quat").strip().split(" ")], + scalar_first=True, + ) + idx = self._mujoco_indices_to_kimodo_indices[i] + self._rot_offsets_q2t[idx] = torch.from_numpy(rot.as_matrix()) + rot = mujoco_to_kimodo * rot * mujoco_to_kimodo.inv() + self._rot_offsets_f2q[idx] = torch.from_numpy(rot.as_matrix().T) + + # Hinge axis in f2q space so extraction uses the same frame as joint_rot_f2q. + # Then extract(offset) gives the angle s.t. axis_angle(angle * axis_f2q) = offset, and + # reconstruction R_local = offset.T @ axis_angle(angle * axis_f2q) = I when input is identity. + axis_kimodo = self._mujoco_joint_axis_values_kimodo_space + self._mujoco_joint_axis_values_f2q_space = torch.zeros_like(axis_kimodo) + for i in range(len(mujoco_hinge_joints)): + j = self._mujoco_indices_to_kimodo_indices[i].item() + axis_f2q = torch.mv(self._rot_offsets_f2q[j], axis_kimodo[i]) + n = axis_f2q.norm() + if n > 1e-8: + axis_f2q = axis_f2q / n + self._mujoco_joint_axis_values_f2q_space[i] = axis_f2q + + # Rest-pose DOFs: angle we extract when R_local = I (t-pose). MuJoCo limits are + # relative to joint zero (rest pose), so we must clamp in MuJoCo space: convert + # joint_dofs to mujoco_angle = joint_dofs - rest_dofs, clamp, then back. + rest_rot_f2q = self._rot_offsets_f2q[self._mujoco_indices_to_kimodo_indices] + rest_rot_f2q = rest_rot_f2q.unsqueeze(0).unsqueeze(0) + self._rest_dofs = self._local_rots_f2q_to_joint_dofs(rest_rot_f2q).squeeze(0).squeeze(0) + # Axis-angle rest DOFs: angle s.t. axis_angle(angle * axis_f2q) = offset. Used in + # project_to_real_robot_rotations so extract+reconstruct round-trip and t-pose is preserved. + rest_rot_f2q_flat = self._rot_offsets_f2q[self._mujoco_indices_to_kimodo_indices] + full_aa = matrix_to_axis_angle(rest_rot_f2q_flat) + self._rest_dofs_axis_angle = (full_aa * self._mujoco_joint_axis_values_f2q_space).sum(dim=-1) + + def dict_to_qpos( + self, + output: dict, + device: Optional[str] = None, + root_quat_w_first: bool = True, + numpy: bool = True, + mujoco_rest_zero: bool = False, + ): + """Convert kimodo output dict to mujoco qpos format. + + Args: + output: dict with keys "local_rot_mats" and "root_positions". + device: device to use for the output. + root_quat_w_first: If True, quaternion in qpos is (w,x,y,z). + numpy: If True, convert the output to numpy array. + mujoco_rest_zero: If True, joint angles are written so that kimodo rest (t-pose) + maps to q=0 in MuJoCo. If False, write raw joint_dofs. + + Returns: + qpos: (B, T, 7+J) mujoco qpos format. + """ + local_rot_mats = to_torch(output["local_rot_mats"], device) + root_positions = to_torch(output["root_positions"], device) + + qpos = self.to_qpos( + local_rot_mats, + root_positions, + root_quat_w_first=root_quat_w_first, + mujoco_rest_zero=mujoco_rest_zero, + ) + if numpy: + qpos = to_numpy(qpos) + return qpos + + def qpos_to_motion_dict( + self, + qpos: torch.Tensor | np.ndarray, + source_fps: float, + *, + root_quat_w_first: bool = True, + mujoco_rest_zero: bool = False, + ): + """Inverse of :meth:`to_qpos` / :meth:`dict_to_qpos` for MuJoCo CSV ``(T, 36)`` rows. + + Args: + qpos: Shape ``(T, 36)`` or ``(1, T, 36)`` (root xyz, root quat wxyz, 29 joint angles). + source_fps: Source frame rate (Hz) of the qpos data. + root_quat_w_first: Must match how the CSV was written (default ``True``). + mujoco_rest_zero: Must match :meth:`dict_to_qpos` / :meth:`to_qpos`. + + Returns: + Kimodo motion dict (see :func:`kimodo.exports.motion_io.complete_motion_dict`). + """ + from kimodo.exports.motion_io import complete_motion_dict + + qpos = to_torch(qpos, None) + if qpos.dim() == 2: + qpos = qpos.unsqueeze(0) + device = qpos.device + dtype = qpos.dtype + batch_size, num_frames, ncols = qpos.shape + if ncols != 36: + raise ValueError(f"Expected qpos last dim 36; got {ncols}") + + kimodo_to_mujoco_matrix = self.kimodo_to_mujoco_matrix.to(device=device, dtype=dtype) + mujoco_to_kimodo_matrix = kimodo_to_mujoco_matrix.T + + root_mujoco = qpos[..., :3] + root_positions = torch.matmul(mujoco_to_kimodo_matrix[None, None, ...], root_mujoco[..., None]).squeeze(-1) + + quat = qpos[..., 3:7] + if root_quat_w_first: + root_rot_mujoco = quaternion_to_matrix(quat) + else: + quat_wxyz = quat[..., [3, 0, 1, 2]] + root_rot_mujoco = quaternion_to_matrix(quat_wxyz) + + O0 = self._rot_offsets_f2q[0].to(device=device, dtype=dtype) + # root_rot_mujoco is (..., 3, 3) after optional batch unsqueeze (e.g. (1, T, 3, 3)). + # Use ``...il`` so ``k`` sums with ``kl``; ``...ik`` incorrectly keeps ``k`` in the output. + R_f2q_root = torch.einsum( + "ij,...jk,kl->...il", + mujoco_to_kimodo_matrix, + root_rot_mujoco, + kimodo_to_mujoco_matrix, + ) + R_kimodo_root = torch.einsum("ij,...jk->...ik", O0.T, R_f2q_root) + + joint_dofs = qpos[..., 7:] + if mujoco_rest_zero: + rest_dofs = self._rest_dofs.to(device=device, dtype=dtype) + angles = joint_dofs + rest_dofs[None, None, :] + use_relative = True + else: + angles = joint_dofs + use_relative = False + + nb_joints = self.skeleton.nbjoints + template = torch.eye(3, device=device, dtype=dtype).expand(batch_size, num_frames, nb_joints, 3, 3).contiguous() + template[:, :, 0] = R_kimodo_root + + local_rot_mats = self._joint_dofs_to_local_rot_mats( + angles, + template, + device, + dtype, + use_relative=use_relative, + ) + + if batch_size != 1: + raise ValueError(f"Only a single clip is supported; got batch_size={batch_size}") + + return complete_motion_dict(local_rot_mats[0], root_positions[0], self.skeleton, source_fps) + + def save_csv(self, qpos: torch.Tensor | np.ndarray, csv_path): + # comment this + qpos = to_numpy(qpos) + shape = qpos.shape + if len(shape) == 2: + # only one motion: save it + np.savetxt(csv_path, qpos, delimiter=",") + if len(shape) == 3: + # batch of motions + if shape[0] == 1: + # if only one motion, just save it + np.savetxt(csv_path, qpos[0], delimiter=",") + else: + csv_path_base, ext = os.path.splitext(csv_path) + for i in range(shape[0]): + self.save_csv(qpos[i], csv_path_base + "_" + str(i).zfill(2) + ext) + + def _local_rots_to_joint_dofs( + self, + local_rot_mats: torch.Tensor, + axis_vals: torch.Tensor, + ) -> torch.Tensor: + """Extract per-joint single-DoF angles (radians) via Euler projection (for to_qpos/f2q).""" + x_joint_dof = torch.atan2(local_rot_mats[..., 2, 1], local_rot_mats[..., 2, 2]) + y_joint_dof = torch.atan2(local_rot_mats[..., 0, 2], local_rot_mats[..., 0, 0]) + z_joint_dof = torch.atan2(local_rot_mats[..., 1, 0], local_rot_mats[..., 1, 1]) + xyz_joint_dofs = torch.stack([x_joint_dof, y_joint_dof, z_joint_dof], dim=-1) + axis_vals = axis_vals.to(device=local_rot_mats.device, dtype=local_rot_mats.dtype) + joint_dofs = (xyz_joint_dofs * axis_vals[None, None, :, :]).sum(dim=-1) + return joint_dofs + + def _local_rots_to_joint_dofs_axis_angle( + self, + local_rot_mats: torch.Tensor, + axis_vals: torch.Tensor, + ) -> torch.Tensor: + """Extract per-joint single-DoF angles (radians) via axis-angle; round-trips with + axis_angle_to_matrix. + + Args: + local_rot_mats: (..., num_hinges, 3, 3) in same frame as axis_vals. + axis_vals: (num_hinges, 3) unit axis per hinge. + Returns: + joint_dofs: (..., num_hinges) signed angle = dot(axis_angle(R), axis). + """ + axis_vals = axis_vals.to(device=local_rot_mats.device, dtype=local_rot_mats.dtype) + full_aa = matrix_to_axis_angle(local_rot_mats) + joint_dofs = (full_aa * axis_vals).sum(dim=-1) + return joint_dofs + + def _local_rots_f2q_to_joint_dofs(self, local_rot_mats_f2q: torch.Tensor) -> torch.Tensor: + """Extract per-joint single-DoF angles from local rotations in f2q space (for to_qpos).""" + axis_vals = self._mujoco_joint_axis_values_f2q_space + return self._local_rots_to_joint_dofs(local_rot_mats_f2q, axis_vals) + + def _clamp_to_limits(self, joint_dofs: torch.Tensor) -> torch.Tensor: + """Clamp joint angles to XML limits (radians). + + Angles are in kimodo convention (0 = rest). + """ + device = joint_dofs.device + lo = self._joint_limits_min.to(device=device, dtype=joint_dofs.dtype) + hi = self._joint_limits_max.to(device=device, dtype=joint_dofs.dtype) + return torch.clamp(joint_dofs, lo[None, None, :], hi[None, None, :]) + + def _clamp_joint_dofs(self, joint_dofs: torch.Tensor, rest_dofs: torch.Tensor) -> torch.Tensor: + """Clamp joint angles to MuJoCo limits (radians), with rest_dofs conversion.""" + device = joint_dofs.device + rest_dofs = rest_dofs.to(device=device, dtype=joint_dofs.dtype) + mujoco_dofs = joint_dofs - rest_dofs[None, None, :] + lo = self._joint_limits_min.to(device=device, dtype=joint_dofs.dtype) + hi = self._joint_limits_max.to(device=device, dtype=joint_dofs.dtype) + mujoco_dofs = torch.clamp(mujoco_dofs, lo[None, None, :], hi[None, None, :]) + return mujoco_dofs + rest_dofs[None, None, :] + + def _joint_dofs_to_local_rot_mats( + self, + joint_dofs: torch.Tensor, + original_local_rot_mats: torch.Tensor, + device: torch.device, + dtype: torch.dtype, + use_relative: bool = False, + ) -> torch.Tensor: + """Reconstruct full local rotation matrices from 1-DoF angles.""" + out = original_local_rot_mats.clone() + axis_kimodo = self._mujoco_joint_axis_values_kimodo_space.to(device=device, dtype=dtype) + for i in range(joint_dofs.shape[-1]): + j = self._mujoco_indices_to_kimodo_indices[i].item() + angle = joint_dofs[..., i] + axis = axis_kimodo[i] + if use_relative: + axis_angle = angle[..., None] * axis[None, None, :] + R_local = axis_angle_to_matrix(axis_angle) + else: + rot_offsets_f2q = self._rot_offsets_f2q.to(device=device, dtype=dtype) + axis_in_f2q = torch.mv(rot_offsets_f2q[j], axis) + axis_angle = angle[..., None] * axis_in_f2q[None, None, :] + R_f2q = axis_angle_to_matrix(axis_angle) + R_local = torch.einsum("ij,btjk->btik", rot_offsets_f2q[j].T, R_f2q) + out[:, :, j, :, :] = R_local + return out + + @ensure_batched(local_rot_mats=5, root_positions=3, lengths=1) + def project_to_real_robot_rotations( + self, + local_rot_mats: torch.Tensor, + root_positions: torch.Tensor, + clamp_to_limits: bool = True, + mujoco_rest_zero: bool = False, + ) -> dict: + """Project full 3D local rotations to G1 real robot DoF and back to 3D for viz. + + Joint angles are extracted along each hinge axis, optionally clamped to XML limits, then + reconstructed to 3D rotations. When mujoco_rest_zero=False (default), raw angles are used + (baked-with-quat). When True, angles are relative to rest (0 = T-pose in MuJoCo). + """ + device = local_rot_mats.device + dtype = local_rot_mats.dtype + + # Transform to f2q frame and extract 1-DoF angles (axis-angle projection). + local_rot_f2q = torch.matmul(self._rot_offsets_f2q.to(device=device, dtype=dtype), local_rot_mats) + hinge_rots = local_rot_f2q[:, :, self._mujoco_indices_to_kimodo_indices, :, :] + axis_f2q = self._mujoco_joint_axis_values_f2q_space.to(device=device, dtype=dtype) + joint_dofs = self._local_rots_to_joint_dofs_axis_angle(hinge_rots, axis_f2q) + + # Optionally express angles relative to rest (MuJoCo q=0 at T-pose). + if mujoco_rest_zero: + rest_dofs = self._rest_dofs_axis_angle.to(device=device, dtype=dtype) + angles = joint_dofs - rest_dofs[None, None, :] + use_relative = True + else: + angles = joint_dofs + use_relative = False + + if clamp_to_limits: + if mujoco_rest_zero: + angles = self._clamp_to_limits(angles) + else: + rest_dofs_aa = self._rest_dofs_axis_angle.to(device=device, dtype=dtype) + angles = self._clamp_joint_dofs(angles, rest_dofs_aa) + + # Reconstruct 3D local rotations from 1-DoF angles and run FK. + local_rot_mats_proj = self._joint_dofs_to_local_rot_mats( + angles, local_rot_mats, device, dtype, use_relative=use_relative + ) + global_rot_mats, posed_joints, _ = self.skeleton.fk(local_rot_mats_proj, root_positions) + return { + "local_rot_mats": local_rot_mats_proj, + "global_rot_mats": global_rot_mats, + "posed_joints": posed_joints, + "root_positions": root_positions, + } + + @ensure_batched(local_rot_mats=5, root_positions=3, lengths=1) + def to_qpos( + self, + local_rot_mats: torch.Tensor, + root_positions: torch.Tensor, + root_quat_w_first: bool = True, + mujoco_rest_zero: bool = False, + ) -> torch.Tensor: + """Fast batch conversion from kimodo features to mujoco qpos format. + + Args: + local_rot_mats: (B, T, J, 3, 3) local rotation matrices (kimodo convention). + root_positions: (B, T, 3) root positions. + root_quat_w_first: If True, quaternion in qpos is (w,x,y,z). + mujoco_rest_zero: If True, joint angles are written so that kimodo rest (t-pose) + maps to q=0 in MuJoCo. If False, write raw joint_dofs. + + Returns: + torch.Tensor of shape [batch, numFrames, 36] containing mujoco qpos data: + - root_trans (3) + root_quat (4) + joint_dofs (29) = 36 columns + """ + + batch_size, num_frames, nb_joints = local_rot_mats.shape[:3] + device, dtype = local_rot_mats.device, local_rot_mats.dtype + + local_rot_mats = torch.matmul(self._rot_offsets_f2q.to(device), local_rot_mats) + + batch_size, num_frames = root_positions.shape[0], root_positions.shape[1] + + # Move precomputed matrices to the same device/dtype + kimodo_to_mujoco_matrix = self.kimodo_to_mujoco_matrix.to(device=device, dtype=dtype) + + # Initialize output tensor: [batch, numFrames, 36] + qpos = torch.zeros((batch_size, num_frames, 36), dtype=dtype, device=device) + + # Convert root translation: apply coordinate transformation + root_positions_mujoco = torch.matmul(kimodo_to_mujoco_matrix[None, None, ...], root_positions[..., None]) + qpos[:, :, :3] = root_positions_mujoco.view(batch_size, num_frames, 3) + + # Convert root rotation: apply coordinate transformation to rotation matrix + root_rot = local_rot_mats[:, :, 0, :] # [batch, numFrames, 3, 3] + + # Apply coordinate transformation: R_mujoco = kimodo_to_mujoco * R_kimodo * kimodo_to_mujoco^T + mujoco_to_kimodo_matrix = kimodo_to_mujoco_matrix.T + root_rot_mujoco = torch.matmul( + torch.matmul(kimodo_to_mujoco_matrix[None, None, ...], root_rot), + mujoco_to_kimodo_matrix[None, None, ...], + ) + root_rot_quat = matrix_to_quaternion(root_rot_mujoco) # [w, x, y, z] + if root_quat_w_first: + qpos[:, :, 3:7] = root_rot_quat[:, :, [0, 1, 2, 3]] # [w, x, y, z] + else: + qpos[:, :, 3:7] = root_rot_quat[:, :, [1, 2, 3, 0]] # [w, x, y, z] -> [x, y, z, w] + + # Joint DOFs: raw angles or relative to rest (rest = q=0 in MuJoCo). + joint_rot_f2q = local_rot_mats[:, :, self._mujoco_indices_to_kimodo_indices, :, :] + joint_dofs = self._local_rots_f2q_to_joint_dofs(joint_rot_f2q) + if mujoco_rest_zero: + rest_dofs = self._rest_dofs.to(device=device, dtype=dtype) + qpos[:, :, 7:] = joint_dofs - rest_dofs[None, None, :] + else: + qpos[:, :, 7:] = joint_dofs + return qpos + + +def apply_g1_real_robot_projection( + skeleton: G1Skeleton34, + joints_pos: torch.Tensor, + joints_rot: torch.Tensor, + clamp_to_limits: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """Project G1 motion to real robot DoF (1-DoF per joint) with optional axis limits. + + Extracts a single angle per hinge along its axis (1-DoF), optionally clamps to + joint limits from the MuJoCo XML (when clamp_to_limits=True), then reconstructs + 3D rotations and runs FK. T-pose (identity local rotations) is preserved. + + Args: + skeleton: G1 skeleton instance. + joints_pos: (T, J, 3) or (B, T, J, 3) joint positions in global space. + joints_rot: (T, J, 3, 3) or (B, T, J, 3, 3) global rotation matrices. + clamp_to_limits: If True, clamp joint angles to XML axis limits (default True). + + Returns: + (posed_joints, global_rot_mats) as tensors, same shape as inputs (batch preserved). + """ + + local_rot_mats = global_rots_to_local_rots(joints_rot, skeleton) + root_positions = joints_pos[..., skeleton.root_idx, :] + + # Converter expects batch dim (B, T, ...); add and remove if single sequence. + single_sequence = local_rot_mats.dim() == 4 + if single_sequence: + local_rot_mats = local_rot_mats.unsqueeze(0) + root_positions = root_positions.unsqueeze(0) + + converter = MujocoQposConverter(skeleton) + projected = converter.project_to_real_robot_rotations( + local_rot_mats, root_positions, clamp_to_limits=clamp_to_limits + ) + + out_pos = projected["posed_joints"] + out_rot = projected["global_rot_mats"] + if single_sequence: + out_pos = out_pos.squeeze(0) + out_rot = out_rot.squeeze(0) + return out_pos, out_rot diff --git a/kimodo/exports/smplx.py b/kimodo/exports/smplx.py new file mode 100644 index 0000000000000000000000000000000000000000..ce1d15262fccf91800b006a9b679731e95431da6 --- /dev/null +++ b/kimodo/exports/smplx.py @@ -0,0 +1,251 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Convert kimodo motion to AMASS/SMPL-X compatible parameters (axis-angle, Y-up or Z-up).""" + +import os +from typing import Optional + +import einops +import numpy as np +import torch + +from kimodo.assets import skeleton_asset_path +from kimodo.geometry import axis_angle_to_matrix, matrix_to_axis_angle +from kimodo.tools import ensure_batched, to_numpy, to_torch + + +def kimodo_y_up_to_amass_coord_rotation_matrix() -> np.ndarray: + """3x3 rotation mapping Kimodo Y-up (+Z forward) to AMASS Z-up (+Y forward). + + Used by :func:`get_amass_parameters` and :func:`amass_arrays_to_kimodo_motion` (inverse). + """ + y_up_to_z_up = np.array( + [ + [1.0, 0.0, 0.0], + [0.0, 0.0, -1.0], + [0.0, 1.0, 0.0], + ], + dtype=np.float32, + ) + rot_z_180 = np.array( + [ + [-1.0, 0.0, 0.0], + [0.0, -1.0, 0.0], + [0.0, 0.0, 1.0], + ], + dtype=np.float32, + ) + return np.matmul(rot_z_180, y_up_to_z_up).astype(np.float32) + + +@ensure_batched(local_rot_mats=5, root_positions=3, lengths=1) +def get_amass_parameters( + local_rot_mats, + root_positions, + skeleton, + z_up=True, +): + """Convert local rot mats and root positions to AMASS-style trans and pose_body; optional z_up + coordinate transform. + + Our method generates motions with Y-up and +Z forward; if z_up=True, transform to Z-up and +Y + forward as in AMASS. + """ + # Our method generate motions with Y-up and +Z forward + # if z_up = True, we transform this to: Z-up with +Y forward, as in AMASS + # Remove the root offset; SMPL-X FK adds pelvis offset back. + pelvis_offset = skeleton.neutral_joints[skeleton.root_idx].cpu().numpy() + trans = root_positions - pelvis_offset + + root_rot_mats = to_numpy(local_rot_mats[:, :, 0]) + local_rot_axis_angle = to_numpy(matrix_to_axis_angle(to_torch(local_rot_mats))) + pose_body = einops.rearrange(local_rot_axis_angle[:, :, 1:], "b t j d -> b t (j d)") + + # Optionally convert from Y-up to Z-up coordinates. + if z_up: + y_up_to_z_up = kimodo_y_up_to_amass_coord_rotation_matrix() + root_rot_mats = np.matmul(y_up_to_z_up, root_rot_mats) + trans = np.matmul(trans + pelvis_offset, y_up_to_z_up.T) - pelvis_offset + + root_orient = to_numpy(matrix_to_axis_angle(to_torch(root_rot_mats))) + return trans, root_orient, pose_body + + +def amass_arrays_to_kimodo_motion( + trans: np.ndarray, + root_orient: np.ndarray, + pose_body: np.ndarray, + skeleton, + source_fps: float, + *, + z_up: bool = True, +): + """Inverse of :func:`get_amass_parameters` for a single sequence (AMASS → Kimodo motion dict). + + Args: + trans: ``(T, 3)`` AMASS root translation (same as ``trans`` in AMASS NPZ). + root_orient: ``(T, 3)`` axis-angle root orientation in AMASS coordinates (z-up when ``z_up``). + pose_body: ``(T, 63)`` body pose axis-angle (21 joints × 3). + skeleton: :class:`~kimodo.skeleton.definitions.SMPLXSkeleton22` instance. + source_fps: Source frame rate (Hz) of the AMASS recording. + z_up: If ``True``, invert the same Y-up↔Z-up transform as ``get_amass_parameters(..., z_up=True)``. + + Returns: + Motion dict compatible with :func:`kimodo.exports.motion_io.save_kimodo_npz`. + """ + from kimodo.exports.motion_io import complete_motion_dict + + trans = np.asarray(trans, dtype=np.float32) + root_orient = np.asarray(root_orient, dtype=np.float32) + pose_body = np.asarray(pose_body, dtype=np.float32) + if trans.ndim != 2 or trans.shape[-1] != 3: + raise ValueError(f"trans must be (T, 3); got {trans.shape}") + if root_orient.shape != trans.shape: + raise ValueError(f"root_orient shape {root_orient.shape} must match trans {trans.shape}") + t = trans.shape[0] + if pose_body.shape != (t, 63): + raise ValueError(f"pose_body must be (T, 63); got {pose_body.shape}") + + pelvis_offset = skeleton.neutral_joints[skeleton.root_idx].detach().cpu().numpy().astype(np.float32) + device = skeleton.neutral_joints.device + dtype = torch.float32 + + Y_np = kimodo_y_up_to_amass_coord_rotation_matrix() + if z_up: + y_up_to_z_up = torch.from_numpy(Y_np).to(device=device, dtype=dtype) + # trans_amass = root_kimodo @ Y.T - pelvis_offset => root_kimodo = (trans_amass + pelvis_offset) @ Y + root_positions_np = (trans + pelvis_offset) @ Y_np + else: + root_positions_np = trans + pelvis_offset + + root_positions = torch.from_numpy(root_positions_np).to(device=device, dtype=dtype) + + R_amass_root = axis_angle_to_matrix(torch.from_numpy(root_orient).to(device=device, dtype=dtype)) + if z_up: + R_kimodo_root = torch.einsum("ij,tjk->tik", y_up_to_z_up.T, R_amass_root) + else: + R_kimodo_root = R_amass_root + + nb = skeleton.nbjoints + if nb != 22: + raise ValueError(f"Expected SMPL-X body skeleton with 22 joints; got {nb}") + + local_rot_mats = torch.zeros((t, nb, 3, 3), device=device, dtype=dtype) + local_rot_mats[:, 0] = R_kimodo_root + + pose_aa = torch.from_numpy(pose_body.reshape(t, 21, 3)).to(device=device, dtype=dtype) + local_rot_mats[:, 1:] = axis_angle_to_matrix(pose_aa.reshape(-1, 3)).reshape(t, 21, 3, 3) + + return complete_motion_dict(local_rot_mats, root_positions, skeleton, source_fps) + + +def amass_npz_to_kimodo_motion(npz_path: str, skeleton, source_fps: Optional[float] = None, *, z_up: bool = True): + """Load an AMASS-style ``.npz`` and return a Kimodo motion dict. + + Args: + npz_path: Path to AMASS NPZ (``trans``, ``root_orient``, ``pose_body``, ...). + skeleton: SMPL-X skeleton instance. + source_fps: Source frame rate (Hz); if ``None``, uses ``mocap_frame_rate`` + from the file when present, else ``30.0``. + z_up: Same meaning as :func:`amass_arrays_to_kimodo_motion`. + """ + with np.load(npz_path, allow_pickle=True) as data: + trans = np.asarray(data["trans"], dtype=np.float32) + root_orient = np.asarray(data["root_orient"], dtype=np.float32) + pose_body = np.asarray(data["pose_body"], dtype=np.float32) + if source_fps is None: + source_fps = float(data["mocap_frame_rate"]) if "mocap_frame_rate" in data.files else 30.0 + + return amass_arrays_to_kimodo_motion(trans, root_orient, pose_body, skeleton, source_fps, z_up=z_up) + + +class AMASSConverter: + def __init__( + self, + fps, + skeleton, + beta_path=str(skeleton_asset_path("smplx22", "beta.npy")), + mean_hands_path=str(skeleton_asset_path("smplx22", "mean_hands.npy")), + ): + self.fps = fps + self.skeleton = skeleton + # Load betas + if os.path.exists(beta_path): + # only use first 16 betas to match AMASS + betas = np.load(beta_path)[:16] + else: + betas = np.zeros(16) + + # Load mean hands + if os.path.exists(mean_hands_path): + mean_hands = np.load(mean_hands_path) + else: + mean_hands = np.zeros(90) + + self.default_frame_params = { + "pose_jaw": np.zeros(3), + "pose_eye": np.zeros(6), + "pose_hand": mean_hands, + } + self.output_dict_base = { + "gender": "neutral", + "surface_model_type": "smplx", + "betas": betas, + "num_betas": len(betas), + "mocap_frame_rate": float(fps), + } + + def convert_save_npz(self, output: dict, npz_path, z_up=True): + trans, root_orient, pose_body = get_amass_parameters( + output["local_rot_mats"], + output["root_positions"], + self.skeleton, + z_up=z_up, + ) + nb_frames = trans.shape[-2] + + amass_output_base = self.output_dict_base.copy() + for key, val in self.default_frame_params.items(): + amass_output_base[key] = einops.repeat(val, "d -> t d", t=nb_frames) + + amass_output_base["mocap_time_length"] = nb_frames / self.fps + self.save_npz(trans, root_orient, pose_body, amass_output_base, npz_path) + + def save_npz(self, trans, root_orient, pose_body, base_output, npz_path): + shape = trans.shape + if len(shape) == 3 and shape[0] == 1: + # if only one motion, squeeze the data + trans = trans[0] + root_orient = root_orient[0] + pose_body = pose_body[0] + shape = trans.shape + if len(shape) == 2: + amass_output = { + "trans": trans, + "root_orient": root_orient, + "pose_body": pose_body, + } | base_output + np.savez(npz_path, **amass_output) + + elif len(shape) == 3: + # real batch of motions + npz_path_base, ext = os.path.splitext(npz_path) + for i in range(shape[0]): + npz_path_i = npz_path_base + "_" + str(i).zfill(2) + ext + self.save_npz(trans[i], root_orient[i], pose_body[i], base_output, npz_path_i) + + +# amass_output = { +# "gender": "neutral", +# "surface_model_type": "smplx", +# "mocap_frame_rate": float(fps), +# "mocap_time_length": len(motion) / float(fps) +# "trans": trans, +# "betas": betas, +# "num_betas": len(betas), +# "root_orient": np.array([T, 3]), # axis angle +# "pose_body": np.array([T, 63]), # 63=21*3, axis angle 21 = 22 - root +# "pose_hand": np.array([T, 90]), # 90=30*3=15*2*3 axis angle (load from mean_hands) +# "pose_jaw": np.array([T, 3]), # all zeros is fine +# "pose_eye": np.array([T, 6]), # all zeros is fine` +# } diff --git a/kimodo/geometry.py b/kimodo/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..e0d2397bf5f4517fc92280caa7dfbab993452940 --- /dev/null +++ b/kimodo/geometry.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Rotation and representation conversions: axis-angle, quaternion, matrix, 6D continuous.""" + +import torch +import torch.nn.functional as F + + +def angle_to_Y_rotation_matrix(angle: torch.Tensor) -> torch.Tensor: + """Build a rotation matrix around the Y axis from a scalar angle (radians). + + Shape: angle.shape + (3, 3). + """ + cos, sin = torch.cos(angle), torch.sin(angle) + one, zero = torch.ones_like(angle), torch.zeros_like(angle) + mat = torch.stack((cos, zero, sin, zero, one, zero, -sin, zero, cos), -1) + mat = mat.reshape(angle.shape + (3, 3)) + return mat + + +def matrix_to_cont6d(matrix: torch.Tensor) -> torch.Tensor: + """Convert rotation matrix to 6D continuous representation (first two columns). + + Shape: (..., 3, 3) -> (..., 6). + """ + cont_6d = torch.concat([matrix[..., 0], matrix[..., 1]], dim=-1) + return cont_6d + + +def cont6d_to_matrix(cont6d: torch.Tensor) -> torch.Tensor: + """Convert 6D continuous representation to rotation matrix (Gram–Schmidt on two columns). + + Last dim must be 6. + """ + assert cont6d.shape[-1] == 6, "The last dimension must be 6" + x_raw = cont6d[..., 0:3] + y_raw = cont6d[..., 3:6] + + x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True) + z = torch.cross(x, y_raw, dim=-1) + z = z / torch.norm(z, dim=-1, keepdim=True) + + y = torch.cross(z, x, dim=-1) + + x = x[..., None] + y = y[..., None] + z = z[..., None] + + mat = torch.cat([x, y, z], dim=-1) + return mat + + +def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: + """Convert axis-angle to rotation matrix. + + Args: + axis_angle: (..., 3) axis-angle vectors (angle = norm, axis = normalized) + Returns: + rotmat: (..., 3, 3) rotation matrices + """ + eps = 1e-6 + angle = torch.norm(axis_angle, dim=-1, keepdim=True) # (..., 1) + axis = axis_angle / (angle + eps) + + x, y, z = axis.unbind(-1) + + zero = torch.zeros_like(x) + K = torch.stack([zero, -z, y, z, zero, -x, -y, x, zero], dim=-1).reshape(*axis.shape[:-1], 3, 3) + + eye = torch.eye(3, device=axis.device, dtype=axis.dtype) + eye = eye.expand(*axis.shape[:-1], 3, 3) + + sin = torch.sin(angle)[..., None] + cos = torch.cos(angle)[..., None] + + R = eye + sin * K + (1 - cos) * (K @ K) + return R + + +def matrix_to_axis_angle(R: torch.Tensor) -> torch.Tensor: + """Convert rotation matrix to axis-angle via quaternions (more numerically stable). + + Args: + R: (..., 3, 3) rotation matrices + Returns: + axis_angle: (..., 3) + """ + # Go through quaternions for numerical stability + quat = matrix_to_quaternion(R) # (..., 4) with (w, x, y, z) + return quaternion_to_axis_angle(quat) + + +def quaternion_to_axis_angle(quat: torch.Tensor) -> torch.Tensor: + """Convert quaternion to axis-angle representation. + + Args: + quat: (..., 4) quaternions with real part first (w, x, y, z) + Returns: + axis_angle: (..., 3) + """ + eps = 1e-6 + + # Ensure canonical form to avoid sign ambiguity. + # Primary: prefer w > 0. When w ≈ 0 (angle ≈ π), prefer first nonzero xyz > 0. + w = quat[..., 0:1] + xyz = quat[..., 1:] + + # Find first significant component of xyz for tie-breaking when w ≈ 0 + first_significant = xyz[..., 0:1] # use x component as tie-breaker + + # Flip if: w < 0, OR (w ≈ 0 AND first xyz component < 0) + should_flip = (w < -eps) | ((w.abs() <= eps) & (first_significant < 0)) + quat = torch.where(should_flip, -quat, quat) + + w = quat[..., 0] + xyz = quat[..., 1:] + + # sin(angle/2) = ||xyz|| + sin_half_angle = xyz.norm(dim=-1) + + # angle = 2 * atan2(sin(angle/2), cos(angle/2)) + # This is more stable than 2 * acos(w) near angle=0 + angle = 2.0 * torch.atan2(sin_half_angle, w) + + # axis = xyz / sin(angle/2), but handle small angles + # For small angles: axis-angle ≈ 2 * xyz (since sin(x) ≈ x for small x) + small_angle = sin_half_angle.abs() < eps + + # Safe division + scale = torch.where( + small_angle, + 2.0 * torch.ones_like(angle), # small angle: axis_angle ≈ 2 * xyz + angle / sin_half_angle.clamp(min=eps), + ) + + return xyz * scale.unsqueeze(-1) + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """Returns torch.sqrt(torch.max(0, x)) subgradient is zero where x is 0.""" + return torch.sqrt(x * (x > 0).to(x.dtype)) + + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + return ( + (F.one_hot(q_abs.argmax(dim=-1), num_classes=4)[..., None] * quat_candidates) + .sum(dim=-2) + .reshape(batch_dim + (4,)) + ) + + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) diff --git a/kimodo/meta.py b/kimodo/meta.py new file mode 100644 index 0000000000000000000000000000000000000000..dd9ff2f75e55e8d1dd94a4b55e1475a139338d32 --- /dev/null +++ b/kimodo/meta.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Parse and normalize prompt text/duration data from meta dicts.""" + +import os +from typing import Any, Optional + +from kimodo.tools import load_json + +from .sanitize import sanitize_text, sanitize_texts + + +def load_prompts_from_meta(meta_path: str, **kwargs): + """Load prompts from a meta dict or file. If fps is provided, the durations are converted to + frames. + + Args: + meta_path: Path to the meta file. + **kwargs: Additional arguments to pass to parse_prompts_from_meta. + + Returns: + texts: List of texts. + durations: List of durations in seconds or frames. + """ + if not os.path.exists(meta_path): + raise FileNotFoundError(f"meta.json not found in input folder: {meta_path}") + + meta = load_json(meta_path) + return parse_prompts_from_meta(meta, **kwargs) + + +def parse_prompts_from_meta( + meta: dict[str, Any], + fps: Optional[float] = None, + sanitize: bool = False, +) -> tuple[list[str], list[float]]: + """Parse prompt texts and durations from a meta dict into normalized lists. If fps is provided, + the durations are converted to frames. + + Accepts either: + - Single prompt: "text" (str) and "duration" (float) in seconds. + - Multiple prompts: "texts" (list of str) and "durations" (list of float) in seconds. + + Returns: + (texts, durations): texts as list of str, durations as list of float (seconds or frames). + Lengths of both lists are equal. + + Raises: + ValueError: If meta does not contain a recognized format. + """ + # Single prompt + if "text" in meta and "duration" in meta: + text = meta["text"] + duration = float(meta["duration"]) + if fps is not None: + duration = int(duration * fps) + if isinstance(text, list): + raise ValueError("meta has 'text' but it is a list; use 'texts' for multiple prompts") + + if sanitize: + text = sanitize_text(text) + return ([text], [duration]) + + # Multiple prompts + if "texts" in meta and "durations" in meta: + texts = meta["texts"] + durations = meta["durations"] + if not isinstance(texts, list) or not isinstance(durations, list): + raise ValueError("meta 'texts' and 'durations' must be lists") + if len(texts) != len(durations): + raise ValueError(f"meta 'texts' and 'durations' length mismatch: {len(texts)} vs {len(durations)}") + durations = [float(d) for d in durations] + if fps is not None: + durations = [int(d * fps) for d in durations] + + if sanitize: + texts = sanitize_texts(texts) + return texts, durations + + raise ValueError("meta must contain either 'text' and 'duration', or 'texts' and 'durations'.") diff --git a/kimodo/metrics/__init__.py b/kimodo/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..159e0da823a0f1c75bce3d75da44584d45f11ffc --- /dev/null +++ b/kimodo/metrics/__init__.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Evaluation metrics for motion quality (foot skate, contact consistency, constraint following).""" + +from .base import ( + Metric, + aggregate_metrics, + clear_metrics, + compute_metrics, +) +from .constraints import ContraintFollow +from .foot_skate import ( + FootContactConsistency, + FootSkateFromContacts, + FootSkateFromHeight, + FootSkateRatio, +) +from .tmr import ( + TMR_EmbeddingMetric, + TMR_Metric, + compute_tmr_per_sample_retrieval, + compute_tmr_retrieval_metrics, +) + +__all__ = [ + "Metric", + "ContraintFollow", + "FootContactConsistency", + "FootSkateFromContacts", + "FootSkateFromHeight", + "FootSkateRatio", + "TMR_EmbeddingMetric", + "TMR_Metric", + "aggregate_metrics", + "clear_metrics", + "compute_metrics", + "compute_tmr_per_sample_retrieval", + "compute_tmr_retrieval_metrics", +] diff --git a/kimodo/metrics/base.py b/kimodo/metrics/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4ca1ebc248c0bd8cbd58eca12a3458f4a65d0745 --- /dev/null +++ b/kimodo/metrics/base.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Base metric class and batch/aggregate helpers.""" + +from __future__ import annotations + +from collections import defaultdict +from typing import Dict, List + +import torch + + +class Metric: + """Base class for metrics that accumulate results over multiple __call__ and expose + aggregate().""" + + def __init__(self, **kwargs): + self.clear() + + def __call__(self, *args, **kwargs): + """Compute metric for current batch, append to saved_metrics, and return the batch + result.""" + metrics = self._compute(*args, **kwargs) + for key, val in metrics.items(): + self.saved_metrics[key].append(val.detach().cpu().float()) + return metrics + + def _compute(self, **kwargs): + """Subclasses implement this to compute metric dict from batch inputs.""" + raise NotImplementedError() + + def clear(self): + """Reset all accumulated metric values.""" + self.saved_metrics = defaultdict(list) + + def aggregate(self): + """Return a dict of concatenated/stacked tensors over all accumulated batches.""" + output = {} + for key, lst in self.saved_metrics.items(): + try: + output[key] = torch.cat(lst) + except RuntimeError: + output[key] = torch.stack(lst) + return output + + +def compute_metrics(metrics_list: List[Metric], metrics_in: Dict) -> Dict: + """Run each metric on metrics_in and return the combined dict of batch results.""" + metrics_out = {} + for metric in metrics_list: + metrics_out.update(metric(**metrics_in)) + return metrics_out + + +def aggregate_metrics(metrics_list: List[Metric]) -> Dict: + """Return combined aggregated results (concatenated over batches) for all metrics.""" + metrics_out = {} + for metric in metrics_list: + metrics_out.update(metric.aggregate()) + return metrics_out + + +def clear_metrics(metrics_list: List[Metric]) -> None: + """Clear accumulated values for all metrics in the list.""" + for metric in metrics_list: + metric.clear() diff --git a/kimodo/metrics/constraints.py b/kimodo/metrics/constraints.py new file mode 100644 index 0000000000000000000000000000000000000000..d31ba99cf859cc0d22d71efcb1b324c47c5d7931 --- /dev/null +++ b/kimodo/metrics/constraints.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Constraint-following metrics.""" + +from __future__ import annotations + +from collections import defaultdict +from typing import Dict, List, Optional + +import torch +from torch import Tensor + +from kimodo.constraints import ( + EndEffectorConstraintSet, + FullBodyConstraintSet, + Root2DConstraintSet, +) +from kimodo.tools import ensure_batched + +from .base import Metric + + +class ContraintFollow(Metric): + """Constraint-following metric dispatcher for kimodo constraint sets.""" + + def __init__( + self, + skeleton, + root_threshold: float = 0.10, + **kwargs, + ): + super().__init__(**kwargs) + self.skeleton = skeleton + self.root_threshold = root_threshold + + @ensure_batched(posed_joints=4, constraints_lst=2, lengths=1) + def _compute( + self, + posed_joints: Tensor, + constraints_lst: Optional[List], + lengths: Optional[Tensor] = None, + **kwargs, + ) -> Dict: + if not constraints_lst: + return {} + + root_idx = self.skeleton.root_idx + output = defaultdict(list) + + for posed_joints_s, constraint_lst_s, lengths_s in zip(posed_joints, constraints_lst, lengths): + output_seq = defaultdict(list) + for constraint in constraint_lst_s: + frame_idx = constraint.frame_indices.to(device=posed_joints_s.device, dtype=torch.long) + assert frame_idx.max() < lengths_s, "The constraint is defined outsite the lenght of the motion." + if frame_idx.numel() == 0: + continue + + if isinstance(constraint, Root2DConstraintSet): + pred_root2d = posed_joints_s[frame_idx, root_idx][:, [0, 2]] + target = constraint.smooth_root_2d.to(posed_joints_s.device) + + dist = torch.norm(pred_root2d - target, dim=-1) + output_seq["constraint_root2d_err"].append(dist) + hit = (dist <= self.root_threshold).float() + output_seq["constraint_root2d_acc"].append(hit) + + elif isinstance(constraint, FullBodyConstraintSet): + pred = posed_joints_s[frame_idx] + target = constraint.global_joints_positions.to(posed_joints_s.device) + err = torch.norm(pred - target, dim=-1) + output_seq["constraint_fullbody_keyframe"].append(err) + + elif isinstance(constraint, EndEffectorConstraintSet): + pos_idx = constraint.pos_indices.to(device=posed_joints_s.device, dtype=torch.long) + pred = posed_joints_s[frame_idx].index_select(1, pos_idx) + target = constraint.global_joints_positions.to(posed_joints_s.device).index_select(1, pos_idx) + err = torch.norm(pred - target, dim=-1) + output_seq["constraint_end_effector"].append(err) + + # in case we have several same constraints in the list + for key, val in output_seq.items(): + output[key].append(torch.cat(val).mean()) + + reduced = {} + for key, vals in output.items(): + reduced[key] = torch.stack(vals, dim=0) + return reduced diff --git a/kimodo/metrics/foot_skate.py b/kimodo/metrics/foot_skate.py new file mode 100644 index 0000000000000000000000000000000000000000..7da474006ccbd915c44b3a68107b3e60d7fdf3b7 --- /dev/null +++ b/kimodo/metrics/foot_skate.py @@ -0,0 +1,232 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Foot skate and contact consistency metrics.""" + +from __future__ import annotations + +from typing import Dict, Optional + +import torch +from torch import Tensor + +from kimodo.motion_rep.feature_utils import compute_vel_xyz +from kimodo.motion_rep.feet import foot_detect_from_pos_and_vel +from kimodo.skeleton import SkeletonBase +from kimodo.tools import ensure_batched + +from .base import Metric + + +class FootSkateFromHeight(Metric): + """When toe joint is near the floor, measures mean velocity of the toes.""" + + def __init__( + self, + skeleton: SkeletonBase, + fps: float, + height_thresh: float = 0.05, + **kwargs, + ): + super().__init__(**kwargs) + self.height_thresh = height_thresh + self.skeleton = skeleton + self.fps = fps + + @ensure_batched(posed_joints=4, lengths=1) + def _compute( + self, + posed_joints: Tensor, + lengths: Optional[Tensor] = None, + **kwargs, + ) -> Dict: + fidx = self.skeleton.foot_joint_idx + if len(fidx) != 4: + raise ValueError("FootSkateFromHeight expects four foot joints (heel/toe per foot)") + + feet_pos = posed_joints[:, :, fidx] + toe_pos = feet_pos[:, :, [1, 3]] + + toe_on_floor = (toe_pos[..., 1] < self.height_thresh)[:, :-1] # y-up [B, T, 2] where [left right] + + dt = 1.0 / self.fps + toe_vel = torch.norm(toe_pos[:, 1:] - toe_pos[:, :-1], dim=-1) / dt # [B, nframes-1, 2] + + # compute err + contact_toe_vel = toe_vel * toe_on_floor # vel when corresponding toe is on ground + + # account for generated length + # since they are velocities use length-1 to avoid inaccurate vel going one frame past len + device = toe_on_floor.device + len_mask = torch.arange(toe_on_floor.shape[1], device=device)[None, :, None].expand(toe_on_floor.shape) < ( + lengths[:, None, None] - 1 + ) + toe_on_floor = toe_on_floor * len_mask + contact_toe_vel = contact_toe_vel * len_mask + + mean_vel = torch.sum(contact_toe_vel, (1, 2)) / (torch.sum(toe_on_floor, (1, 2)) + 1e-6) + return {"foot_skate_from_height": mean_vel} + + +class FootSkateFromContacts(Metric): + """Measures velocity of the toes and ankles when predicted to be in contact.""" + + def __init__( + self, + skeleton: SkeletonBase, + fps: float, + **kwargs, + ): + super().__init__(**kwargs) + self.skeleton = skeleton + self.fps = fps + + @ensure_batched(posed_joints=4, foot_contacts=3, lengths=1) + def _compute( + self, + posed_joints: Tensor, + foot_contacts: Tensor, + lengths: Optional[Tensor] = None, + **kwargs, + ) -> Dict: + fidx = self.skeleton.foot_joint_idx + feet_pos = posed_joints[:, :, fidx] + dt = 1.0 / self.fps + foot_vel = torch.norm(feet_pos[:, 1:] - feet_pos[:, :-1], dim=-1) / dt + + foot_contacts = foot_contacts[:, :-1] + vel_err = foot_vel * foot_contacts + + # account for generated length + # since they are velocities use length-1 to avoid inaccurate vel going one frame past len + device = foot_contacts.device + len_mask = torch.arange(foot_contacts.shape[1], device=device)[None, :, None].expand(foot_contacts.shape) < ( + lengths[:, None, None] - 1 + ) + foot_contacts = foot_contacts * len_mask + vel_err = vel_err * len_mask + + mean_vel = torch.sum(vel_err, (1, 2)) / (torch.sum(foot_contacts, (1, 2)) + 1e-6) # mean over contacting frames + + # Compute max velocity error across all feet and frames (per batch) + max_vel = vel_err.amax(dim=(1, 2)) # [B] + + return { + "foot_skate_from_pred_contacts": mean_vel, + "foot_skate_max_vel": max_vel, + } + + +class FootSkateRatio(Metric): + """Compute fraction of frames where the foot skates when it is on the ground. + + Inspired by GMD: https://github.com/korrawe/guided-motion-diffusion/blob/main/data_loaders/humanml/utils/metrics.py#L204 + """ + + def __init__( + self, + skeleton: SkeletonBase, + fps: float, + height_thresh=0.05, + vel_thresh=0.2, + **kwargs, + ): + super().__init__(**kwargs) + self.height_thresh = height_thresh + self.vel_thresh = vel_thresh + + self.skeleton = skeleton + self.fps = fps + + @ensure_batched(posed_joints=4, foot_contacts=3, lengths=1) + def _compute( + self, + posed_joints: Tensor, + foot_contacts: Tensor, + lengths: Optional[Tensor] = None, + **kwargs, + ) -> Dict: + fidx = self.skeleton.foot_joint_idx + assert len(fidx) == 4, "This metric assumes 4 foot joints: heel, toe, heel, toe" + + feet_pos = posed_joints[:, :, fidx] + toe_pos = feet_pos[:, :, [1, 3]] + + toe_on_floor = toe_pos[..., 1] < self.height_thresh # y-up [B, T, 2] where [left right] + # current and next frame on floor to consider it in contact + toe_on_floor = torch.logical_and(toe_on_floor[:, :-1], toe_on_floor[:, 1:]) # [B, T-1, 2] + + dt = 1.0 / self.fps + toe_vel = torch.norm(toe_pos[:, 1:] - toe_pos[:, :-1], dim=-1) / dt # [B, nframes-1, 2] + + # compute err + contact_toe_vel = toe_vel * toe_on_floor # vel when corresponding toe is on ground + + # account for generated length + # since they are velocities use length-1 to avoid inaccurate vel going one frame past len + device = toe_on_floor.device + len_mask = torch.arange(toe_on_floor.shape[1], device=device)[None, :, None].expand(toe_on_floor.shape) < ( + lengths[:, None, None] - 1 + ) + toe_on_floor = toe_on_floor * len_mask + contact_toe_vel = contact_toe_vel * len_mask + + # skating if velocity during contact > thresh + toe_skate = contact_toe_vel > self.vel_thresh + skate_ratio = torch.sum(toe_skate, (1, 2)) / (torch.sum(toe_on_floor, (1, 2)) + 1e-6) + return {"foot_skate_ratio": skate_ratio} + + +class FootContactConsistency(Metric): + """Measures consistency between heuristic detected foot contacts (from height and velocity) and + predicted foot contacts. + + i.e. accuracy of how well predicted matches heuristic. + """ + + def __init__( + self, + skeleton: SkeletonBase, + fps: float, + vel_thresh: float = 0.15, + height_thresh: float = 0.10, + **kwargs, + ): + super().__init__(**kwargs) + self.vel_thresh = vel_thresh + self.height_thresh = height_thresh + + self.skeleton = skeleton + self.fps = fps + + @ensure_batched(posed_joints=4, foot_contacts=3, lengths=1) + def _compute( + self, + posed_joints: Tensor, + foot_contacts: Tensor, + lengths: Optional[Tensor] = None, + **kwargs, + ) -> Dict: + velocity = compute_vel_xyz(posed_joints, float(self.fps), lengths=lengths) + heuristic_contacts = foot_detect_from_pos_and_vel( + posed_joints, + velocity, + self.skeleton, + self.vel_thresh, + self.height_thresh, + ) + + # compute accuracy of predicted, treating heuristic as ground truth + num_contacts = foot_contacts.shape[-1] + incorrect = torch.logical_xor(heuristic_contacts, foot_contacts) + # account for generated length + # since they are velocities, use length-1 to avoid inaccurate vel going one frame past len + device = foot_contacts.device + len_mask = torch.arange(foot_contacts.shape[1], device=device)[None, :, None].expand(foot_contacts.shape) < ( + lengths[:, None, None] - 1 + ) + incorrect = incorrect * len_mask + + incorrect_ratio = torch.sum(incorrect, (1, 2)) / (num_contacts * (lengths - 1)) + accuracy = 1 - incorrect_ratio + + return {"foot_contact_consistency": accuracy} diff --git a/kimodo/metrics/tmr.py b/kimodo/metrics/tmr.py new file mode 100644 index 0000000000000000000000000000000000000000..4c4835704cbf7957066648736f389f4c05debad2 --- /dev/null +++ b/kimodo/metrics/tmr.py @@ -0,0 +1,530 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""TMR evaluation metrics: text-motion retrieval, R-Precision, and related scores.""" + +from __future__ import annotations + +from collections import defaultdict +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +from scipy import linalg +from torch import Tensor + +from kimodo.model.tmr import TMR + +from .base import Metric + + +# Scores are between 0 and 1 +def get_score_matrix_unit(x, y): + sim_matrix = np.einsum("b i, c i -> b c", x, y) + scores = sim_matrix / 2 + 0.5 + return scores + + +def get_scores_unit(x, y): + similarity = np.einsum("... i, ... i", x, y) + scores = similarity / 2 + 0.5 + return scores + + +def compute_tmr_per_sample_retrieval( + motion_emb: np.ndarray, + text_emb: np.ndarray, + sample_ids: List[str], + texts: List[str], + top_k: int = 5, +) -> List[Dict[str, Any]]: + """For each sample (text query i), compute t2m rank of motion i and top-k retrieved motions with + ids and texts. + + Returns list of dicts: [{"rank": int, "top_k": [{"id": str, "text": str}, ...]}, ...]. + """ + motion_emb = np.asarray(motion_emb).squeeze() + text_emb = np.asarray(text_emb).squeeze() + if motion_emb.ndim == 1: + motion_emb = motion_emb[np.newaxis, :] + if text_emb.ndim == 1: + text_emb = text_emb[np.newaxis, :] + n = motion_emb.shape[0] + assert text_emb.shape[0] == n and len(sample_ids) == n and len(texts) == n + scores = get_score_matrix_unit(text_emb, motion_emb) + out: List[Dict[str, Any]] = [] + for i in range(n): + row = np.asarray(scores[i]) + order = np.argsort(row)[::-1] + rank = int(np.where(order == i)[0][0]) + 1 + top_indices = order[:top_k] + top_k_list = [{"id": sample_ids[j], "text": texts[j]} for j in top_indices] + out.append({"rank": rank, "top_k": top_k_list}) + return out + + +class TMR_Metric(Metric): + def __init__( + self, + tmr_model: TMR, + ranks: List = [1, 2, 3, 5, 10], + ranks_rounding=2, + **kwargs, + ): + super().__init__(**kwargs) + self.tmr_model = tmr_model + self.ranks = ranks + self.ranks_rounding = ranks_rounding + + def clear(self): + self.saved_metrics = defaultdict(list) + self.saved_text_latents = [] + self.saved_motion_gen_latents = [] + self.saved_motion_gt_latents = [] + + def _compute( + self, + motion_rep, + pred_joints_output: Dict, + gt_joints_output: Dict, + text_x_dict: Dict, + lengths: Tensor, + **kwargs, + ) -> Dict: + pred_posed_joints = pred_joints_output["posed_joints"] + original_skeleton = motion_rep.skeleton if motion_rep is not None else None + latents_motion = self.tmr_model.encode_motion( + pred_posed_joints, + lengths=lengths, + original_skeleton=original_skeleton, + unit_vector=True, + ) + latents_motion = latents_motion.cpu().numpy() + + if isinstance(text_x_dict, dict) and "texts" in text_x_dict: + latents_text = self.tmr_model.encode_raw_text(text_x_dict["texts"], unit_vector=True) + else: + latents_text = self.tmr_model.encode_text(text_x_dict, unit_vector=True) + if latents_text.dim() == 1: + latents_text = latents_text.unsqueeze(0) + latents_text = latents_text.cpu().numpy() + + self.saved_text_latents.append(latents_text) + self.saved_motion_gen_latents.append(latents_motion) + + scores_text = get_scores_unit(latents_motion, latents_text) + output = {"TMR/t2m_sim": scores_text} + + if gt_joints_output is not None and "posed_joints" in gt_joints_output: + gt_posed_joints = gt_joints_output["posed_joints"] + gt_latents_motion = self.tmr_model.encode_motion( + gt_posed_joints, + lengths=lengths, + original_skeleton=original_skeleton, + unit_vector=True, + ) + gt_latents_motion = gt_latents_motion.cpu().numpy() + self.saved_motion_gt_latents.append(gt_latents_motion) + + gt_scores_text = get_scores_unit(gt_latents_motion, latents_text) + scores_motion = get_scores_unit(latents_motion, gt_latents_motion) + + output["TMR/t2m_gt_sim"] = gt_scores_text + output["TMR/m2m_sim"] = scores_motion + + # pytorch tensors + for key, val in output.items(): + output[key] = torch.tensor(val) + return output + + def aggregate(self): + output = {} + for key, lst in self.saved_metrics.items(): + output[key] = np.concatenate(lst) + + assert self.saved_text_latents, "Should call the metric at least once." + + text_latents = np.concatenate(self.saved_text_latents) + motion_gen_latents = np.concatenate(self.saved_motion_gen_latents) + + batch_size = len(text_latents) + assert text_latents.shape == motion_gen_latents.shape + + scores_t2m = get_score_matrix_unit(text_latents, motion_gen_latents) + scores_t2t = get_score_matrix_unit(text_latents, text_latents) + + t2m_metrics = contrastive_metrics( + scores=scores_t2m, + scores_t2t=scores_t2t, + threshold=0.99, + rounding=2, + ) + + for key, val in t2m_metrics.items(): + output["TMR/t2m_R/" + key] = val + + mu_gen, cov_gen = calculate_activation_statistics(motion_gen_latents) + mu_text, cov_text = calculate_activation_statistics(text_latents) + + fid_gen_text = calculate_frechet_distance(mu_gen, cov_gen, mu_text, cov_text) + output["TMR/FID/gen_text"] = fid_gen_text + + if self.saved_motion_gt_latents: + motion_gt_latents = np.concatenate(self.saved_motion_gt_latents) + assert motion_gt_latents.shape == motion_gen_latents.shape + + scores_m2gm = get_score_matrix_unit(motion_gen_latents, motion_gt_latents) + scores_t2gm = get_score_matrix_unit(text_latents, motion_gt_latents) + + m2gm_metrics = contrastive_metrics( + scores=scores_m2gm, + scores_t2t=scores_t2t, + threshold=0.99, + rounding=2, + ) + for key, val in m2gm_metrics.items(): + output["TMR/m2m_R/" + key] = val + + t2gm_metrics = contrastive_metrics( + scores=scores_t2gm, + scores_t2t=scores_t2t, + threshold=0.99, + rounding=2, + ) + for key, val in t2gm_metrics.items(): + output["TMR/t2m_gt_R/" + key] = val + + mu_gt_motion, cov_gt_motion = calculate_activation_statistics(motion_gt_latents) + fid_gen_motion = calculate_frechet_distance( + mu_gen, + cov_gen, + mu_gt_motion, + cov_gt_motion, + ) + output["TMR/FID/gen_gt"] = fid_gen_motion + + fid_gt_text = calculate_frechet_distance( + mu_gt_motion, + cov_gt_motion, + mu_text, + cov_text, + ) + output["TMR/FID/gt_text"] = fid_gt_text + + for key, val in output.items(): + if isinstance(val, (int, float, np.integer, np.floating)): + val = torch.tensor([val for _ in range(batch_size)]) + + if isinstance(val, np.ndarray): + val = torch.from_numpy(val) + + output[key] = val.cpu().float() + return output + + +class TMR_EmbeddingMetric(Metric): + """TMR metrics from precomputed motion and text embeddings (no model load). + + Use in the loop: pass motion_emb and text_emb per sample; aggregate() computes retrieval metrics. + """ + + def __init__(self, ranks_rounding: int = 2, **kwargs): + super().__init__(**kwargs) + self.ranks_rounding = ranks_rounding + + def clear(self): + self.saved_metrics = defaultdict(list) + self.saved_text_latents = [] + self.saved_motion_gen_latents = [] + self.saved_motion_gt_latents = [] + + def _compute( + self, + motion_emb=None, + text_emb=None, + gt_motion_emb=None, + **kwargs, + ) -> Dict: + if motion_emb is None or text_emb is None: + return {} + motion_emb = np.asarray(motion_emb) + text_emb = np.asarray(text_emb) + if motion_emb.ndim == 1: + motion_emb = motion_emb[np.newaxis, :] + if text_emb.ndim == 1: + text_emb = text_emb[np.newaxis, :] + self.saved_text_latents.append(text_emb) + self.saved_motion_gen_latents.append(motion_emb) + if gt_motion_emb is not None: + gt_motion_emb = np.asarray(gt_motion_emb) + if gt_motion_emb.ndim == 1: + gt_motion_emb = gt_motion_emb[np.newaxis, :] + self.saved_motion_gt_latents.append(gt_motion_emb) + scores = get_scores_unit(motion_emb, text_emb) + return {"TMR/t2m_sim": torch.tensor(scores, dtype=torch.float32)} + + def aggregate(self): + output = {} + for key, lst in self.saved_metrics.items(): + output[key] = np.concatenate(lst) + if not self.saved_text_latents: + return output + text_latents = np.concatenate(self.saved_text_latents) + motion_gen_latents = np.concatenate(self.saved_motion_gen_latents) + batch_size = len(text_latents) + assert text_latents.shape == motion_gen_latents.shape + scores_t2m = get_score_matrix_unit(text_latents, motion_gen_latents) + scores_t2t = get_score_matrix_unit(text_latents, text_latents) + t2m_metrics = contrastive_metrics( + scores=scores_t2m, + scores_t2t=scores_t2t, + threshold=0.99, + rounding=self.ranks_rounding, + ) + for key, val in t2m_metrics.items(): + output["TMR/t2m_R/" + key] = val + mu_gen, cov_gen = calculate_activation_statistics(motion_gen_latents) + mu_text, cov_text = calculate_activation_statistics(text_latents) + output["TMR/FID/gen_text"] = calculate_frechet_distance(mu_gen, cov_gen, mu_text, cov_text) + if self.saved_motion_gt_latents: + motion_gt_latents = np.concatenate(self.saved_motion_gt_latents) + assert motion_gt_latents.shape == motion_gen_latents.shape + scores_m2gm = get_score_matrix_unit(motion_gen_latents, motion_gt_latents) + scores_t2gm = get_score_matrix_unit(text_latents, motion_gt_latents) + m2gm_metrics = contrastive_metrics( + scores=scores_m2gm, + scores_t2t=scores_t2t, + threshold=0.99, + rounding=self.ranks_rounding, + ) + for key, val in m2gm_metrics.items(): + output["TMR/m2m_R/" + key] = val + t2gm_metrics = contrastive_metrics( + scores=scores_t2gm, + scores_t2t=scores_t2t, + threshold=0.99, + rounding=self.ranks_rounding, + ) + for key, val in t2gm_metrics.items(): + output["TMR/t2m_gt_R/" + key] = val + mu_gt_motion, cov_gt_motion = calculate_activation_statistics(motion_gt_latents) + output["TMR/FID/gen_gt"] = calculate_frechet_distance(mu_gen, cov_gen, mu_gt_motion, cov_gt_motion) + output["TMR/FID/gt_text"] = calculate_frechet_distance(mu_gt_motion, cov_gt_motion, mu_text, cov_text) + for key, val in output.items(): + if isinstance(val, (int, float, np.integer, np.floating)): + val = torch.tensor([val for _ in range(batch_size)]) + if isinstance(val, np.ndarray): + val = torch.from_numpy(val) + output[key] = val.cpu().float() + return output + + +def compute_tmr_retrieval_metrics( + motion_emb: np.ndarray, + text_emb: np.ndarray, + gt_motion_emb: Optional[np.ndarray] = None, + rounding: int = 2, +) -> Dict[str, float]: + """Compute TMR retrieval metrics from precomputed embeddings.""" + if motion_emb.shape != text_emb.shape: + raise ValueError(f"Expected same shape for motion/text embeddings, got {motion_emb.shape} vs {text_emb.shape}") + + scores_t2m = get_score_matrix_unit(text_emb, motion_emb) + scores_t2t = get_score_matrix_unit(text_emb, text_emb) + + output: Dict[str, float] = {} + t2m_metrics = contrastive_metrics( + scores=scores_t2m, + scores_t2t=scores_t2t, + threshold=0.99, + rounding=rounding, + ) + for key, val in t2m_metrics.items(): + output[f"TMR/t2m_R/{key}"] = float(val) + + mu_gen, cov_gen = calculate_activation_statistics(motion_emb) + mu_text, cov_text = calculate_activation_statistics(text_emb) + output["TMR/FID/gen_text"] = float(calculate_frechet_distance(mu_gen, cov_gen, mu_text, cov_text)) + + if gt_motion_emb is not None: + if gt_motion_emb.shape != motion_emb.shape: + raise ValueError(f"Expected gt motion embeddings shape {motion_emb.shape}, got {gt_motion_emb.shape}") + + scores_m2gm = get_score_matrix_unit(motion_emb, gt_motion_emb) + scores_t2gm = get_score_matrix_unit(text_emb, gt_motion_emb) + + m2gm_metrics = contrastive_metrics( + scores=scores_m2gm, + scores_t2t=scores_t2t, + threshold=0.99, + rounding=rounding, + ) + for key, val in m2gm_metrics.items(): + output[f"TMR/m2m_R/{key}"] = float(val) + + t2gm_metrics = contrastive_metrics( + scores=scores_t2gm, + scores_t2t=scores_t2t, + threshold=0.99, + rounding=rounding, + ) + for key, val in t2gm_metrics.items(): + output[f"TMR/t2m_gt_R/{key}"] = float(val) + + mu_gt_motion, cov_gt_motion = calculate_activation_statistics(gt_motion_emb) + output["TMR/FID/gen_gt"] = float(calculate_frechet_distance(mu_gen, cov_gen, mu_gt_motion, cov_gt_motion)) + output["TMR/FID/gt_text"] = float(calculate_frechet_distance(mu_gt_motion, cov_gt_motion, mu_text, cov_text)) + + return output + + +def all_contrastive_metrics(sims, emb=None, threshold=None, rounding=2, return_cols=False): + text_selfsim = None + if emb is not None: + text_selfsim = emb @ emb.T + + t2m_m, t2m_cols = contrastive_metrics(sims, text_selfsim, threshold, return_cols=True, rounding=rounding) + m2t_m, m2t_cols = contrastive_metrics(sims.T, text_selfsim, threshold, return_cols=True, rounding=rounding) + + all_m = {} + for key in t2m_m: + all_m[f"t2m/{key}"] = t2m_m[key] + all_m[f"m2t/{key}"] = m2t_m[key] + + all_m["t2m/len"] = float(len(sims)) + all_m["m2t/len"] = float(len(sims[0])) + if return_cols: + return all_m, t2m_cols, m2t_cols + return all_m + + +def contrastive_metrics( + scores, + scores_t2t=None, + threshold=None, + rounding=2, +): + n, m = scores.shape + assert n == m + num_queries = n + + dists = -scores + sorted_dists = np.sort(dists, axis=1) + # GT is in the diagonal + gt_dists = np.diag(dists)[:, None] + + if scores_t2t is not None and threshold is not None: + real_threshold = 2 * threshold - 1 + idx = np.argwhere(scores_t2t > real_threshold) + partition = np.unique(idx[:, 0], return_index=True)[1] + # take as GT the minimum score of similar values + gt_dists = np.minimum.reduceat(dists[tuple(idx.T)], partition) + gt_dists = gt_dists[:, None] + + rows, cols = np.where((sorted_dists - gt_dists) == 0) # find column position of GT + + # if there are ties + if rows.size > num_queries: + assert np.unique(rows).size == num_queries, "issue in metric evaluation" + avg_cols = break_ties_average(sorted_dists, gt_dists) + cols = avg_cols + + msg = "expected ranks to match queries ({} vs {}) " + assert cols.size == num_queries, msg + + metrics = {} + vals = [str(x).zfill(2) for x in [1, 2, 3, 5, 10]] + for val in vals: + metrics[f"R{val}"] = 100 * float(np.sum(cols < int(val))) / num_queries + + metrics["MedR"] = float(np.median(cols) + 1) + metrics["len"] = num_queries + + if rounding is not None: + for key in metrics: + metrics[key] = round(metrics[key], rounding) + return metrics + + +def break_ties_average(sorted_dists, gt_dists): + # fast implementation, based on this code: + # https://stackoverflow.com/a/49239335 + locs = np.argwhere((sorted_dists - gt_dists) == 0) + + # Find the split indices + steps = np.diff(locs[:, 0]) + splits = np.nonzero(steps)[0] + 1 + splits = np.insert(splits, 0, 0) + + # Compute the result columns + summed_cols = np.add.reduceat(locs[:, 1], splits) + counts = np.diff(np.append(splits, locs.shape[0])) + avg_cols = summed_cols / counts + return avg_cols + + +def calculate_activation_statistics(activations): + """ + Params: + -- activation: num_samples x dim_feat + Returns: + -- mu: dim_feat + -- sigma: dim_feat x dim_feat + """ + mu = np.mean(activations, axis=0) + cov = np.cov(activations, rowvar=False) + return mu, cov + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. The Frechet distance between two multivariate + Gaussians X_1 ~ N(mu_1, C_1) + + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative dataset set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative dataset set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" + assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ("fid calculation produces singular product; " "adding %s to diagonal of cov estimates") % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + # try again with diagonal %s + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean diff --git a/kimodo/model/__init__.py b/kimodo/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d4cd87748f15bd0538d256763a52b61ed480d4 --- /dev/null +++ b/kimodo/model/__init__.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Kimodo model package: main model class, text encoders, and loading utilities.""" + +from .common import resolve_target +from .kimodo_model import Kimodo +from .llm2vec import LLM2VecEncoder +from .load_model import load_model +from .loading import ( + AVAILABLE_MODELS, + DEFAULT_MODEL, + DEFAULT_TEXT_ENCODER_URL, + MODEL_NAMES, + load_checkpoint_state_dict, +) +from .tmr import TMR +from .twostage_denoiser import TwostageDenoiser + +__all__ = [ + "Kimodo", + "LLM2VecEncoder", + "TMR", + "TwostageDenoiser", + "load_model", + "load_checkpoint_state_dict", + "resolve_target", + "AVAILABLE_MODELS", + "DEFAULT_MODEL", + "DEFAULT_TEXT_ENCODER_URL", + "MODEL_NAMES", +] diff --git a/kimodo/model/backbone.py b/kimodo/model/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..014f6599f2b0ff7fddfabb9a47db8d0962941b11 --- /dev/null +++ b/kimodo/model/backbone.py @@ -0,0 +1,312 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Transformer backbone: padding, masking, and encoder stack for the denoiser.""" + +import logging +from typing import Optional, Union + +import torch +from omegaconf import ListConfig +from pydantic.dataclasses import dataclass +from torch import Tensor, nn +from torch.nn import TransformerEncoder, TransformerEncoderLayer + +from kimodo.tools import validate + +log = logging.getLogger(__name__) + + +def pad_x_and_mask_to_fixed_size(x: Tensor, mask: Tensor, size: int): + """Pad a feature vector x and the mask to always have the same size. + + Args: + x (torch.Tensor): [B, T, D] + mask (torch.Tensor): [B, T] + size (int) + Returns: + torch.Tensor: [B, size, D] + torch.Tensor: [B, size] + """ + + batch_size, cur_max_size, dim = x.shape[0], x.shape[1], x.shape[2] + + if cur_max_size == size: + # already padded to this size, probably in the collate function + return x, mask + + if cur_max_size > size: + # This issue should have been handled in the collate function + # usefull as a check for test time + log.warn("The size of the tensor is larger than the maximum size. Cropping the input..") + cur_max_size = size + + new_x = torch.zeros( + (batch_size, size, dim), + dtype=x.dtype, + device=x.device, + ) + new_x[:, :cur_max_size] = x + + # same for the mask + new_mask = torch.zeros( + (batch_size, size), + dtype=mask.dtype, + device=mask.device, + ) + new_mask[:, :cur_max_size] = mask + return new_x, new_mask + + +@dataclass(frozen=True, config=dict(extra="forbid", arbitrary_types_allowed=True)) +class TransformerEncoderBlockConfig: + """Configuration for the transformer encoder backbone.""" + + # input features dimension + input_dim: int + # output features dimension + output_dim: int + + # skeleton object + skeleton: object + + # dimension of the text embeddings + llm_shape: Union[list[int], ListConfig] + + # mask the text or not + use_text_mask: bool + + # latent dimension of the model + latent_dim: int + # dimension of the feedforward network in transformer + ff_size: int + # num layers in transformer + num_layers: int + # num heads in transformer + num_heads: int + # activation in transformer + activation: str + # dropout rate for the transformer + dropout: float + # dropout rate for the positional embeddings + pe_dropout: float + # use norm first or not + norm_first: bool = False + # artificially extend the number of text tokens + num_text_tokens_override: Optional[int] = None + + # Input first heading angle + input_first_heading_angle: bool = False + + +class TransformerEncoderBlock(nn.Module): + @validate(TransformerEncoderBlockConfig, save_args=True, super_init=True) + def __init__(self, conf): + self.nbjoints = self.skeleton.nbjoints + llm_dim = self.llm_shape[-1] + self.embed_text = nn.Linear(llm_dim, self.latent_dim) + + self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.pe_dropout) + + # maximum number of tokens + self.num_text_tokens = self.llm_shape[0] + if self.num_text_tokens_override is not None: + self.num_text_tokens = self.num_text_tokens_override + + self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) + + self.input_linear = nn.Linear(self.input_dim, self.latent_dim) + self.output_linear = nn.Linear(self.latent_dim, self.output_dim) + self.linear_first_heading_angle = nn.Linear(2, self.latent_dim) + + trans_enc_layer = TransformerEncoderLayer( + d_model=self.latent_dim, + nhead=self.num_heads, + dim_feedforward=self.ff_size, + dropout=self.dropout, + activation=self.activation, + batch_first=True, + norm_first=self.norm_first, + ) + self.seqTransEncoder = TransformerEncoder( + trans_enc_layer, + num_layers=self.num_layers, + enable_nested_tensor=False, + ) + + def forward( + self, + x: Tensor, + x_pad_mask: torch.Tensor, + text_feat: torch.Tensor, + text_feat_pad_mask: torch.Tensor, + timesteps: Tensor, + first_heading_angle: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x (torch.Tensor): [B, T, dim_motion] current noisy motion + x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not + text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts + text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not + timesteps (torch.Tensor): [B,] current denoising step + + Returns: + torch.Tensor: [B, T, output_dim] + """ + batch_size = len(x) + x = self.input_linear(x) # [B, T, D] + + # Pad the text tokens + mask to always have the same size == self.num_text_tokens + # done here if it was not done in the collate function + if self.num_text_tokens is not None: + text_feat, text_feat_pad_mask = pad_x_and_mask_to_fixed_size( + text_feat, + text_feat_pad_mask, + self.num_text_tokens, + ) + + # Encode the text features and the time information + emb_text = self.embed_text(text_feat) # [B, max_text_len, D] + emb_time = self.embed_timestep(timesteps) # [B, 1, D] + + # Create mask for the time information + time_mask = torch.ones((batch_size, 1), dtype=bool, device=x.device) + + # Create the prefix features (text, time, etc): [B, max_text_len + 1 + etc] + prefix_feats = torch.cat((emb_text, emb_time), axis=1) + + # Behavior from old code: not use text mask -> True for all the tokens + if not self.use_text_mask: + text_feat_pad_mask = torch.ones( + (batch_size, emb_text.shape[1]), + dtype=torch.bool, + device=x.device, + ) + + prefix_mask = torch.cat((text_feat_pad_mask, time_mask), axis=1) + + # add the input first heading angle + if self.input_first_heading_angle: + assert first_heading_angle is not None, "The first heading angle is mandatory for this model" + # cos(angle) / sin(angle) + first_heading_angle_feats = torch.stack( + [ + torch.cos(first_heading_angle), + torch.sin(first_heading_angle), + ], + axis=-1, + ) + + first_heading_angle_feats = self.linear_first_heading_angle(first_heading_angle_feats) + first_heading_angle_feats = first_heading_angle_feats[:, None] # for cat + first_heading_angle_mask = torch.ones( + (batch_size, 1), + dtype=bool, + device=x.device, + ) + prefix_feats = torch.cat((prefix_feats, first_heading_angle_feats), axis=1) + prefix_mask = torch.cat((prefix_mask, first_heading_angle_mask), axis=1) + + # compute the number of prefix features + pose_start_ind = prefix_feats.shape[1] + + # Concatenate prefix and x: [B, len(prefix) + T, D] + xseq = torch.cat((prefix_feats, x), axis=1) + + # Concatenate the masks and negate them: [B, len(prefix) + T] + src_key_padding_mask = ~torch.cat((prefix_mask, x_pad_mask), axis=1) + + # Add positional encoding + xseq = self.sequence_pos_encoder(xseq) + + # Input to the transformer and keep the motion indexes + if isinstance(self.seqTransEncoder, nn.TransformerEncoder): + assert not self.seqTransEncoder.use_nested_tensor, "Flash attention should be disabled due to bug!" + + output = self.seqTransEncoder( + xseq, + src_key_padding_mask=src_key_padding_mask, + ) + output = output[:, pose_start_ind:] # [B, T, D] + output = self.output_linear(output) # [B, T, OD] + return output + + +class PositionalEncoding(nn.Module): + """Non-learned positional encoding.""" + + def __init__( + self, + d_model: int, + dropout: Optional[float] = 0.1, + max_len: Optional[int] = 5000, + ): + """ + Args: + d_model (int): input dim + dropout (Optional[float] = 0.1): dropout probability on output + max_len (Optional[int] = 5000): maximum sequence length + """ + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + + # Note: have to replace torch.exp() and math.log() with torch.pow() + # due to MKL exp() and ln() throws floating point exceptions on certain CPUs + # see corresponding commit and MR + div_term = torch.pow(10000.0, -torch.arange(0, d_model, 2).float() / d_model) + # div_term = torch.exp( + # torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) + # ) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) # [1, T, D] + + self.register_buffer("pe", pe, persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply positional encoding to input sequence. + + Args: + x (torch.Tensor): [B, T, D] input motion sequence + + Returns: + torch.Tensor: [B, T, D] input motion with PE added to it (and optionally dropout) + """ + x = x + self.pe[:, : x.shape[1], :] + return self.dropout(x) + + +class TimestepEmbedder(nn.Module): + """Encoder for diffusion step.""" + + def __init__(self, latent_dim: int, sequence_pos_encoder: PositionalEncoding): + """ + Args: + latent_dim (int): dim to encode to + sequence_pos_encoder (PositionalEncoding): the PE to use on timesteps + """ + super().__init__() + self.latent_dim = latent_dim + self.sequence_pos_encoder = sequence_pos_encoder + + time_embed_dim = self.latent_dim + self.time_embed = nn.Sequential( + nn.Linear(self.latent_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + """Embed timesteps by adding PE then going through linear layers. + + Args: + timesteps (torch.Tensor): [B] + + Returns: + torch.Tensor: [B, 1, D] + """ + return self.time_embed(self.sequence_pos_encoder.pe.transpose(0, 1)[timesteps]) diff --git a/kimodo/model/cfg.py b/kimodo/model/cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..6c39defdbbf0c074a0adb746ec4d98266e5ee463 --- /dev/null +++ b/kimodo/model/cfg.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Classifier-free guidance wrapper for the denoiser at sampling time.""" + +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +CFG_TYPES = ["nocfg", "regular", "separated"] + + +class ClassifierFreeGuidedModel(nn.Module): + """Wrapper around denoiser to use classifier-free guidance at sampling time.""" + + def __init__(self, model: nn.Module, cfg_type: Optional[str] = "separated"): + """Wrap the denoiser for classifier-free guidance; cfg_type in CFG_TYPES (e.g. 'regular', + 'nocfg').""" + super().__init__() + self.model = model + assert cfg_type in CFG_TYPES, f"Invalid cfg_type: {cfg_type}" + self.cfg_type_default = cfg_type + + def forward( + self, + cfg_weight: Union[float, Tuple[float, float]], + x: torch.Tensor, + x_pad_mask: torch.Tensor, + text_feat: torch.Tensor, + text_feat_pad_mask: torch.Tensor, + timesteps: torch.Tensor, + first_heading_angle: Optional[torch.Tensor] = None, + motion_mask: Optional[torch.Tensor] = None, + observed_motion: Optional[torch.Tensor] = None, + cfg_type: Optional[str] = None, + ) -> torch.Tensor: + """ + Args: + cfg_weight (float): guidance weight float or tuple of floats with (text, constraint) weights if using separated cfg + x (torch.Tensor): [B, T, dim_motion] current noisy motion + x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not + text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts + text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not + timesteps (torch.Tensor): [B,] current denoising step + motion_mask + observed_motion + neutral_joints (torch.Tensor): [B, nbjoints] The neutral joints of the motions + + Returns: + torch.Tensor: same size as input x + """ + + if cfg_type is None: + cfg_type = self.cfg_type_default + + assert cfg_type in CFG_TYPES, f"Invalid cfg_type: {cfg_type}" + + # batched conditional and uncond pass together + if cfg_type == "nocfg": + return self.model( + x, + x_pad_mask, + text_feat, + text_feat_pad_mask, + timesteps, + first_heading_angle=first_heading_angle, + motion_mask=motion_mask, + observed_motion=observed_motion, + ) + elif cfg_type == "regular": + assert isinstance(cfg_weight, (float, int)), "cfg_weight must be a single float for regular CFG" + # out_uncond + w * (out_text_and_constraint - out_uncond) + text_feat = torch.concatenate([text_feat, 0 * text_feat], dim=0) + if motion_mask is not None: + motion_mask = torch.concatenate([motion_mask, 0 * motion_mask], dim=0) + if observed_motion is not None: + observed_motion = torch.concatenate([observed_motion, observed_motion], dim=0) + if first_heading_angle is not None: + first_heading_angle = torch.concatenate([first_heading_angle, first_heading_angle], dim=0) + + out_cond_uncond = self.model( + torch.concatenate([x, x], dim=0), + torch.concatenate([x_pad_mask, x_pad_mask], dim=0), + text_feat, + torch.concatenate([text_feat_pad_mask, False * text_feat_pad_mask], dim=0), + torch.concatenate([timesteps, timesteps], dim=0), + first_heading_angle=first_heading_angle, + motion_mask=motion_mask, + observed_motion=observed_motion, + ) + + out, out_uncond = torch.chunk(out_cond_uncond, 2) + out_new = out_uncond + (cfg_weight * (out - out_uncond)) + elif cfg_type == "separated": + assert len(cfg_weight) == 2, "cfg_weight must be a tuple of two floats for separated CFG" + # out_uncond + w_text * (out_text - out_uncond) + w_constraint * (out_constraint - out_uncond) + text_feat = torch.concatenate([text_feat, 0 * text_feat, 0 * text_feat], dim=0) + if motion_mask is not None: + motion_mask = torch.concatenate([0 * motion_mask, motion_mask, 0 * motion_mask], dim=0) + if observed_motion is not None: + observed_motion = torch.concatenate([observed_motion, observed_motion, observed_motion], dim=0) + if first_heading_angle is not None: + first_heading_angle = torch.concatenate( + [first_heading_angle, first_heading_angle, first_heading_angle], + dim=0, + ) + + out_cond_uncond = self.model( + torch.concatenate([x, x, x], dim=0), + torch.concatenate([x_pad_mask, x_pad_mask, x_pad_mask], dim=0), + text_feat, + torch.concatenate( + [ + text_feat_pad_mask, + False * text_feat_pad_mask, + False * text_feat_pad_mask, + ], + dim=0, + ), + torch.concatenate([timesteps, timesteps, timesteps], dim=0), + first_heading_angle=first_heading_angle, + motion_mask=motion_mask, + observed_motion=observed_motion, + ) + + out_text, out_constraint, out_uncond = torch.chunk(out_cond_uncond, 3) + out_new = ( + out_uncond + (cfg_weight[0] * (out_text - out_uncond)) + (cfg_weight[1] * (out_constraint - out_uncond)) + ) + else: + raise ValueError(f"Invalid cfg_type: {cfg_type}") + + return out_new diff --git a/kimodo/model/common.py b/kimodo/model/common.py new file mode 100644 index 0000000000000000000000000000000000000000..3b6937bb98bc676d380ef4657cfd58f42bc5f294 --- /dev/null +++ b/kimodo/model/common.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Config hydration: env vars, _target_ resolution, and recursive instantiation.""" + +import importlib +import os + + +def get_env_var(name: str, default=None): + """Read env var by name and by lowercased name; return default if neither set.""" + return os.getenv(name, os.getenv(name.lower(), default)) + + +def resolve_target(target: str): + """Import module and return the attribute named by a dotted path (e.g. 'pkg.mod.Class').""" + module_name, attr_name = target.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, attr_name) + + +def materialize_value(value): + """Recursively turn dicts with '_target_' into instances; lists/dicts traversed; leaves + unchanged.""" + if isinstance(value, dict): + if "_target_" in value: + return instantiate_from_dict(value) + return {k: materialize_value(v) for k, v in value.items()} + if isinstance(value, list): + return [materialize_value(v) for v in value] + return value + + +def instantiate_from_dict(node, overrides=None): + """Build an instance from a config dict: '_target_' gives the class, other keys are kwargs; overrides merged in.""" + if not isinstance(node, dict) or "_target_" not in node: + raise ValueError("Config node must be a dict with a '_target_' key.") + + target = resolve_target(node["_target_"]) + kwargs = {} + for key, value in node.items(): + if key == "_target_": + continue + kwargs[key] = materialize_value(value) + + if overrides: + kwargs.update({k: v for k, v in overrides.items() if v is not None}) + + return target(**kwargs) diff --git a/kimodo/model/diffusion.py b/kimodo/model/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..7e36d9940142b01d9447ceac2f9425d87589af4d --- /dev/null +++ b/kimodo/model/diffusion.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Diffusion process and DDIM sampling for motion generation.""" + +import math +from typing import Optional, Tuple + +import torch +from torch import nn + + +def get_beta_schedule( + num_diffusion_timesteps: int, + max_beta: Optional[float] = 0.999, +) -> torch.Tensor: + """Get cosine beta schedule.""" + + def alpha_bar(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float) + + +class Diffusion(torch.nn.Module): + """Cosine-schedule diffusion process: betas, alphas, and DDIM step mapping.""" + + def __init__(self, num_base_steps: int): + """Set up cosine beta schedule and precompute diffusion variables for num_base_steps.""" + super().__init__() + self.num_base_steps = num_base_steps + betas_base = get_beta_schedule(self.num_base_steps) + self.register_buffer("betas_base", betas_base, persistent=False) + alphas_cumprod_base = torch.cumprod(1.0 - self.betas_base, dim=0) + self.register_buffer("alphas_cumprod_base", alphas_cumprod_base, persistent=False) + use_timesteps, _ = self.space_timesteps(self.num_base_steps) + self.calc_diffusion_vars(use_timesteps) + + def extra_repr(self) -> str: + return f"num_base_steps={self.num_base_steps}" + + @property + def device(self): + return self.betas_base.device + + def space_timesteps(self, num_denoising_steps: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Return (use_timesteps, map_tensor) for a subsampled denoising schedule of + num_denoising_steps.""" + nsteps_train = self.num_base_steps + frac_stride = (nsteps_train - 1) / max(1, num_denoising_steps - 1) + use_timesteps = torch.round(torch.arange(nsteps_train, device=self.device) * frac_stride).to(torch.long) + use_timesteps = torch.clamp(use_timesteps, max=nsteps_train - 1) + map_tensor = torch.arange(nsteps_train, device=self.device, dtype=torch.long)[use_timesteps] + return use_timesteps, map_tensor + + def calc_diffusion_vars(self, use_timesteps: torch.Tensor) -> None: + """Update buffers (betas, alphas, alphas_cumprod, etc.) for the given subsampled + timesteps.""" + alphas_cumprod = self.alphas_cumprod_base[use_timesteps] + last_alpha_cumprod = torch.cat([torch.tensor([1.0]).to(alphas_cumprod), alphas_cumprod[:-1]]) + betas = 1.0 - alphas_cumprod / last_alpha_cumprod + self.register_buffer("betas", betas, persistent=False) + + alphas = 1.0 - self.betas + self.register_buffer("alphas", alphas, persistent=False) + alphas_cumprod = torch.cumprod(self.alphas, dim=0) + alphas_cumprod = torch.clamp(alphas_cumprod, min=1e-9) + self.register_buffer("alphas_cumprod", alphas_cumprod, persistent=False) + + alphas_cumprod_prev = torch.cat([torch.tensor([1.0]).to(self.alphas_cumprod), self.alphas_cumprod[:-1]]) + self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev, persistent=False) + + sqrt_recip_alphas_cumprod = torch.rsqrt(self.alphas_cumprod) + self.register_buffer("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod, persistent=False) + + sqrt_recipm1_alphas_cumprod = torch.rsqrt(self.alphas_cumprod / (1.0 - self.alphas_cumprod)) + self.register_buffer("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod, persistent=False) + + posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.register_buffer("posterior_variance", posterior_variance, persistent=False) + + sqrt_alphas_cumprod = torch.rsqrt(1.0 / self.alphas_cumprod) + self.register_buffer("sqrt_alphas_cumprod", sqrt_alphas_cumprod, persistent=False) + + sqrt_one_minus_alphas_cumprod = torch.rsqrt(1.0 / (1.0 - self.alphas_cumprod)) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + sqrt_one_minus_alphas_cumprod, + persistent=False, + ) + + def q_sample( + self, + x_start: torch.Tensor, + t: torch.Tensor, + noise: torch.Tensor = None, + ): + if noise is None: + noise = torch.randn_like(x_start) + assert noise.shape == x_start.shape + + xt = ( + self.sqrt_alphas_cumprod[t, None, None] * x_start + + self.sqrt_one_minus_alphas_cumprod[t, None, None] * noise + ) + return xt + + +class DDIMSampler(nn.Module): + """Deterministic DDIM sampler (eta = 0).""" + + def __init__(self, diffusion: Diffusion): + super().__init__() + self.diffusion = diffusion + + def __call__( + self, + use_timesteps: torch.Tensor, + x_t: torch.Tensor, + pred_xstart: torch.Tensor, + t: torch.Tensor, + ) -> torch.Tensor: + self.diffusion.calc_diffusion_vars(use_timesteps) + eps = ( + self.diffusion.sqrt_recip_alphas_cumprod[t, None, None] * x_t - pred_xstart + ) / self.diffusion.sqrt_recipm1_alphas_cumprod[t, None, None] + alpha_bar_prev = self.diffusion.alphas_cumprod_prev[t, None, None] + x = pred_xstart * torch.sqrt(alpha_bar_prev) + torch.sqrt(1 - alpha_bar_prev) * eps + return x diff --git a/kimodo/model/kimodo_model.py b/kimodo/model/kimodo_model.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf78eac5145ebe2d3e42c8b2d5994531ce026f3 --- /dev/null +++ b/kimodo/model/kimodo_model.py @@ -0,0 +1,605 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Kimodo model: denoiser, text encoder, diffusion sampling, and post-processing.""" + +import logging +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from tqdm.auto import tqdm + +from kimodo.constraints import FullBodyConstraintSet +from kimodo.motion_rep.feature_utils import compute_heading_angle, length_to_mask +from kimodo.postprocess import post_process_motion +from kimodo.sanitize import sanitize_texts +from kimodo.skeleton import SOMASkeleton30 +from kimodo.tools import to_numpy + +from .cfg import ClassifierFreeGuidedModel +from .diffusion import DDIMSampler, Diffusion + +log = logging.getLogger(__name__) + + +class Kimodo(nn.Module): + """Helper class for test time.""" + + def __init__( + self, + denoiser: nn.Module, + text_encoder: nn.Module, + num_base_steps: int, + device: Optional[Union[str, torch.device]] = None, + cfg_type: Optional[str] = "separated", + ): + super().__init__() + + self.denoiser = denoiser.eval() + + if cfg_type is None: + cfg_type = "nocfg" + + # Add Classifier-free guidance to the model if needed + self.denoiser = ClassifierFreeGuidedModel(self.denoiser, cfg_type=cfg_type) + + self.motion_rep = denoiser.motion_rep + self.skeleton = self.motion_rep.skeleton + + self.fps = denoiser.motion_rep.fps + + self.diffusion = Diffusion(num_base_steps=num_base_steps) + self.sampler = DDIMSampler(self.diffusion) + self.text_encoder = text_encoder + + self.device = device + # for classifier-free guidance + + self.to(device) + + @property + def output_skeleton(self): + """Skeleton used for model output (somaskel77 for SOMA, else unchanged).""" + if isinstance(self.skeleton, SOMASkeleton30): + return self.skeleton.somaskel77 + return self.skeleton + + def train(self, mode: bool): + self.denoiser.train(mode) + return self + + def eval(self): + self.denoiser.eval() + return self + + def denoising_step( + self, + motion: torch.Tensor, + pad_mask: torch.Tensor, + text_feat: torch.Tensor, + text_pad_mask: torch.Tensor, + t: torch.Tensor, + first_heading_angle: Optional[torch.Tensor], + motion_mask: torch.Tensor, + observed_motion: torch.Tensor, + num_denoising_steps: torch.Tensor, + cfg_weight: Union[float, Tuple[float, float]], + guide_masks: Optional[Dict] = None, + cfg_type: Optional[str] = None, + ) -> torch.Tensor: + """Single denoising step. + + Returns: + torch.Tensor: [B, T, D] noisy motion input to t-1 + """ + # subsample timesteps + # NOTE: do this at every step due to ONNX export, i.e. num_samp_stepsmay change dynamically when + # running onnx version so need to account for that. + num_denoising_steps = num_denoising_steps[0] + use_timesteps, map_tensor = self.diffusion.space_timesteps(num_denoising_steps) + self.diffusion.calc_diffusion_vars(use_timesteps) + + # first compute initial clean prediction from denoiser + t_map = map_tensor[t] + + with torch.inference_mode(): + pred_clean = self.denoiser( + cfg_weight, + motion, + pad_mask, + text_feat, + text_pad_mask, + t_map, + first_heading_angle, + motion_mask, + observed_motion, + cfg_type=cfg_type, + ) + + # sampler computes next step noisy motion + x_tm1 = self.sampler(use_timesteps, motion, pred_clean, t) + return x_tm1 + + def _multiprompt( + self, + prompts: list[str], + num_frames: int | list[int], + num_denoising_steps: int, + constraint_lst: Optional[list] = [], + cfg_weight: Optional[float] = [2.0, 2.0], + num_samples: Optional[int] = None, + cfg_type: Optional[str] = None, + return_numpy: bool = False, + first_heading_angle: Optional[torch.Tensor] = None, + # for transitioning + num_transition_frames: int = 5, + share_transition: bool = True, + percentage_transition_override=0.10, + # for postprocess + post_processing: bool = False, + root_margin: float = 0.04, + # progress bar + progress_bar=tqdm, + ) -> torch.Tensor: + device = self.device + + bs = num_samples + texts = sanitize_texts(prompts) + + if isinstance(num_frames, int): + # same duration for all the segments + num_frames = [num_frames for _ in range(num_samples)] + + tosqueeze = False + if num_samples is None: + num_samples = 1 + tosqueeze = True + + if constraint_lst is None: + constraint_lst = [] + + # Generate one chunck at a time + current_frame = 0 + generated_motions = [] + + for idx, (text, num_frame) in enumerate(zip(texts, num_frames)): + texts_bs = [text for _ in range(num_samples)] + + lengths = torch.tensor( + [num_frame for _ in range(num_samples)], + device=device, + ) + + is_first_motion = not generated_motions + + observed_motion, motion_mask = None, None + + # filter the constraint_lst to only keep the relevent ones + constraint_lst_base = [ + constraint.crop_move(current_frame, current_frame + num_frame) for constraint in constraint_lst + ] # this move temporally but not spatially + + observed_motion, motion_mask = self.motion_rep.create_conditions_from_constraints_batched( + constraint_lst_base, + lengths, + to_normalize=False, # don't normalize yet, it needs to be moved around + device=device, + ) + + if not is_first_motion: + prev_num_frame = num_frames[idx - 1] + if share_transition: + # starting the transitioning earlier, to "share" the transition between A and B + # in any case, we still use "num_transition_frames" for conditioning + # we don't condition until the end of A + # we compute the number of frames of transition as a percentage of the last motion + nb_transition_frames = num_transition_frames + int(prev_num_frame * percentage_transition_override) + else: + nb_transition_frames = num_transition_frames + + latest_motions = generated_motions.pop() + # remove the transition part of A (will be put back afterward) + generated_motions.append(latest_motions[:, :-nb_transition_frames]) + latest_frames = latest_motions[:, -nb_transition_frames:] + # latest_frames[..., 2] += 0.5 + + last_output = self.motion_rep.inverse( + latest_frames, + is_normalized=False, + return_numpy=False, + ) + smooth_root_2d = last_output["smooth_root_pos"][..., [0, 2]] + + # add constraints at the begining to allow natural transitions + constraint_lst_transition = [] + for batch_id in range(bs): + new_constraint = FullBodyConstraintSet( + self.skeleton, + torch.arange(num_transition_frames), + last_output["posed_joints"][batch_id, :num_transition_frames], + last_output["local_rot_mats"][batch_id, :num_transition_frames], + smooth_root_2d[batch_id, :num_transition_frames], + ) + + # new lists + constraint_lst_transition.append([new_constraint]) + + transition_lengths = torch.tensor( + [nb_transition_frames for _ in range(num_samples)], + device=device, + ) + + observed_motion_transition, motion_mask_transition = ( + self.motion_rep.create_conditions_from_constraints_batched( + constraint_lst_transition, + transition_lengths, + to_normalize=False, # don't normalize yet + device=device, + ) + ) + + # concatenate the obversed motion / motion mask + observed_motion = torch.cat([observed_motion_transition, observed_motion], axis=1) + motion_mask = torch.cat([motion_mask_transition, motion_mask], axis=1) + + # we need to move each observed motion in the batch to the new starting points + last_smooth_root_2d = smooth_root_2d[:, 0] + observed_motion = self.motion_rep.translate_2d( + observed_motion, -last_smooth_root_2d + ) # equivalent to: self.motion_rep.translate_2d_to_zero(observed_motion) + + # remove dummy values after moving + observed_motion = observed_motion * motion_mask + + lengths = lengths + transition_lengths + first_heading_angle = compute_heading_angle(last_output["posed_joints"], self.skeleton)[:, 0] + else: + if first_heading_angle is None: + # Start at 0 angle, but this will change afterward + first_heading_angle = torch.tensor([0.0] * bs, device=device) + else: + first_heading_angle = torch.as_tensor(first_heading_angle, device=device) + if first_heading_angle.numel() == 1: + first_heading_angle = first_heading_angle.repeat(bs) + + observed_motion = self.motion_rep.normalize(observed_motion) + + max_frames = max(lengths) + motion_pad_mask = length_to_mask(lengths) + + motion = self._generate( + texts_bs, + max_frames, + num_denoising_steps=num_denoising_steps, + pad_mask=motion_pad_mask, + first_heading_angle=first_heading_angle, + motion_mask=motion_mask, + observed_motion=observed_motion, + cfg_weight=cfg_weight, + cfg_type=cfg_type, + ) + + motion = self.motion_rep.unnormalize(motion) + + if not is_first_motion: + motion_with_transition = self.motion_rep.translate_2d( + motion, + last_smooth_root_2d, + ) + + motion = motion_with_transition[:, num_transition_frames:] + transition_frames = motion_with_transition[:, :num_transition_frames] + # for sharing = True, the new motion contains the very last of A + + # linearly combine the previously generated transitions with the newly generated ones + # so that we linearly go from previous gen to new gen + alpha = torch.linspace(1, 0, num_transition_frames, device=device)[:, None] + new_transition_frames = ( + latest_frames[:, :num_transition_frames] * alpha + (1 - alpha) * transition_frames + ) + + # add new transitions frames for A (merging with B predition of the history) + # for share_transition == True, this remove (do not add back) a small part of the end of A + # the small last part of A has been re-generated by B + generated_motions.append(new_transition_frames) + + # motion[..., 2] += 0.5 + + generated_motions.append(motion) + current_frame += num_frame + + generated_motions = torch.cat(generated_motions, axis=1) # temporal axis (b, t, d) + + if tosqueeze: + generated_motions = generated_motions[0] + + output = self.motion_rep.inverse( + generated_motions, + is_normalized=False, + return_numpy=False, + ) + + # Apply post-processing if requested + if post_processing: + corrected = post_process_motion( + output["local_rot_mats"], + output["root_positions"], + output["foot_contacts"], + self.skeleton, + constraint_lst, + root_margin=root_margin, + ) + output.update(corrected) + + # Convert SOMA output to somaskel77 for external API + if isinstance(self.skeleton, SOMASkeleton30): + output = self.skeleton.output_to_SOMASkeleton77(output) + + # Convert to numpy if requested + if return_numpy: + output = to_numpy(output) + return output + + def __call__( + self, + prompts: str | list[str], + num_frames: int | list[int], + num_denoising_steps: int, + multi_prompt: bool = False, + constraint_lst: Optional[list] = [], + cfg_weight: Optional[float] = [2.0, 2.0], + num_samples: Optional[int] = None, + cfg_type: Optional[str] = None, + return_numpy: bool = False, + first_heading_angle: Optional[torch.Tensor] = None, + # for transitioning + num_transition_frames: int = 5, + share_transition: bool = True, + percentage_transition_override=0.10, + # for postprocess + post_processing: bool = False, + root_margin: float = 0.04, + # progress bar + progress_bar=tqdm, + ) -> dict: + """Generate motion from text prompts and optional kinematic constraints. + + When a single prompt/num_frames pair is given, one motion is generated. + Passing lists of prompts and/or num_frames produces a batch of + independent motions. With ``multi_prompt=True``, the prompts are + treated as sequential segments that are generated and stitched together + with smooth transitions. + + Args: + prompts: One or more text descriptions of the desired motion. + A single string generates one sample; a list generates a batch + (or sequential segments when ``multi_prompt=True``). + num_frames: Duration of the generated motion in frames. Can be a + single int applied to every prompt or a per-prompt list. + num_denoising_steps: Number of DDIM denoising steps. More steps + generally improve quality at the cost of speed. + multi_prompt: If ``True``, treat ``prompts`` as an ordered sequence + of segments and concatenate them with transitions. + constraint_lst: Per-sample list of kinematic constraints (e.g. + keyframe poses, end-effector targets, 2-D paths). Pass an + empty list for unconstrained generation. + cfg_weight: Classifier-free guidance scale(s). A two-element list + ``[text_cfg, constraint_cfg]`` controls text and constraint + guidance independently. + num_samples: Number of samples to generate. + cfg_type: Override the default CFG strategy set at init + (e.g. ``"separated"``). + return_numpy: If ``True``, convert all output tensors to numpy + arrays. + first_heading_angle: Initial body heading in radians. Shape + ``(B,)`` or scalar. Defaults to ``0`` (facing +Z). + num_transition_frames: Number of overlapping frames used to blend + consecutive segments in multi-prompt mode. + share_transition: If ``True``, transition frames are shared between + adjacent segments rather than appended. + percentage_transition_override: Fraction of each segment's length + that may be overridden by the transition blend. + post_processing: If ``True``, apply post-processing + (foot-skate cleanup and constraint enforcement). + root_margin: Horizontal margin (in meters) used by the post-processor + to determine when to correct root motion. When root deviates more than + margin from the constraint, the post-processor will correct it. + progress_bar: Callable wrapping an iterable to display progress + (default: ``tqdm``). Pass a no-op to silence output. + + Returns: + dict: A dictionary of motion tensors (or numpy arrays if + ``return_numpy=True``) with the following keys: + + - ``local_rot_mats`` – Local joint rotations as rotation matrices. + - ``global_rot_mats`` – Global joint rotations as rotation matrices. + - ``posed_joints`` – Joint positions in world space. + - ``root_positions`` – Root joint positions. + - ``smooth_root_pos`` – Smoothed root trajectory. + - ``foot_contacts`` – Boolean foot-contact labels [left heel, left toe, right heel, right toe]. + - ``global_root_heading`` – Root heading angle over time. + """ + device = self.device + + if multi_prompt: + # multi prompt generation + return self._multiprompt( + prompts, + num_frames, + num_denoising_steps, + constraint_lst, + cfg_weight, + num_samples, + cfg_type, + return_numpy, + first_heading_angle, + num_transition_frames, + share_transition, + percentage_transition_override, + post_processing, + root_margin, + progress_bar, + ) + + # Input checking + tosqueeze = False + if isinstance(prompts, list) and isinstance(num_frames, list): + assert len(prompts) == len(num_frames), "The number of prompts should match the number of num_frames." + num_samples = len(prompts) + elif isinstance(prompts, list): + num_samples = len(prompts) + num_frames = [num_frames for _ in range(num_samples)] + elif isinstance(num_frames, list): + num_samples = len(num_frames) + prompts = [prompts for _ in range(num_samples)] + else: + if num_samples is None: + tosqueeze = True + num_samples = 1 + prompts = [prompts for _ in range(num_samples)] + num_frames = [num_frames for _ in range(num_samples)] + + bs = num_samples + texts = sanitize_texts(prompts) + + lengths = torch.tensor( + num_frames, + device=device, + ) + max_frames = max(lengths) + motion_pad_mask = length_to_mask(lengths) + + if first_heading_angle is None: + # Start at 0 angle + first_heading_angle = torch.tensor([0.0] * bs, device=device) + else: + first_heading_angle = torch.as_tensor(first_heading_angle, device=device) + if first_heading_angle.numel() == 1: + first_heading_angle = first_heading_angle.repeat(bs) + + observed_motion, motion_mask = None, None + if constraint_lst: + observed_motion, motion_mask = self.motion_rep.create_conditions_from_constraints_batched( + constraint_lst, + lengths, + to_normalize=True, + device=device, + ) + + motion = self._generate( + texts, + max_frames, + num_denoising_steps=num_denoising_steps, + pad_mask=motion_pad_mask, + first_heading_angle=first_heading_angle, + motion_mask=motion_mask, + observed_motion=observed_motion, + cfg_weight=cfg_weight, + cfg_type=cfg_type, + progress_bar=progress_bar, + ) + + if tosqueeze: + motion = motion[0] + + output = self.motion_rep.inverse( + motion, + is_normalized=True, + return_numpy=False, # Keep as tensor for potential post-processing + ) + + # Apply post-processing if requested + if post_processing: + corrected = post_process_motion( + output["local_rot_mats"], + output["root_positions"], + output["foot_contacts"], + self.skeleton, + constraint_lst, + root_margin=root_margin, + ) + # key frame outputs / foot contacts are not changed + output.update(corrected) + + # Convert SOMA output to somaskel77 for external API + if isinstance(self.skeleton, SOMASkeleton30): + output = self.skeleton.output_to_SOMASkeleton77(output) + + # Convert to numpy if requested + if return_numpy: + output = to_numpy(output) + return output + + def _generate( + self, + texts: List[str], + max_frames: int, + num_denoising_steps: int, + pad_mask: torch.Tensor, + first_heading_angle: Optional[torch.Tensor], + motion_mask: torch.Tensor, + observed_motion: torch.Tensor, + cfg_weight: Optional[float] = 2.0, + text_feat: Optional[torch.Tensor] = None, + text_pad_mask: Optional[torch.Tensor] = None, + guide_masks: Optional[Dict] = None, + cfg_type: Optional[str] = None, + progress_bar=tqdm, + ) -> torch.Tensor: + """Sample full denoising loop. + + Args: + texts (List[str]): batch of text prompts to use for sampling (if text_feat is not passed in) + """ + + device = self.device + if text_feat is None: + assert text_pad_mask is None + log.info("Encoding text...") + text_feat, text_length = self.text_encoder(texts) + text_feat = text_feat.to(device) + + # handle empty string (set to zero) + empty_text_mask = [len(text.strip()) == 0 for text in texts] + text_feat[empty_text_mask] = 0 + + # Create the pad mask for the text + batch_size, maxlen = text_feat.shape[:2] + tensor_text_length = torch.tensor(text_length, device=device) + tensor_text_length[empty_text_mask] = 0 + text_pad_mask = torch.arange(maxlen, device=device).expand(batch_size, maxlen) < tensor_text_length[:, None] + + if motion_mask is not None: + if motion_mask.dtype == torch.bool: + motion_mask = 1 * motion_mask + + batch_size = text_feat.shape[0] + + # sample loop + indices = list(range(num_denoising_steps))[::-1] + shape = (batch_size, max_frames, self.motion_rep.motion_rep_dim) + cur_mot = torch.randn(shape, device=self.device) + num_denoising_steps = torch.tensor( + [num_denoising_steps], device=self.device + ) # this and t need to be tensor for onnx export + # init diffusion with correct num steps before looping + use_timesteps = self.diffusion.space_timesteps(num_denoising_steps[0])[0] + self.diffusion.calc_diffusion_vars(use_timesteps) + for i in progress_bar(indices): + t = torch.tensor([i] * cur_mot.size(0), device=self.device) + with torch.inference_mode(): + cur_mot = self.denoising_step( + cur_mot, + pad_mask, + text_feat, + text_pad_mask, + t, + first_heading_angle, + motion_mask, + observed_motion, + num_denoising_steps, + cfg_weight, + guide_masks=guide_masks, + cfg_type=cfg_type, + ) + return cur_mot diff --git a/kimodo/model/llm2vec/README.md b/kimodo/model/llm2vec/README.md new file mode 100644 index 0000000000000000000000000000000000000000..450e995c0aa9cdb0f1117aa33cb30a30d84a41c0 --- /dev/null +++ b/kimodo/model/llm2vec/README.md @@ -0,0 +1 @@ +This is a patched version of the original [LLM2Vec](https://github.com/McGill-NLP/llm2vec) codebase so that `McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised` works with `transformers==5.0.0rc3`. diff --git a/kimodo/model/llm2vec/__init__.py b/kimodo/model/llm2vec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5890f5848bc66237649197109d3a31328d3f77a1 --- /dev/null +++ b/kimodo/model/llm2vec/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""LLM2Vec text encoder and wrapper for Kimodo.""" + +from .llm2vec import LLM2Vec +from .llm2vec_wrapper import LLM2VecEncoder + +__all__ = [ + "LLM2Vec", + "LLM2VecEncoder", +] diff --git a/kimodo/model/llm2vec/llm2vec.py b/kimodo/model/llm2vec/llm2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..6d01f5716ed57a6b6ea9b8cca3c9f292bdf19b5c --- /dev/null +++ b/kimodo/model/llm2vec/llm2vec.py @@ -0,0 +1,477 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from functools import partial +from typing import Dict, List, Optional, Union + +import numpy as np +import torch +import torch.multiprocessing as mp +from peft import PeftModel +from torch import Tensor, device, nn +from tqdm.autonotebook import tqdm, trange +from transformers import ( + AutoConfig, + AutoModel, + AutoTokenizer, + GemmaConfig, + LlamaConfig, + MistralConfig, + PretrainedConfig, + Qwen2Config, +) + +logger = logging.getLogger(__name__) + + +def batch_to_device(batch, target_device: device): + """Send a pytorch batch to a device (CPU/GPU)""" + for key in batch: + if isinstance(batch[key], Tensor): + batch[key] = batch[key].to(target_device) + return batch + + +class LLM2Vec(nn.Module): + def __init__( + self, + model: AutoModel, + tokenizer: AutoTokenizer, + pooling_mode: str = "mean", + max_length: int = 512, + doc_max_length: int = 400, + skip_instruction: bool = True, + ): + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.pooling_mode = pooling_mode + self.skip_instruction = skip_instruction + self.max_length = max_length + self.doc_max_length = doc_max_length + self.config = model.config + + @classmethod + def _get_model_class(cls, config_class_name, enable_bidirectional): + if not enable_bidirectional: + return AutoModel + if config_class_name == "MistralConfig": + from .models.bidirectional_mistral import MistralBiModel + + return MistralBiModel + elif config_class_name == "LlamaConfig": + from .models.bidirectional_llama import LlamaBiModel + + return LlamaBiModel + elif config_class_name == "GemmaConfig": + from .models.bidirectional_gemma import GemmaBiModel + + return GemmaBiModel + elif config_class_name == "Qwen2Config": + from .models.bidirectional_qwen2 import Qwen2BiModel + + return Qwen2BiModel + else: + raise ValueError(f"{config_class_name} is not supported yet with bidirectional models.") + + @classmethod + def from_pretrained( + cls, + base_model_name_or_path, + peft_model_name_or_path=None, + merge_peft=False, + enable_bidirectional=True, + **kwargs, + ): + # pop out encoder args + keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"] + encoder_args = {key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None} + + tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + config = AutoConfig.from_pretrained(base_model_name_or_path) + config_class_name = config.__class__.__name__ + + model_class = cls._get_model_class(config_class_name, enable_bidirectional=enable_bidirectional) + + model = model_class.from_pretrained(base_model_name_or_path, **kwargs) + + if os.path.isdir(base_model_name_or_path) and os.path.exists(f"{base_model_name_or_path}/config.json"): + with open(f"{base_model_name_or_path}/config.json", "r") as fIn: + config_dict = json.load(fIn) + config = PretrainedConfig.from_dict(config_dict) + model.config._name_or_path = config._name_or_path + + # For special case where config.json and adapter weights are in the same directory + if hasattr(model, "peft_config"): + model = PeftModel.from_pretrained( + model, + base_model_name_or_path, + ) + model = model.merge_and_unload() + + if peft_model_name_or_path is not None: + model = PeftModel.from_pretrained( + model, + peft_model_name_or_path, + ) + if merge_peft: + model = model.merge_and_unload() + + config = {} + config_addr = peft_model_name_or_path if peft_model_name_or_path is not None else base_model_name_or_path + if os.path.exists(f"{config_addr}/llm2vec_config.json"): + with open(f"{config_addr}/llm2vec_config.json", "r") as fIn: + llm2vec_config = json.load(fIn) + config.update(llm2vec_config) + + for key, value in encoder_args.items(): + config[key] = value + + return cls(model=model, tokenizer=tokenizer, **config) + + def prepare_for_tokenization(self, text): + if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct": + text = "<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>" + return text + if self.model.config._name_or_path in [ + "mistralai/Mistral-7B-Instruct-v0.2", + "meta-llama/Llama-2-7b-chat-hf", + ]: + text = "[INST] " + text.strip() + " [/INST]" + if self.model.config._name_or_path in [ + "google/gemma-2-9b-it", + ]: + text = "user\n" + text.strip() + "" + if self.model.config._name_or_path in [ + "Qwen/Qwen2-1.5B-Instruct", + "Qwen/Qwen2-7B-Instruct", + ]: + text = "<|im_start|>user\n" + text.strip() + "<|im_end|>" + if self.pooling_mode == "eos_token": + if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B": + text = text.strip() + "<|end_of_text|>" + elif isinstance(self.model.config, LlamaConfig) or isinstance(self.model.config, MistralConfig): + text = text.strip() + " " + elif isinstance(self.model.config, GemmaConfig): + text = text.strip() + "" + elif isinstance(self.model.config, Qwen2Config): + text = text.strip() + "<|endoftext|>" + return text + + def tokenize(self, texts): + texts_2 = [] + original_texts = [] + for text in texts: + t = text.split("!@#$%^&*()") + texts_2.append(t[1] if len(t) > 1 else "") + original_texts.append("".join(t)) + + original = self.tokenizer( + original_texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + ) + embed_mask = None + for t_i, t in enumerate(texts_2): + ids = self.tokenizer( + [t], + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + ) + if embed_mask is None: + e_m = torch.zeros_like(original["attention_mask"][t_i]) + if len(ids["input_ids"][0]) > 0: + e_m[-len(ids["input_ids"][0]) :] = torch.ones(len(ids["input_ids"][0])) + embed_mask = e_m.unsqueeze(0) + else: + e_m = torch.zeros_like(original["attention_mask"][t_i]) + if len(ids["input_ids"][0]) > 0: + e_m[-len(ids["input_ids"][0]) :] = torch.ones(len(ids["input_ids"][0])) + embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) + + original["embed_mask"] = embed_mask + return original + + def _skip_instruction(self, sentence_feature): + assert sentence_feature["attention_mask"].shape == sentence_feature["embed_mask"].shape + sentence_feature["attention_mask"] = sentence_feature["embed_mask"] + + def forward(self, sentence_feature: Dict[str, Tensor]): + embed_mask = None + if "embed_mask" in sentence_feature: + embed_mask = sentence_feature.pop("embed_mask") + reps = self.model(**sentence_feature) + sentence_feature["embed_mask"] = embed_mask + + return self.get_pooling(sentence_feature, reps.last_hidden_state) + + def get_pooling(self, features, last_hidden_states): # All models padded from left + assert self.tokenizer.padding_side == "left", "Pooling modes are implemented for padding from left." + if self.skip_instruction: + self._skip_instruction(features) + seq_lengths = features["attention_mask"].sum(dim=-1) + if self.pooling_mode == "mean": + return torch.stack( + [last_hidden_states[i, -length:, :].mean(dim=0) for i, length in enumerate(seq_lengths)], + dim=0, + ) + elif self.pooling_mode == "weighted_mean": + bs, l, _ = last_hidden_states.shape + complete_weights = torch.zeros(bs, l, device=last_hidden_states.device) + for i, seq_l in enumerate(seq_lengths): + if seq_l > 0: + complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1 + complete_weights[i] /= torch.clamp(complete_weights[i].sum(), min=1e-9) + return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1) + elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token": + return last_hidden_states[:, -1] + elif self.pooling_mode == "bos_token": + return last_hidden_states[features["input_ids"] == self.tokenizer.bos_token_id] + else: + raise ValueError(f"{self.pooling_mode} is not implemented yet.") + + def _convert_to_str(self, instruction, text): + tokenized_q = self.tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + ) + tokenized_q_length = len(tokenized_q["input_ids"][0]) + + while tokenized_q_length > self.doc_max_length: + reduction_ratio = self.doc_max_length / tokenized_q_length + reduced_length = int(len(text.split()) * reduction_ratio) + text = " ".join(text.split()[:reduced_length]) + tokenized_q = self.tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + ) + tokenized_q_length = len(tokenized_q["input_ids"][0]) + + return f"{instruction.strip()} !@#$%^&*(){text}" if instruction else f"!@#$%^&*(){text}" + + def encode( + self, + sentences: Union[str, List[str]], + batch_size: int = 32, + show_progress_bar: bool = True, + convert_to_numpy: bool = False, + convert_to_tensor: bool = False, + device: Optional[str] = None, + ): + """ + Encode a list of sentences to their respective embeddings. The sentences can be a list of strings or a string. + Args: + sentences: sentence or sentences to encode. + batch_size: batch size for turning sentence tokens into embeddings. + show_progress_bar: whether to show progress bars during encoding steps. + convert_to_numpy: If true, return numpy arrays instead of torch tensors. + convert_to_tensor: If true, return torch tensors (default). + device: torch backend device identifier (e.g., 'cuda', 'cpu','mps' etc.). If not specified, + the default is to use cuda when available, otherwise cpu. Note that only the choice of 'cuda' supports + multiprocessing as currently implemented. + + Returns: embeddings of the sentences. Embeddings are detached and always on the CPU (see _encode implementation). + + """ + if isinstance(sentences[0], str) and isinstance(sentences[-1], int): + sentences = [sentences] + # required for MEDI version of MTEB + if isinstance(sentences[0], str): + sentences = [[""] + [sentence] for sentence in sentences] + + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + concatenated_input_texts = [] + for sentence in sentences: + assert isinstance(sentence[0], str) + assert isinstance(sentence[1], str) + concatenated_input_texts.append(self._convert_to_str(sentence[0], sentence[1])) + sentences = concatenated_input_texts + + self.eval() + + if convert_to_tensor: + convert_to_numpy = False + + length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) + sentences_sorted = [sentences[idx] for idx in length_sorted_idx] + all_embeddings = [] + + if torch.cuda.device_count() <= 1: + # This branch also support mps devices + self.to(device) + for start_index in trange( + 0, + len(sentences), + batch_size, + desc="Batches", + disable=not show_progress_bar, + ): + sentences_batch = sentences_sorted[start_index : start_index + batch_size] + embeddings = self._encode(sentences_batch, device=device, convert_to_numpy=convert_to_numpy) + all_embeddings.append(embeddings) + else: + num_proc = torch.cuda.device_count() + cuda_compatible_multiprocess = mp.get_context("spawn") + with cuda_compatible_multiprocess.Pool(num_proc) as p: + sentences_batches = [ + sentences_sorted[start_index : start_index + batch_size] + for start_index in range(0, len(sentences), batch_size) + ] + + progress_bar = tqdm( + total=len(sentences_batches), + desc="Batches", + disable=not show_progress_bar, + ) + results = [] + + def update(*args): + progress_bar.update() + + for batch in sentences_batches: + results.append( + p.apply_async( + self._encode, + args=(batch, None, convert_to_numpy, True), + callback=update, + ) + ) + + all_embeddings = [result.get() for result in results] + progress_bar.close() + + all_embeddings = torch.cat(all_embeddings, dim=0) + all_embeddings = all_embeddings[np.argsort(length_sorted_idx)] + all_embeddings = all_embeddings.to(torch.float32) + if convert_to_numpy: + all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) + return all_embeddings + + def save(self, output_path, merge_before_save=False, save_config=True): + if merge_before_save and isinstance(self.model, PeftModel): + self.model = self.model.merge_and_unload() + # Fixes the issue of saving - https://huggingface.co/McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse/discussions/1 + if hasattr(self.model, "_hf_peft_config_loaded"): + self.model._hf_peft_config_loaded = False + + self.model.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + llm2vec_config = { + "pooling_mode": self.pooling_mode, + "max_length": self.max_length, + "doc_max_length": self.doc_max_length, + "skip_instruction": self.skip_instruction, + } + + if save_config: + os.makedirs(output_path, exist_ok=True) + with open(f"{output_path}/llm2vec_config.json", "w") as fOut: + json.dump(llm2vec_config, fOut, indent=4) + + def _encode( + self, + sentences_batch, + device: Optional[str] = None, + convert_to_numpy: bool = False, + multiprocessing=False, + ): + if multiprocessing: + # multiprocessing only supports CUDA devices at this time, so we ignore the value of device + # and use cuda:rank for the device + rank = mp.current_process()._identity[0] + if device is None and torch.cuda.is_available(): + device = f"cuda:{rank % torch.cuda.device_count()}" + + self.to(device) + features = self.tokenize([self.prepare_for_tokenization(sentence) for sentence in sentences_batch]) + features = batch_to_device(features, device) + + with torch.no_grad(): + embeddings = self.forward(features) + embeddings = embeddings.detach() + embeddings = embeddings.cpu() + + return embeddings + + def _text_length(self, text: Union[List[int], List[List[int]]]): + """Help function to get the length for the input text. + + Text can be either a string (which means a single text) a list of ints (which means a single + tokenized text), or a tuple of list of ints (representing several text inputs to the model). + """ + if ( + isinstance(text, str) or (isinstance(text, list) and isinstance(text[0], int)) or len(text) == 0 + ): # Single text, list of ints, or empty + return len(text) + if isinstance(text, dict): # {key: value} case + return len(next(iter(text.values()))) + elif not hasattr(text, "__len__"): # Object has no len() method + return 1 + else: + return sum([len(t) for t in text]) + + def resize_token_embeddings( + self, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + ) -> nn.Embedding: + return self.model.resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of) + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) diff --git a/kimodo/model/llm2vec/llm2vec_wrapper.py b/kimodo/model/llm2vec/llm2vec_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..eb33c871b518ac2992b573986078895742d59437 --- /dev/null +++ b/kimodo/model/llm2vec/llm2vec_wrapper.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""LLM2Vec encoder wrapper for Kimodo text conditioning.""" + +import os + +import numpy as np +import torch + +from .llm2vec import LLM2Vec + + +class LLM2VecEncoder: + """LLM2Vec text embeddings.""" + + def __init__( + self, + base_model_name_or_path: str, + peft_model_name_or_path: str, + dtype: str, + llm_dim: int, + ) -> None: + torch_dtype = getattr(torch, dtype) + self.llm_dim = llm_dim + + cache_dir = os.environ.get("HUGGINGFACE_CACHE_DIR") + + if "TEXT_ENCODERS_DIR" in os.environ: + base_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], base_model_name_or_path) + peft_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], peft_model_name_or_path) + + self.model = LLM2Vec.from_pretrained( + base_model_name_or_path=base_model_name_or_path, + peft_model_name_or_path=peft_model_name_or_path, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + ) + self.model.eval() + for p in self.model.parameters(): + p.requires_grad = False + + def to(self, device: torch.device): + self.model = self.model.to(device) + return self + + def eval(self): + self.model.eval() + return self + + def get_device(self): + return self.model.model.device + + def __call__(self, text: list[str] | str): + is_string = False + if isinstance(text, str): + text = [text] + is_string = True + + with torch.no_grad(): + encoded_text = self.model.encode(text, batch_size=len(text), show_progress_bar=False) + + assert len(encoded_text.shape) + assert self.llm_dim == encoded_text.shape[-1] + + encoded_text = encoded_text[:, None] + lengths = np.ones(len(encoded_text), dtype=int).tolist() + + if is_string: + encoded_text = encoded_text[0] + lengths = lengths[0] + + encoded_text = torch.tensor(encoded_text).to(self.get_device()) + return encoded_text, lengths diff --git a/kimodo/model/llm2vec/models/__init__.py b/kimodo/model/llm2vec/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2504048d0ee1addb1d3c95cbc04aae2b59e3e68 --- /dev/null +++ b/kimodo/model/llm2vec/models/__init__.py @@ -0,0 +1,4 @@ +# from .bidirectional_gemma import GemmaBiForMNTP, GemmaBiModel +# from .bidirectional_llama import LlamaBiForMNTP, LlamaBiModel +# from .bidirectional_mistral import MistralBiForMNTP, MistralBiModel +# from .bidirectional_qwen2 import Qwen2BiForMNTP, Qwen2BiModel diff --git a/kimodo/model/llm2vec/models/attn_mask_utils.py b/kimodo/model/llm2vec/models/attn_mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..787f99172c52b3e340b037025ad6ac4c6f8f1929 --- /dev/null +++ b/kimodo/model/llm2vec/models/attn_mask_utils.py @@ -0,0 +1,181 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from typing import List, Optional, Tuple, Union + +import torch +from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + +def _prepare_4d_causal_attention_mask( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D + mask of shape `(batch_size, key_value_length)` + + Args: + attention_mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + inputs_embeds (`torch.Tensor`): + The embedded inputs as a torch Tensor. + past_key_values_length (`int`): + The length of the key value cache. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter( + is_causal=False, sliding_window=sliding_window + ) # is_causal=True in original implementation + + key_value_length = input_shape[-1] + past_key_values_length + + # 4d mask is passed through the layers + if attention_mask is not None and len(attention_mask.shape) == 2: + attention_mask = attn_mask_converter.to_4d( + attention_mask, + input_shape[-1], + key_value_length=key_value_length, + dtype=inputs_embeds.dtype, + ) + elif attention_mask is not None and len(attention_mask.shape) == 4: + expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + else: + # if the 4D mask has correct shape - invert it and fill with negative infinity + inverted_mask = 1.0 - attention_mask + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min + ) + else: + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], + input_shape[-1], + key_value_length, + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) + + return attention_mask + + +# Adapted from _prepare_4d_causal_attention_mask +def _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """Prepares the correct `attn_mask` argument to be used by + `torch.nn.functional.scaled_dot_product_attention`. + + In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and + `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). + """ + attn_mask_converter = AttentionMaskConverter( + is_causal=False, sliding_window=sliding_window + ) # is_causal=True in original implementation + + key_value_length = input_shape[-1] + past_key_values_length + batch_size, query_length = input_shape + + # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` + # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. + # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + is_tracing = ( + torch.jit.is_tracing() + or isinstance(inputs_embeds, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + + if attention_mask is not None: + # 4d mask is passed through + if len(attention_mask.shape) == 4: + expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + else: + # if the 4D mask has correct shape - invert it and fill with negative infinity + inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype) + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min + ) + return attention_mask + + elif not is_tracing and torch.all(attention_mask == 1): + if query_length == 1: + # For query_length == 1, causal attention and bi-directional attention are the same. + attention_mask = None + elif key_value_length == query_length: + attention_mask = None + else: + # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation + # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + pass + elif query_length > 1 and key_value_length != query_length: + # See the comment above (https://github.com/pytorch/pytorch/issues/108108). + # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`. + attention_mask = True + elif is_tracing: + raise ValueError( + 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.' + ) + + if attention_mask is None: + expanded_4d_mask = None + elif attention_mask is True: + expanded_4d_mask = attn_mask_converter.to_causal_4d( + input_shape[0], + input_shape[-1], + key_value_length, + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) + else: + expanded_4d_mask = attn_mask_converter.to_4d( + attention_mask, + input_shape[-1], + dtype=inputs_embeds.dtype, + key_value_length=key_value_length, + ) + + # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + if not is_tracing and expanded_4d_mask.device.type == "cuda": + expanded_4d_mask = AttentionMaskConverter._unmask_unattended( + expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min + ) + + return expanded_4d_mask diff --git a/kimodo/model/llm2vec/models/bidirectional_llama.py b/kimodo/model/llm2vec/models/bidirectional_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..f3e624e6d342ac37bd1eaecb85d9b76823de6864 --- /dev/null +++ b/kimodo/model/llm2vec/models/bidirectional_llama.py @@ -0,0 +1,224 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from peft import PeftModel +from torch import nn +from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel, LlamaPreTrainedModel +from transformers.cache_utils import Cache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + # LlamaFlashAttention2, + LlamaMLP, + LlamaRMSNorm, + LlamaRotaryEmbedding, + # LlamaSdpaAttention, +) +from transformers.utils import logging + +from .utils import is_transformers_attn_greater_or_equal_4_43_1 + +logger = logging.get_logger(__name__) + + +class ModifiedLlamaAttention(LlamaAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False + + +# class ModifiedLlamaFlashAttention2(LlamaFlashAttention2): +# def __init__(self, *args, **kwargs): +# super().__init__(*args, **kwargs) +# self.is_causal = False + + +# class ModifiedLlamaSdpaAttention(LlamaSdpaAttention): +# def __init__(self, *args, **kwargs): +# super().__init__(*args, **kwargs) +# self.is_causal = False + + +# LLAMA_ATTENTION_CLASSES = { +# "eager": ModifiedLlamaAttention, +# "flash_attention_2": ModifiedLlamaFlashAttention2, +# "sdpa": ModifiedLlamaSdpaAttention, +# } + + +class ModifiedLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig, layer_idx: int): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + + self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx) + # self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( + # config=config, layer_idx=layer_idx + # ) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class LlamaBiModel(LlamaModel): + _no_split_modules = ["ModifiedLlamaDecoderLayer"] + + def __init__(self, config: LlamaConfig): + if not is_transformers_attn_greater_or_equal_4_43_1(): + raise ValueError( + "The current implementation of LlamaEncoderModel follows modeling_llama.py of transformers version >= 4.43.1" + ) + LlamaPreTrainedModel.__init__(self, config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def _update_causal_mask( + self, + attention_mask, + input_tensor, + cache_position, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + # if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + # if AttentionMaskConverter._ignore_causal_mask_sdpa( + # attention_mask, + # inputs_embeds=input_tensor, + # past_key_values_length=past_seen_tokens, + # is_training=self.training, + # ): + # return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + causal_mask = torch.zeros( + (sequence_length, target_length), dtype=dtype, device=device + ) # in original implementation - torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + # Commenting out next 2 lines to disable causal masking + # if sequence_length != 1: + # causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + elif attention_mask.dim() == 4: + # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with + # cache. In that case, the 4D attention mask attends to the newest tokens only. + if attention_mask.shape[-2] < cache_position[0] + sequence_length: + offset = cache_position[0] + else: + offset = 0 + mask_shape = attention_mask.shape + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + causal_mask[ + : mask_shape[0], + : mask_shape[1], + offset : mask_shape[2] + offset, + : mask_shape[3], + ] = mask_slice + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class LlamaBiForMNTP(LlamaForCausalLM): + def __init__(self, config): + LlamaPreTrainedModel.__init__(self, config) + self.model = LlamaBiModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # getter for PEFT model + def get_model_for_peft(self): + return self.model + + # setter for PEFT model + def set_model_for_peft(self, model: PeftModel): + self.model = model + + # save the PEFT model + def save_peft_model(self, path): + self.model.save_pretrained(path) diff --git a/kimodo/model/llm2vec/models/utils.py b/kimodo/model/llm2vec/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..92ed8ec7058668697e628da342621b7599fb5ca2 --- /dev/null +++ b/kimodo/model/llm2vec/models/utils.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import importlib.metadata + +from packaging import version +from transformers.utils.import_utils import _is_package_available + + +def is_transformers_attn_greater_or_equal_4_43_1(): + if not _is_package_available("transformers"): + return False + + return version.parse(importlib.metadata.version("transformers")) >= version.parse("4.43.1") diff --git a/kimodo/model/load_model.py b/kimodo/model/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b2bed994aa49a5e53e4187caa166da46aaaa827c --- /dev/null +++ b/kimodo/model/load_model.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Load Kimodo diffusion models from local checkpoints or Hugging Face.""" + +from pathlib import Path +from typing import Optional + +from huggingface_hub import snapshot_download +from omegaconf import OmegaConf + +from .loading import ( + AVAILABLE_MODELS, + DEFAULT_MODEL, + DEFAULT_TEXT_ENCODER_URL, + MODEL_NAMES, + TMR_MODELS, + get_env_var, + instantiate_from_dict, +) +from .registry import get_model_info, resolve_model_name + +DEFAULT_TEXT_ENCODER = "llm2vec" +TEXT_ENCODER_PRESETS = { + "llm2vec": { + "target": "kimodo.model.LLM2VecEncoder", + "kwargs": { + "base_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp", + "peft_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised", + "dtype": "bfloat16", + "llm_dim": 4096, + }, + } +} + + +def _resolve_hf_model_path(modelname: str) -> Path: + """Resolve model name to a local path, using Hugging Face cache or CHECKPOINT_DIR.""" + try: + repo_id = MODEL_NAMES[modelname] + except KeyError: + raise ValueError(f"Model '{modelname}' not found. Available models: {MODEL_NAMES.keys()}") + + local_cache = get_env_var("LOCAL_CACHE", "False").lower() == "true" + if not local_cache: + snapshot_dir = snapshot_download(repo_id=repo_id) # will check online no matter what + return Path(snapshot_dir) + + try: + snapshot_dir = snapshot_download(repo_id=repo_id, local_files_only=True) # will check local cache only + return Path(snapshot_dir) + except Exception: + # if local cache is not found, download from online + try: + snapshot_dir = snapshot_download(repo_id=repo_id) + return Path(snapshot_dir) + except Exception: + raise RuntimeError(f"Could not resolve model '{modelname}' from Hugging Face (repo: {repo_id}). ") from None + + +def _build_api_text_encoder_conf(text_encoder_url: str) -> dict: + return { + "_target_": "kimodo.model.text_encoder_api.TextEncoderAPI", + "url": text_encoder_url, + } + + +def _build_local_text_encoder_conf() -> dict: + text_encoder_name = get_env_var("TEXT_ENCODER", DEFAULT_TEXT_ENCODER) + if text_encoder_name not in TEXT_ENCODER_PRESETS: + available = ", ".join(sorted(TEXT_ENCODER_PRESETS)) + raise ValueError(f"Unknown TEXT_ENCODER='{text_encoder_name}'. Available: {available}") + + preset = TEXT_ENCODER_PRESETS[text_encoder_name] + return { + "_target_": preset["target"], + **preset["kwargs"], + } + + +def _select_text_encoder_conf(text_encoder_url: str) -> dict: + # TEXT_ENCODER_MODE options: + # - "api": force TextEncoderAPI + # - "local": force local LLM2VecEncoder + # - "auto": try API first, fallback to local if unreachable + mode = get_env_var("TEXT_ENCODER_MODE", "auto").lower() + if mode == "local": + return _build_local_text_encoder_conf() + if mode == "api": + return _build_api_text_encoder_conf(text_encoder_url) + + api_conf = _build_api_text_encoder_conf(text_encoder_url) + try: + text_encoder = instantiate_from_dict(api_conf) + # Probe availability early so inference doesn't fail later. + text_encoder(["healthcheck"]) + return api_conf + except Exception as error: + print( + "Text encoder service is unreachable, falling back to local LLM2Vec " + f"encoder. ({type(error).__name__}: {error})" + ) + return _build_local_text_encoder_conf() + + +def load_model( + modelname=None, + device=None, + eval_mode: bool = True, + default_family: Optional[str] = "Kimodo", + return_resolved_name: bool = False, +): + """Load a kimodo model by name (e.g. 'g1', 'soma'). + + Resolution of partial/full names (e.g. Kimodo-SOMA-RP-v1, SOMA) is done + inside this function using default_family when the name is not a known + short key. + + Args: + modelname: Model identifier; uses DEFAULT_MODEL if None. Can be a short key, + a full name (e.g. Kimodo-SOMA-RP-v1), or a partial name; unknown names + are resolved via resolve_model_name using default_family. + device: Target device for the model (e.g. 'cuda', 'cpu'). + eval_mode: If True, set model to eval mode. + default_family: Used when modelname is not in AVAILABLE_MODELS to resolve + partial names ("Kimodo" for demo/generation, "TMR" for embed script). + Default "Kimodo". + return_resolved_name: If True, return (model, resolved_short_key). If False, + return only the model. + + Returns: + Loaded model in eval mode, or (model, resolved short key) if + return_resolved_name is True. + + Raises: + ValueError: If modelname is not in AVAILABLE_MODELS and cannot be resolved. + FileNotFoundError: If config.yaml is missing in the checkpoint folder. + """ + if modelname is None: + modelname = DEFAULT_MODEL + if modelname not in AVAILABLE_MODELS: + if default_family is not None: + modelname = resolve_model_name(modelname, default_family) + else: + raise ValueError( + f"""The model is not recognized. + Please choose between: {AVAILABLE_MODELS}""" + ) + + resolved_modelname = modelname + + # In case, we specify a custom checkpoint directory + configured_checkpoint_dir = get_env_var("CHECKPOINT_DIR") + if configured_checkpoint_dir: + print(f"CHECKPOINT_DIR is set to {configured_checkpoint_dir}, checking the local cache...") + # Checkpoint folders are named by display name (e.g. Kimodo-SOMA-RP-v1) + info = get_model_info(modelname) + checkpoint_folder_name = info.display_name if info is not None else modelname + model_path = Path(configured_checkpoint_dir) / checkpoint_folder_name + if not model_path.exists() and modelname != checkpoint_folder_name: + # Fallback: try short_key for backward compatibility + model_path = Path(configured_checkpoint_dir) / modelname + if not model_path.exists(): + print(f"Model folder not found at '{model_path}', downloading it from Hugging Face...") + model_path = _resolve_hf_model_path(modelname) + else: + # Otherwise, we load the model from the local cache or download it from Hugging Face. + model_path = _resolve_hf_model_path(modelname) + + model_config_path = model_path / "config.yaml" + if not model_config_path.exists(): + raise FileNotFoundError(f"The model checkpoint folder exists but config.yaml is missing: {model_config_path}") + + model_conf = OmegaConf.load(model_config_path) + + if modelname in TMR_MODELS: + # Same process at the moment for TMR and Kimodo + pass + + text_encoder_url = get_env_var("TEXT_ENCODER_URL", DEFAULT_TEXT_ENCODER_URL) + runtime_conf = OmegaConf.create( + { + "checkpoint_dir": str(model_path), + "text_encoder": _select_text_encoder_conf(text_encoder_url), + } + ) + model_cfg = OmegaConf.to_container(OmegaConf.merge(model_conf, runtime_conf), resolve=True) + model_cfg.pop("checkpoint_dir", None) + + model = instantiate_from_dict(model_cfg, overrides={"device": device}) + if eval_mode: + model = model.eval() + if return_resolved_name: + return model, resolved_modelname + return model diff --git a/kimodo/model/loading.py b/kimodo/model/loading.py new file mode 100644 index 0000000000000000000000000000000000000000..b2636a210871cc08e8ab2c667807892e5f9cc539 --- /dev/null +++ b/kimodo/model/loading.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Model loading utilities: checkpoints, registry, env, and Hydra-based instantiation.""" + +import os +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import torch +from hydra.utils import instantiate +from omegaconf import OmegaConf +from safetensors.torch import load_file as load_safetensors + +from .registry import ( + AVAILABLE_MODELS, + DEFAULT_MODEL, + DEFAULT_TEXT_ENCODER_URL, + KIMODO_MODELS, + MODEL_NAMES, + TMR_MODELS, +) + + +def get_env_var(name: str, default: Optional[str] = None) -> Optional[str]: + """Return environment variable value, or default if unset/empty.""" + return os.environ.get(name) or default + + +def instantiate_from_dict( + cfg: Dict[str, Any], + overrides: Optional[Dict[str, Any]] = None, +): + """Instantiate an object from a config dict (e.g. from OmegaConf.to_container). + + The dict must contain _target_ with a fully qualified class path. Nested configs are + instantiated recursively. + """ + if overrides: + cfg = {**cfg, **overrides} + conf = OmegaConf.create(cfg) + return instantiate(conf) + + +def load_checkpoint_state_dict(ckpt_path: Union[str, Path]) -> dict: + """Load a state dict from a checkpoint file. + + If the checkpoint is a dict with a 'state_dict' key (e.g. PyTorch Lightning), + that is returned; otherwise the whole checkpoint is treated as the state dict. + + Args: + ckpt_path: Path to the checkpoint file. + + Returns: + state_dict suitable for model.load_state_dict(). + """ + ckpt_path = str(ckpt_path) + + if ckpt_path.endswith(".safetensors"): + state_dict = load_safetensors(ckpt_path) + else: + checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif isinstance(checkpoint, dict): + state_dict = checkpoint + else: + raise ValueError(f"Unsupported checkpoint format: {ckpt_path}") + return {key: val.detach().cpu() for key, val in state_dict.items()} + + +__all__ = [ + "get_env_var", + "instantiate_from_dict", + "KIMODO_MODELS", + "TMR_MODELS", + "AVAILABLE_MODELS", + "MODEL_NAMES", + "DEFAULT_MODEL", + "DEFAULT_TEXT_ENCODER_URL", + "load_checkpoint_state_dict", +] diff --git a/kimodo/model/registry.py b/kimodo/model/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..ba3718128212adc1ad4c0e3afaa52faa9b30a78b --- /dev/null +++ b/kimodo/model/registry.py @@ -0,0 +1,473 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Registry of model names and Hugging Face repo IDs for Kimodo and TMR. + +Canonical source of truth is the list of repo IDs. Short keys (e.g. soma-rp) and metadata (dataset, +skeleton, version, display name) are derived by parsing. +""" + +import re +from dataclasses import dataclass +from typing import Optional + +# Canonical list: repo IDs in the same syntax as Hugging Face (org/Model-Name-v1). +# Parser expects: org/Family-SKELETON-DATASET-version (e.g. Kimodo-SOMA-RP-v1). +KIMODO_REPO_IDS = [ + "nvidia/Kimodo-SOMA-RP-v1", + "nvidia/Kimodo-SMPLX-RP-v1", + "nvidia/Kimodo-G1-RP-v1", + "nvidia/Kimodo-SOMA-SEED-v1", + "nvidia/Kimodo-G1-SEED-v1", +] +TMR_REPO_IDS = [ + "nvidia/TMR-SOMA-RP-v1", +] + +# Repo ID without org, for display (e.g. Kimodo-SOMA-RP-v1). +_REPO_NAME_PATTERN = re.compile(r"^(Kimodo|TMR)-([A-Za-z0-9]+)-(RP|SEED)-v(\d+)$") + + +@dataclass +class ModelInfo: + """Structured metadata for one model, derived from its repo ID.""" + + repo_id: str + short_key: str + family: str + skeleton: str + dataset: str + version: str + display_name: str + + @property + def dataset_ui_label(self) -> str: + return "Rigplay" if self.dataset == "RP" else "SEED" + + +def _parse_repo_id(repo_id: str) -> Optional[ModelInfo]: + """Parse a repo ID into ModelInfo. + + Returns None if format is unrecognized. + """ + # repo_id is "org/Model-Name-v1" + if "/" in repo_id: + _, name = repo_id.split("/", 1) + else: + name = repo_id + m = _REPO_NAME_PATTERN.match(name) + if not m: + return None + family, skeleton, dataset, ver = m.groups() + # Normalize skeleton for display (as is for now) + skeleton_display = skeleton + # Include family so Kimodo-SOMA-RP and TMR-SOMA-RP have distinct keys. + short_key = f"{family.lower()}-{skeleton.lower()}-{dataset.lower()}" + return ModelInfo( + repo_id=repo_id, + short_key=short_key, + family=family, + skeleton=skeleton_display, + dataset=dataset, + version=f"v{ver}", + display_name=name, + ) + + +def _build_registry() -> tuple[list[ModelInfo], dict[str, str], list[str]]: + """Build model infos, short_key -> repo_id map, and list of short keys. + + When multiple versions exist for the same (family, skeleton, dataset), the base short_key (e.g. + kimodo-soma-rp) maps to the latest version's repo_id so that HF resolution finds the newest + model. + """ + + def _version_key(info: ModelInfo) -> int: + v = info.version + if v.startswith("v") and v[1:].isdigit(): + return int(v[1:]) + return 0 + + all_repos = KIMODO_REPO_IDS + TMR_REPO_IDS + infos: list[ModelInfo] = [] + for repo_id in all_repos: + info = _parse_repo_id(repo_id) + if info is None: + raise ValueError(f"Registry repo ID does not match expected pattern: {repo_id}") + infos.append(info) + + # Map each base short_key to the latest version's repo_id (by version number) + model_names: dict[str, str] = {} + seen_short_keys: set[str] = set() + for info in infos: + if info.short_key in seen_short_keys: + continue + seen_short_keys.add(info.short_key) + candidates = [ + i for i in infos if i.family == info.family and i.skeleton == info.skeleton and i.dataset == info.dataset + ] + if candidates: + latest = max(candidates, key=_version_key) + model_names[info.short_key] = latest.repo_id + + return infos, model_names, list(model_names.keys()) + + +MODEL_INFOS, MODEL_NAMES, _SHORT_KEYS = _build_registry() +AVAILABLE_MODELS = _SHORT_KEYS + +# Short-key lists for Kimodo vs TMR (load_model uses TMR_MODELS to branch). +KIMODO_MODELS = [info.short_key for info in MODEL_INFOS if info.family == "Kimodo"] +TMR_MODELS = [info.short_key for info in MODEL_INFOS if info.family == "TMR"] + +# Backward compatibility: FRIENDLY_NAMES for any code that still expects it. +FRIENDLY_NAMES = {info.short_key: info.display_name for info in MODEL_INFOS} + +DEFAULT_MODEL = "kimodo-soma-rp" +DEFAULT_TEXT_ENCODER_URL = "http://127.0.0.1:9550/" + +# Friendly names for skeleton dropdown (key -> label). +SKELETON_DISPLAY_NAMES = { + "SOMA": "SOMA Human Body", + "SMPLX": "SMPLX Human Body", + "G1": "Unitree G1 Humanoid Robot", +} + +# Order for skeleton dropdown: SOMA, SMPLX, G1. +SKELETON_ORDER = ("SOMA", "SMPLX", "G1") + + +def get_skeleton_display_name(skeleton_key: str) -> str: + """Return the UI label for a skeleton key (e.g. SOMA -> SOMA Human Body).""" + return SKELETON_DISPLAY_NAMES.get(skeleton_key, skeleton_key) + + +def get_skeleton_key_from_display_name(display_name: str) -> Optional[str]: + """Return the skeleton key for a UI label, or None.""" + for key, label in SKELETON_DISPLAY_NAMES.items(): + if label == display_name: + return key + return None + + +def get_skeleton_display_names_for_dataset(dataset_ui_label: str, family: Optional[str] = None) -> list[str]: + """Return skeleton UI labels for the given dataset. + + If family is set (e.g. "Kimodo"), only skeletons with a model of that family are included. + """ + keys = get_skeletons_for_dataset(dataset_ui_label, family=family) + return [get_skeleton_display_name(k) for k in keys] + + +def get_short_key(repo_id: str) -> Optional[str]: + """Return the short key for a repo ID, or None if not in registry.""" + for info in MODEL_INFOS: + if info.repo_id == repo_id: + return info.short_key + return None + + +def get_model_info(short_key: str) -> Optional[ModelInfo]: + """Return ModelInfo for a short key, or None if not found. + + When multiple versions share the same short_key, returns the one used for loading (the latest + version), so CHECKPOINT_DIR and HF use the same version. + """ + repo_id = MODEL_NAMES.get(short_key) + if repo_id is None: + return None + for info in MODEL_INFOS: + if info.repo_id == repo_id: + return info + return None + + +def get_short_key_from_display_name(display_name: str) -> Optional[str]: + """Return short_key for a display name (e.g. Kimodo-SOMA-RP-v1), or None.""" + for info in MODEL_INFOS: + if info.display_name == display_name: + return info.short_key + return None + + +def get_models_for_demo() -> list[ModelInfo]: + """Return all model infos in registry order (for demo model list).""" + return list(MODEL_INFOS) + + +def get_datasets(family: Optional[str] = None) -> list[str]: + """Return unique dataset UI labels (Rigplay, SEED) present in registry. + + If family is set (e.g. "Kimodo"), only datasets that have a model of that family are included. + """ + infos = MODEL_INFOS + if family is not None: + infos = [i for i in infos if i.family == family] + labels = set() + for info in infos: + labels.add(info.dataset_ui_label) + return sorted(labels) + + +def get_skeletons_for_dataset(dataset_ui_label: str, family: Optional[str] = None) -> list[str]: + """Return skeleton names that have a model for the given dataset. + + Order: SOMA, SMPLX, G1 (only those present for the dataset). + If family is set (e.g. "Kimodo"), only skeletons with a model of that + family are included. + """ + dataset = "RP" if dataset_ui_label == "Rigplay" else "SEED" + infos = MODEL_INFOS + if family is not None: + infos = [i for i in infos if i.family == family] + skeletons = set() + for info in infos: + if info.dataset == dataset: + skeletons.add(info.skeleton) + return [s for s in SKELETON_ORDER if s in skeletons] + + +def get_versions_for_dataset_skeleton(dataset_ui_label: str, skeleton: str) -> list[str]: + """Return version strings (e.g. v1) for the given dataset/skeleton. + + Sorted by version number so the last element is the highest (e.g. v1, v2). + """ + dataset = "RP" if dataset_ui_label == "Rigplay" else "SEED" + versions = [] + for info in MODEL_INFOS: + if info.dataset == dataset and info.skeleton == skeleton: + versions.append(info.version) + + # Sort by numeric part so v2 comes after v1. + def version_key(v: str) -> int: + if v.startswith("v") and v[1:].isdigit(): + return int(v[1:]) + return 0 + + return sorted(set(versions), key=version_key) + + +def get_models_for_dataset_skeleton( + dataset_ui_label: str, skeleton: str, family: Optional[str] = None +) -> list[ModelInfo]: + """Return model infos for the given dataset/skeleton, sorted by version (max first). + + Used to build the Version dropdown (options = full display names, one per model). If family is + set (e.g. "Kimodo"), only models of that family are returned. + """ + dataset = "RP" if dataset_ui_label == "Rigplay" else "SEED" + infos = [info for info in MODEL_INFOS if info.dataset == dataset and info.skeleton == skeleton] + if family is not None: + infos = [i for i in infos if i.family == family] + + def version_key(info: ModelInfo) -> int: + v = info.version + if v.startswith("v") and v[1:].isdigit(): + return int(v[1:]) + return 0 + + return sorted(infos, key=version_key, reverse=True) + + +def resolve_to_short_key(dataset_ui_label: str, skeleton: str, version: str) -> Optional[str]: + """Return the short key for (dataset, skeleton, version), or None.""" + for info in MODEL_INFOS: + if info.dataset_ui_label == dataset_ui_label and info.skeleton == skeleton and info.version == version: + return info.short_key + return None + + +# ----------------------------------------------------------------------------- +# Flexible model name resolution (partial names, case-insensitive, defaults) +# ----------------------------------------------------------------------------- + +_FAMILY_ALIASES = {"kimodo": "Kimodo", "tmr": "TMR"} +_DATASET_ALIASES = {"rp": "RP", "rigplay": "RP", "seed": "SEED"} +_SKELETON_ALIASES = { + "soma": "SOMA", + "smplx": "SMPLX", + "g1": "G1", +} + + +def _normalize_family(s: str) -> Optional[str]: + """Return canonical family (Kimodo/TMR) or None if unknown.""" + return _FAMILY_ALIASES.get(s.strip().lower()) + + +def _normalize_dataset(s: str) -> Optional[str]: + """Return canonical dataset (RP/SEED) or None if unknown.""" + return _DATASET_ALIASES.get(s.strip().lower()) + + +def _normalize_skeleton(s: str) -> Optional[str]: + """Return canonical skeleton (SOMA/SMPLX/G1) or None if unknown.""" + return _SKELETON_ALIASES.get(s.strip().lower()) + + +def _get_latest_for_family_skeleton_dataset(family: str, skeleton: str, dataset: str) -> Optional[ModelInfo]: + """Return the model info with the highest version for (family, skeleton, dataset).""" + candidates = [ + info for info in MODEL_INFOS if info.family == family and info.skeleton == skeleton and info.dataset == dataset + ] + if not candidates: + return None + + def version_key(info: ModelInfo) -> int: + v = info.version + if v.startswith("v") and v[1:].isdigit(): + return int(v[1:]) + return 0 + + return max(candidates, key=version_key) + + +def kimodo_short_key_for_skeleton_dataset(skeleton: str, dataset: str) -> Optional[str]: + """Return the latest Kimodo model short_key for ``skeleton`` and ``dataset`` (RP/SEED), or + None.""" + info = _get_latest_for_family_skeleton_dataset("Kimodo", skeleton, dataset) + return info.short_key if info is not None else None + + +def registry_skeleton_for_joint_count(nb_joints: int) -> str: + """Map motion joint count to registry skeleton key (SOMA / SMPLX / G1).""" + if nb_joints == 34: + return "G1" + if nb_joints == 22: + return "SMPLX" + if nb_joints in (77, 30): + return "SOMA" + raise ValueError(f"No Kimodo model registered for motion with J={nb_joints}") + + +# Optional version: Family-Skeleton-Dataset-vN or Family-Skeleton-Dataset +_RESOLVE_FULL_PATTERN = re.compile( + r"^(Kimodo|TMR|kimodo|tmr)[\-_]" r"([A-Za-z0-9]+)[\-_]" r"(RP|SEED|rp|seed)" r"(?:[\-_]v(\d+))?$", + re.IGNORECASE, +) +# Partial: Skeleton-Dataset or Skeleton or Dataset (no family) +_RESOLVE_PARTIAL_PATTERN = re.compile( + r"^([A-Za-z0-9]+)(?:[\-_](RP|SEED|rp|seed))?(?:[\-_]v(\d+))?$", + re.IGNORECASE, +) + + +def resolve_model_name(name: Optional[str], default_family: Optional[str] = None) -> str: + """Resolve a user-facing model name to a short_key. + + Accepts full names (e.g. Kimodo-SOMA-RP-v1), case-insensitive matching, + and partial names with defaults: dataset=RP, skeleton=SOMA, family from + default_family (Kimodo for demo/generation, TMR for embed script). + Omitted version resolves to the latest for that model. + + Args: + name: User-provided name (can be None or empty). + default_family: "Kimodo" or "TMR" when name is empty or omits family. + + Returns: + Short key (e.g. kimodo-soma-rp) for use with load_model / MODEL_NAMES. + + Raises: + ValueError: If name cannot be resolved or default_family is missing when needed. + """ + if name is not None: + name = name.strip() + if not name: + if default_family is None: + raise ValueError('Model name is empty; provide a name or set default_family ("Kimodo" or "TMR").') + fam = _normalize_family(default_family) + if fam is None: + raise ValueError(f"default_family must be 'Kimodo' or 'TMR', got {default_family!r}") + info = _get_latest_for_family_skeleton_dataset(fam, "SOMA", "RP") + if info is None: + raise ValueError(f"No model found for {fam}-SOMA-RP. Available: {list(MODEL_NAMES.keys())}") + return info.short_key + + # Exact short_key + if name in MODEL_NAMES: + return name + + # Case-insensitive match against short_key or display_name + name_lower = name.lower() + matches = [] + for info in MODEL_INFOS: + if name_lower == info.short_key.lower(): + matches.append(info) + disp = info.display_name.lower() + if name_lower == disp or name_lower == ("nvidia/" + disp): + matches.append(info) + if len(matches) == 1: + return matches[0].short_key + if len(matches) > 1: + return matches[0].short_key + + # Parsed full form: Family-Skeleton-Dataset or Family-Skeleton-Dataset-vN + m = _RESOLVE_FULL_PATTERN.match(name) + if m: + fam_raw, skel_raw, ds_raw, ver_num = m.groups() + fam = _normalize_family(fam_raw) + skel = _normalize_skeleton(skel_raw) + ds = _normalize_dataset(ds_raw) + if fam is not None and skel is not None and ds is not None: + if ver_num is not None: + version = f"v{ver_num}" + for info in MODEL_INFOS: + if info.family == fam and info.skeleton == skel and info.dataset == ds and info.version == version: + return info.short_key + else: + info = _get_latest_for_family_skeleton_dataset(fam, skel, ds) + if info is not None: + return info.short_key + + # Parsed partial: Skeleton-Dataset, Skeleton, or Dataset (use default_family) + if default_family is not None: + m = _RESOLVE_PARTIAL_PATTERN.match(name) + if m: + tok1, ds_raw, ver_num = m.groups() + fam = _normalize_family(default_family) + if fam is not None: + skel = _normalize_skeleton(tok1) + ds_candidate = _normalize_dataset(ds_raw) if ds_raw else None + if skel is not None and ds_candidate is not None: + ds = ds_candidate + elif skel is not None: + ds = "RP" + else: + skel = "SOMA" + ds = _normalize_dataset(tok1) if tok1 else "RP" + if ds is None: + ds = "RP" + if ver_num is not None: + version = f"v{ver_num}" + for info in MODEL_INFOS: + if ( + info.family == fam + and info.skeleton == skel + and info.dataset == ds + and info.version == version + ): + return info.short_key + else: + info = _get_latest_for_family_skeleton_dataset(fam, skel, ds) + if info is not None: + return info.short_key + + # Single token: skeleton or dataset + fam = _normalize_family(default_family) + if fam is not None: + skel = _normalize_skeleton(name) + if skel is not None: + info = _get_latest_for_family_skeleton_dataset(fam, skel, "RP") + if info is not None: + return info.short_key + ds = _normalize_dataset(name) + if ds is not None: + info = _get_latest_for_family_skeleton_dataset(fam, "SOMA", ds) + if info is not None: + return info.short_key + + raise ValueError( + f"Model name {name!r} could not be resolved. " + f"Use a short key (e.g. {list(MODEL_NAMES.keys())[:3]}...), " + "a full name (e.g. Kimodo-SOMA-RP-v1), or a partial (e.g. SOMA-RP, SOMA) " + "with default_family set." + ) diff --git a/kimodo/model/text_encoder_api.py b/kimodo/model/text_encoder_api.py new file mode 100644 index 0000000000000000000000000000000000000000..020194f7fb6f67a958d392585c627670c2ae333a --- /dev/null +++ b/kimodo/model/text_encoder_api.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Remote text encoder API client (Gradio) for motion generation.""" + +import logging + +import numpy as np +import torch +from gradio_client import Client + +# Suppress the [httpx] logs (GET requests) +logging.getLogger("httpx").setLevel(logging.WARNING) + +# Suppress internal gradio_client logs +logging.getLogger("gradio_client").setLevel(logging.WARNING) + + +class TextEncoderAPI: + """Text encoder API client for motion generation.""" + + def __init__(self, url: str): + self.client = Client(url, verbose=False) + self.device = "cpu" + self.dtype = torch.float + + def _create_np_random_name(self): + import uuid + + return str(uuid.uuid4()) + ".npy" + + def to(self, device=None, dtype=None): + if device is not None: + self.device = device + if dtype is not None: + self.dtype = dtype + return self + + def __call__(self, texts): + """Encode text prompts into tensors. + + Args: + texts (str | list[str]): text prompts to encode + + Returns: + tuple[torch.Tensor, list[int]]: encoded text tensors and their lengths + """ + if isinstance(texts, str): + texts = [texts] + + tensors = [] + lengths = [] + for text in texts: + filename = self._create_np_random_name() + + # Use a long result timeout to tolerate text-encoder cold-start (LLM2Vec model load ~60-120s). + result = self.client.submit( + text=text, + filename=filename, + api_name="/DemoWrapper", + ).result(timeout=300) + path = result[0]["value"] + tensor = np.load(path) + length = tensor.shape[0] + + tensors.append(tensor) + lengths.append(length) + + padded_tensor = np.zeros((len(lengths), max(lengths), tensors[0].shape[-1]), dtype=tensors[0].dtype) + for idx, (tensor, length) in enumerate(zip(tensors, lengths)): + padded_tensor[idx, :length] = tensor + + padded_tensor = torch.from_numpy(padded_tensor) + padded_tensor = padded_tensor.to(device=self.device, dtype=self.dtype) + return padded_tensor, lengths diff --git a/kimodo/model/tmr.py b/kimodo/model/tmr.py new file mode 100644 index 0000000000000000000000000000000000000000..b05c4498785e131e3032f8f1af477b055a787062 --- /dev/null +++ b/kimodo/model/tmr.py @@ -0,0 +1,382 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""TMR model: encoder, and text-to-motion retrieval head.""" + +import contextlib +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from einops import repeat +from torch import Tensor + +from kimodo.model import load_checkpoint_state_dict +from kimodo.motion_rep.feature_utils import length_to_mask +from kimodo.sanitize import sanitize_texts +from kimodo.skeleton import SkeletonBase, build_skeleton +from kimodo.tools import ensure_batched + + +class PositionalEncoding(nn.Module): + """Sinusoidal positional encoding for sequences (batch_first optional).""" + + def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None: + super().__init__() + self.batch_first = batch_first + + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + # Note: have to replace torch.exp() and math.log() with torch.pow() + # due to MKL exp() and ln() throws floating point exceptions on certain CPUs + div_term = torch.pow(10000.0, -torch.arange(0, d_model, 2).float() / d_model) + # div_term = torch.exp( + # torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) + # ) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer("pe", pe, persistent=False) + + def forward(self, x: Tensor) -> Tensor: + if self.batch_first: + x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] + else: + x = x + self.pe[: x.shape[0], :] + return self.dropout(x) + + +def load_ckpt(self, ckpt_path): + """Load model weights from checkpoint path.""" + state_dict = load_checkpoint_state_dict(ckpt_path) + self.load_state_dict(state_dict) + + +class ACTORStyleEncoder(nn.Module): + """Motion encoder in ACTOR style: optional motion_rep projection, VAE/MLP tokens, transformer.""" + + def __init__( + self, + motion_rep: Optional[nn.Module], + llm_shape: Optional[Tuple], + vae: bool, + latent_dim: int = 256, + ff_size: int = 1024, + num_layers: int = 4, + num_heads: int = 4, + dropout: float = 0.1, + activation: str = "gelu", + ckpt_path: Optional[str] = None, + ) -> None: + super().__init__() + + self.motion_rep = motion_rep + if motion_rep is not None and llm_shape is None: + nfeats = motion_rep.motion_rep_dim + elif motion_rep is None and llm_shape is not None: + nfeats = llm_shape[-1] + else: + raise ValueError + + self.nfeats = nfeats + self.projection = nn.Linear(nfeats, latent_dim) + + self.vae = vae + self.nbtokens = 2 if vae else 1 + self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim)) + + self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout=dropout, batch_first=True) + + seq_trans_encoder_layer = nn.TransformerEncoderLayer( + d_model=latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation, + batch_first=True, + ) + + self.seqTransEncoder = nn.TransformerEncoder( + seq_trans_encoder_layer, + num_layers=num_layers, + enable_nested_tensor=False, + ) + + if ckpt_path is not None: + load_ckpt(self, ckpt_path) + + def forward(self, x_dict: Dict) -> Tensor: + x = x_dict["x"] + mask = x_dict["mask"] + + x = self.projection(x) + + device = x.device + bs = len(x) + + tokens = repeat(self.tokens, "nbtoken dim -> bs nbtoken dim", bs=bs) + xseq = torch.cat((tokens, x), 1) + + token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device) + aug_mask = torch.cat((token_mask, mask), 1) + + # add positional encoding + xseq = self.sequence_pos_encoding(xseq) + final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask) + return final[:, : self.nbtokens] + + +class TMR(nn.Module): + r"""TMR: Text-to-Motion Retrieval inference code (no decoder) + Find more information about the model on the following website: + https://mathis.petrovich.fr/tmr + """ + + @classmethod + def from_args( + cls, + motion_rep: nn.Module, + llm_shape: tuple | list, + vae: bool, + latent_dim: int = 256, + ff_size: int = 1024, + num_layers: int = 4, + num_heads: int = 4, + dropout: float = 0.1, + activation: str = "gelu", + ckpt_folder: Optional[str] = None, + device: Optional[str] = None, + **kwargs, + ): + motion_encoder, top_text_encoder = None, None + + motion_encoder = ACTORStyleEncoder( + motion_rep=motion_rep, + llm_shape=None, + vae=vae, + latent_dim=latent_dim, + ff_size=ff_size, + num_layers=num_layers, + num_heads=num_heads, + dropout=dropout, + activation=activation, + ckpt_path=Path(ckpt_folder) / "motion_encoder.pt", + ).to(device) + + top_text_encoder = ACTORStyleEncoder( + motion_rep=None, + llm_shape=llm_shape, + vae=vae, + latent_dim=latent_dim, + ff_size=ff_size, + num_layers=num_layers, + num_heads=num_heads, + dropout=dropout, + activation=activation, + ckpt_path=Path(ckpt_folder) / "text_encoder.pt", + ).to(device) + return cls( + motion_encoder, + top_text_encoder, + vae, + device=device, + **kwargs, + ) + + def __init__( + self, + motion_encoder: nn.Module, + top_text_encoder: nn.Module, + vae: bool, + text_encoder: Optional = None, + fact: Optional[float] = None, + sample_mean: Optional[bool] = True, + unit_vector: Optional[bool] = False, + compute_grads: bool = False, + device: Optional[str] = None, + ) -> None: + super().__init__() + + self.motion_encoder = motion_encoder + self.text_encoder = top_text_encoder + self.raw_text_encoder = text_encoder + + self.motion_rep = None + self.skeleton = None + if self.motion_encoder is not None: + self.motion_rep = self.motion_encoder.motion_rep + if self.motion_rep is not None: + self.skeleton = self.motion_rep.skeleton + + self.compute_grads = compute_grads + + self.device = device + + # sampling parameters + self.vae = vae + self.fact = fact if fact is not None else 1.0 + self.sample_mean = sample_mean + self.unit_vector = unit_vector + + def full_text_encoder(self, texts: list[str]): + assert isinstance(texts, list), "The input should be batched." + # sanitize the texts first + # then encode the text, and then use the top text encoder + texts = sanitize_texts(texts) + text_feat, text_length = self.raw_text_encoder(texts) + if isinstance(text_length, list): + text_length = torch.tensor(text_length, device=self.device) + else: + text_length = text_length.to(self.device) + inputs = { + "x": text_feat.to(self.device), + "mask": length_to_mask(text_length, device=self.device), + } + return self.text_encoder(inputs) + + def _find_encoder(self, inputs, modality): + assert modality in ["text", "motion", "raw_text", "auto"] + + if modality == "text": + return self.text_encoder + elif modality == "motion": + return self.motion_encoder + elif modality == "raw_text": + return self.full_text_encoder + + if isinstance(inputs[0], str): + return self.full_text_encoder + + m_nfeats = self.motion_encoder.nfeats + t_nfeats = self.text_encoder.nfeats + + if m_nfeats == t_nfeats: + raise ValueError("Cannot automatically find the encoder, as they share the same input space.") + + nfeats = inputs["x"].shape[-1] + if nfeats == m_nfeats: + return self.motion_encoder + elif nfeats == t_nfeats: + return self.text_encoder + else: + raise ValueError("The inputs is not recognized.") + + def _encode( + self, + inputs, + modality: str = "auto", + sample_mean: Optional[bool] = None, + fact: Optional[float] = None, + return_distribution: bool = False, + unit_vector: Optional[bool] = None, + ): + sample_mean = self.sample_mean if sample_mean is None else sample_mean + fact = self.fact if fact is None else fact + unit_vector = self.unit_vector if unit_vector is None else unit_vector + + # Encode the inputs + encoder = self._find_encoder(inputs, modality) + encoded = encoder(inputs) + + # Sampling + if self.vae: + dists = encoded.unbind(1) + mu, logvar = dists + if sample_mean: + latent_vectors = mu + else: + # Reparameterization trick + std = logvar.exp().pow(0.5) + eps = std.data.new(std.size()).normal_() + latent_vectors = mu + fact * eps * std + else: + dists = None + (latent_vectors,) = encoded.unbind(1) + + if unit_vector: + latent_vectors = torch.nn.functional.normalize(latent_vectors, dim=-1) + + if return_distribution: + return latent_vectors, dists + + return latent_vectors + + @ensure_batched(posed_joints=4, lengths=1) + def encode_motion( + self, + posed_joints: torch.Tensor, + original_skeleton: Optional[SkeletonBase] = None, + lengths: Optional[torch.Tensor] = None, + unit_vector: Optional[bool] = None, + ): + # TODO here. + convert_ctx = torch.no_grad() if not self.compute_grads else contextlib.nullcontext() + + if original_skeleton is None: + original_skeleton = build_skeleton(posed_joints.shape[-2]) + + if lengths is None: + nbatch, nbframes = posed_joints.shape[:2] + device = posed_joints.device + assert nbatch == 1, "If lenghts is not provided, the input should not be batched." + lengths = torch.tensor([nbframes], device=device) + + # slice the posed joints if we use less joints + skel_slice = self.motion_rep.skeleton.get_skel_slice(original_skeleton) + posed_joints = posed_joints[..., skel_slice, :] + + with convert_ctx: + features = self.motion_rep( + posed_joints=posed_joints, + to_normalize=True, + lengths=lengths, + ) + mask = length_to_mask(lengths, device=features.device) + x_dict = {"x": features, "mask": mask} + latent_vectors = self._encode( + x_dict, + modality="motion", + unit_vector=unit_vector, + ) + return latent_vectors + + def encode_text( + self, + x_dict: Dict, + unit_vector: Optional[bool] = None, + ): + # TODO: make it ensure batched + convert_ctx = torch.no_grad() if not self.compute_grads else contextlib.nullcontext() + + with convert_ctx: + latent_vectors = self._encode( + x_dict, + modality="text", + unit_vector=unit_vector, + ) + return latent_vectors + + def encode_raw_text( + self, + texts: List[str], + unit_vector: Optional[bool] = None, + ): + is_batched = True + if isinstance(texts, str): + is_batched = False + texts = [texts] + + convert_ctx = torch.no_grad() if not self.compute_grads else contextlib.nullcontext() + + with convert_ctx: + latent_vectors = self._encode( + texts, + modality="raw_text", + unit_vector=unit_vector, + ) + if not is_batched: + latent_vectors = latent_vectors[0] + return latent_vectors diff --git a/kimodo/model/twostage_denoiser.py b/kimodo/model/twostage_denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..d14cf76085cfc1dbba486cb6470644bade4038d0 --- /dev/null +++ b/kimodo/model/twostage_denoiser.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Two-stage transformer denoiser: root stage then body stage for motion diffusion.""" + +import contextlib +from typing import Optional + +import torch +from torch import nn + +from .backbone import TransformerEncoderBlock +from .loading import load_checkpoint_state_dict + + +class TwostageDenoiser(nn.Module): + """Two-stage denoiser: first predicts global root features, then body features conditioned on local root.""" + + def __init__( + self, + motion_rep, + motion_mask_mode, + ckpt_path: Optional[str] = None, + **kwargs, + ): + """Build root and body transformer blocks; optionally load checkpoint from ckpt_path.""" + super().__init__() + self.motion_rep = motion_rep + self.motion_mask_mode = motion_mask_mode + + # it should be a dual motion_rep + # and be global by default + # global motion_rep as inpnut + input_dim = motion_rep.motion_rep_dim + will_concatenate = motion_mask_mode == "concat" + + # stage 1: root only + root_input_dim = input_dim * 2 if will_concatenate else input_dim + root_output_dim = motion_rep.global_root_dim + + self.root_model = TransformerEncoderBlock( + input_dim=root_input_dim, + output_dim=root_output_dim, + skeleton=self.motion_rep.skeleton, + **kwargs, + ) + + # replace the global root by the local root + local_motion_rep_dim = input_dim - motion_rep.global_root_dim + motion_rep.local_root_dim + + # stage 2: local body + body_input_dim = local_motion_rep_dim + ( + input_dim if will_concatenate else 0 + ) # body stage always takes in local root info for motion (but still the global mask) + + body_output_dim = input_dim - motion_rep.global_root_dim + self.body_model = TransformerEncoderBlock( + input_dim=body_input_dim, + output_dim=body_output_dim, + skeleton=self.motion_rep.skeleton, + **kwargs, + ) + + if ckpt_path: + self.load_ckpt(ckpt_path) + + def load_ckpt(self, ckpt_path: str) -> None: + """Load checkpoint from path; state dict keys are stripped of 'denoiser.backbone.' + prefix.""" + state_dict = load_checkpoint_state_dict(ckpt_path) + state_dict = {key.replace("denoiser.backbone.", ""): val for key, val in state_dict.items()} + self.load_state_dict(state_dict) + + def forward( + self, + x: torch.Tensor, + x_pad_mask: torch.Tensor, + text_feat: torch.Tensor, + text_feat_pad_mask: torch.Tensor, + timesteps: torch.Tensor, + first_heading_angle: Optional[torch.Tensor] = None, + motion_mask: Optional[torch.Tensor] = None, + observed_motion: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): [B, T, dim_motion] current noisy motion + x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not + text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts + text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not + timesteps (torch.Tensor): [B,] current denoising step + motion_mask + observed_motion + + Returns: + torch.Tensor: same size as input x + """ + + if self.motion_mask_mode == "concat": + if motion_mask is None or observed_motion is None: + motion_mask = torch.zeros_like(x) + observed_motion = torch.zeros_like(x) + x = x * (1 - motion_mask) + observed_motion * motion_mask + x_extended = torch.cat([x, motion_mask], axis=-1) + else: + x_extended = x + + # Stage 1: predict root motion in global + root_motion_pred = self.root_model( + x_extended, + x_pad_mask, + text_feat, + text_feat_pad_mask, + timesteps, + first_heading_angle, + ) # [B, T, 5] + + # Maybe pass this as argument instead of recomputing it + lengths = x_pad_mask.sum(-1) + + # Convert root pred to local rep + # At test-time want to allow gradient through for guidance + convert_ctx = torch.no_grad() if self.training else contextlib.nullcontext() + with convert_ctx: + root_motion_local = self.motion_rep.global_root_to_local_root( + root_motion_pred, + normalized=True, + lengths=lengths, + ) + if self.training: + root_motion_local = root_motion_local.detach() + + # concatenate the predicted local root with the body motion + body_x = x[..., self.motion_rep.body_slice] + x_new = torch.cat([root_motion_local, body_x], axis=-1) + + if self.motion_mask_mode == "concat": + x_new_extended = torch.cat([x_new, motion_mask], axis=-1) + else: + x_new_extended = x_new + + # Stage 2: predict local body motion based on local root + predicted_body = self.body_model( + x_new_extended, + x_pad_mask, + text_feat, + text_feat_pad_mask, + timesteps, + first_heading_angle, + ) + + # concatenate the predicted local body with the predicted root + output = torch.cat([root_motion_pred, predicted_body], axis=-1) + return output diff --git a/kimodo/motion_rep/__init__.py b/kimodo/motion_rep/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..721e0392fe88404410942e80d7ad120d1a91a88b --- /dev/null +++ b/kimodo/motion_rep/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Motion representation utilities.""" + +from .reps import KimodoMotionRep, MotionRepBase, TMRMotionRep + +__all__ = [ + "MotionRepBase", + "KimodoMotionRep", + "TMRMotionRep", +] diff --git a/kimodo/motion_rep/conditioning.py b/kimodo/motion_rep/conditioning.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c64da24dd143b7684b16a96429ab41ed15db3f --- /dev/null +++ b/kimodo/motion_rep/conditioning.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Constraint conditioning: build index and data dicts from constraint sets for the denoiser.""" + +from collections import defaultdict + +import torch + + +def build_condition_dicts(constraints_lst: list): + index_dict = defaultdict(list) + data_dict = defaultdict(list) + for constraint in constraints_lst: + constraint.update_constraints(data_dict, index_dict) + return index_dict, data_dict + + +def get_unique_index_and_data(indices_lst, data): + # unique + sort them by t + indices_unique, inverse = torch.unique(indices_lst, dim=0, return_inverse=True) + # pick first value for each unique (t, j) + first_idx = torch.zeros(indices_unique.size(0), dtype=torch.long, device=inverse.device) + first_idx.scatter_(0, inverse, torch.arange(len(inverse), device=inverse.device)) + assert (indices_lst[first_idx] == indices_unique).all() + # get the data + indices_lst = indices_lst[first_idx] + data = data[first_idx] + return indices_lst, data diff --git a/kimodo/motion_rep/feature_utils.py b/kimodo/motion_rep/feature_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..664aa99047f29629a79a520efb4e2a4421f0b445 --- /dev/null +++ b/kimodo/motion_rep/feature_utils.py @@ -0,0 +1,212 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Motion representation helpers: velocity, heading, masks, and rotation of features.""" + +from typing import List, Optional, Union + +import einops +import torch + +from kimodo.geometry import cont6d_to_matrix, matrix_to_cont6d +from kimodo.skeleton import SkeletonBase +from kimodo.tools import ensure_batched + + +def diff_angles(angles: torch.Tensor, fps: float) -> torch.Tensor: + """Compute frame-to-frame angular differences in radians, scaled by fps. + + Args: + angles: [..., T] batched sequences of rotation angles in radians. + fps: Sampling rate used to convert frame differences to per-second rate. + + Returns: + [..., T-1] difference between consecutive angles (rad/s). + """ + + cos = torch.cos(angles) + sin = torch.sin(angles) + + cos_diff = cos[..., 1:] * cos[..., :-1] + sin[..., 1:] * sin[..., :-1] + sin_diff = sin[..., 1:] * cos[..., :-1] - cos[..., 1:] * sin[..., :-1] + + # should be close to angles.diff() but more robust + # multiply by fps = 1 / dt + angles_diff = fps * torch.arctan2(sin_diff, cos_diff) + return angles_diff + + +@ensure_batched(positions=4, lengths=1) +def compute_vel_xyz( + positions: torch.Tensor, + fps: float, + lengths: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Compute the velocities from positions: dx/dt. Works with batches. The last velocity is duplicated to keep the same size. + + Args: + positions (torch.Tensor): [..., T, J, 3] xyz positions of a human skeleton + fps (float): frame per seconds + lengths (Optional[torch.Tensor]): [...] size of each input batched. If not provided, positions should not be batched + + Returns: + velocity (torch.Tensor): [..., T, J, 3] velocities computed from the positions + """ + device = positions.device + + if lengths is None: + assert positions.shape[0] == 1, "If lengths is not provided, the input should not be batched." + lengths = torch.tensor([len(positions)], device=device) + + # useful for indexing + range_len = torch.arange(len(lengths)) + + # compute velocities with fps + velocity = fps * (positions[:, 1:] - positions[:, :-1]) + # pading the velocity vector + vel_pad = torch.zeros_like(velocity[:, 0]) + velocity, _ = einops.pack([velocity, vel_pad], "batch * nbjoints dim") + + # repeat the last velocities + # with special care for different lengths with batches + velocity[(range_len, lengths - 1)] = velocity[(range_len, lengths - 2)] + return velocity + + +@ensure_batched(root_rot_angles=2, lengths=1) +def compute_vel_angle( + root_rot_angles: torch.Tensor, + fps: float, + lengths: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Compute the local root rotation velocity: dtheta/dt. + + Args: + root_rot_angles (torch.Tensor): [..., T] rotation angle (in radian) + fps (float): frame per seconds + lengths (Optional[torch.Tensor]): [...] size of each input batched. If not provided, root_rot_angles should not be batched + + Returns: + local_root_rot_vel (torch.Tensor): [..., T] local root rotation velocity (in radian/s) + """ + device = root_rot_angles.device + if lengths is None: + assert root_rot_angles.shape[0] == 1, "If lengths is not provided, the input should not be batched." + lengths = torch.tensor([len(root_rot_angles)], device=device) + + # useful for indexing + range_len = torch.arange(len(lengths)) + + local_root_rot_vel = diff_angles(root_rot_angles, fps) + pad_rot_vel_angles = torch.zeros_like(root_rot_angles[:, 0]) + local_root_rot_vel, _ = einops.pack( + [local_root_rot_vel, pad_rot_vel_angles], + "batch *", + ) + # repeat the last rotation angle + # with special care for different lengths with batches + local_root_rot_vel[(range_len, lengths - 1)] = local_root_rot_vel[(range_len, lengths - 2)] + return local_root_rot_vel + + +@ensure_batched(posed_joints=4) +def compute_heading_angle(posed_joints: torch.Tensor, skeleton: SkeletonBase) -> torch.Tensor: + """Compute the heading direction from joint positions using the hip vector. + + Args: + posed_joints: [B, T, J, 3] global joint positions. + skeleton: Skeleton instance used to get hip joint indices. + + Returns: + [B] heading angle in radians. + """ + # compute root heading for the sequence from hip positions + r_hip, l_hip = skeleton.hip_joint_idx + diff = posed_joints[:, :, r_hip] - posed_joints[:, :, l_hip] + heading_angle = torch.atan2(diff[..., 2], -diff[..., 0]) + return heading_angle + + +def length_to_mask( + length: Union[torch.Tensor, List], + max_len: Optional[int] = None, + device=None, +) -> torch.Tensor: + """Convert sequence lengths to a boolean validity mask. + + Args: + length: Sequence lengths, either a tensor ``[B]`` or a Python list. + max_len: Optional mask width. If omitted, uses ``max(length)``. + device: Optional device. When ``length`` is a list, this controls where + the new tensor is created. + + Returns: + A boolean tensor of shape ``[B, max_len]`` where ``True`` marks valid + timesteps. + """ + if isinstance(length, list): + if device is None: + device = "cpu" + length = torch.tensor(length, device=device) + + # Use requested device for output; move length if needed so mask and length match + if device is not None: + target = torch.device(device) + if length.device != target: + length = length.to(target) + device = length.device + + if max_len is None: + max_len = max(length) + + mask = torch.arange(max_len, device=device).expand(len(length), max_len) < length.unsqueeze(1) + return mask + + +class RotateFeatures: + """Helper that applies a global heading rotation to motion features.""" + + def __init__(self, angle: torch.Tensor): + """Precompute 2D and 3D rotation matrices for a batch of angles. + + Args: + angle: Rotation angle(s) in radians, shaped ``[B]``. + """ + self.angle = angle + + ## Create the necessary rotations matrices + cos, sin = torch.cos(angle), torch.sin(angle) + one, zero = torch.ones_like(angle), torch.zeros_like(angle) + + # 2D rotation transposed (sin are -sin) + self.corrective_mat_2d_T = torch.stack((cos, sin, -sin, cos), -1).reshape(angle.shape + (2, 2)) + # 3D rotation on Y axis + self.corrective_mat_Y = torch.stack((cos, zero, sin, zero, one, zero, -sin, zero, cos), -1).reshape( + angle.shape + (3, 3) + ) + self.corrective_mat_Y_T = self.corrective_mat_Y.transpose(1, 2).contiguous() + + def rotate_positions(self, positions: torch.Tensor): + """Rotate 3D positions around the Y axis.""" + return positions @ self.corrective_mat_Y_T + + def rotate_2d_positions(self, positions_2d: torch.Tensor): + """Rotate 2D ``(x, z)`` vectors in the ground plane.""" + return positions_2d @ self.corrective_mat_2d_T + + def rotate_rotations(self, rotations: torch.Tensor): + """Left-multiply global rotation matrices by the heading correction.""" + # "Rotate" the global rotations + # which means add an extra Y rotation after the transform + # so at the left R' = R_y R + # (since we use the convention x' = R x) + # "bik,btdkj->btdij" + + B, T, J = rotations.shape[:3] + BTJ = B * T * J + return ( + self.corrective_mat_Y[:, None, None].expand(B, T, J, 3, 3).reshape(BTJ, 3, 3) @ rotations.reshape(BTJ, 3, 3) + ).reshape(B, T, J, 3, 3) + + def rotate_6d_rotations(self, rotations_6d: torch.Tensor): + """Rotate 6D rotation features via matrix conversion.""" + return matrix_to_cont6d(self.rotate_rotations(cont6d_to_matrix(rotations_6d))) diff --git a/kimodo/motion_rep/feet.py b/kimodo/motion_rep/feet.py new file mode 100644 index 0000000000000000000000000000000000000000..89e6bdb155fb7e5e6be219c709a5856702515600 --- /dev/null +++ b/kimodo/motion_rep/feet.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Foot contact detection from joint positions and velocities.""" + +import torch + +from ..tools import ensure_batched + + +@ensure_batched(positions=4, velocity=4) +def foot_detect_from_pos_and_vel( + positions: torch.Tensor, + velocity: torch.Tensor, + skeleton, + vel_thres: float, + height_thresh: float, +) -> torch.Tensor: + """Compute foot contact labels using heuristics combining joint height and velocities. + + Args: + positions (torch.Tensor): [X, T, J, 3] global joint positions + velocity (torch.Tensor): [X, T, J, 3] velocities (already padded correctly), already multiplied by 1 / dt + vel_thres (float): threshold for joint velocity + height_thresh (float): threshold for joint height + + Returns: + torch.Tensor: [X, T, 4] contact labels for left and right foot joints + (heel/toe order follows the skeleton joint index definition), where + ``1`` denotes contact. + """ + + device = positions.device + # Use at most 2 foot joints per side (ankle + toe); SOMA77 defines a + # third end-effector (ToeEnd) that SOMA30 and other skeletons omit. + fid_l = skeleton.left_foot_joint_idx[:2] + fid_r = skeleton.right_foot_joint_idx[:2] + + velfactor, heightfactor = ( + torch.tensor([vel_thres, vel_thres], device=device), + torch.tensor([height_thresh, height_thresh], device=device), + ) + + feet_l_v = torch.linalg.norm(velocity[:, :, fid_l], axis=-1) + feet_l_h = positions[:, :, fid_l, 1] + + feet_l = torch.logical_and( + feet_l_v < velfactor, + feet_l_h < heightfactor, + ).to(positions.dtype) + + feet_r_v = torch.linalg.norm(velocity[:, :, fid_r], axis=-1) + feet_r_h = positions[:, :, fid_r, 1] + + feet_r = torch.logical_and( + feet_r_v < velfactor, + feet_r_h < heightfactor, + ).to(positions.dtype) + + foot_contacts = torch.cat((feet_l, feet_r), axis=-1) + return foot_contacts diff --git a/kimodo/motion_rep/reps/__init__.py b/kimodo/motion_rep/reps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd6ac79c5f0392c501a99018223fe6095a42c19 --- /dev/null +++ b/kimodo/motion_rep/reps/__init__.py @@ -0,0 +1,13 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Motion representation implementations: base, Kimodo, and TMR.""" + +from .base import MotionRepBase +from .kimodo_motionrep import KimodoMotionRep +from .tmr_motionrep import TMRMotionRep + +__all__ = [ + "MotionRepBase", + "KimodoMotionRep", + "TMRMotionRep", +] diff --git a/kimodo/motion_rep/reps/base.py b/kimodo/motion_rep/reps/base.py new file mode 100644 index 0000000000000000000000000000000000000000..aac2245d7c1bf38fcb112b65a343afa4eb1f822d --- /dev/null +++ b/kimodo/motion_rep/reps/base.py @@ -0,0 +1,300 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Base motion representation: feature layout, normalization, and conditioning helpers.""" + +import os +from typing import Optional + +import einops +import numpy as np +import torch +from einops import repeat + +from ...tools import ensure_batched +from ..conditioning import build_condition_dicts +from ..feature_utils import compute_vel_angle, compute_vel_xyz +from ..stats import Stats + + +def _require_split_stats_layout(stats_path: str) -> None: + """Raise if stats_path does not contain the required global_root, local_root, body subdirs.""" + subdirs = ("global_root", "local_root", "body") + missing = [] + for name in subdirs: + subpath = os.path.join(stats_path, name) + mean_path = os.path.join(subpath, "mean.npy") + if not os.path.isfile(mean_path): + missing.append(f"{subpath}/ (mean.npy)") + if missing: + raise FileNotFoundError( + f"Checkpoint stats must use the split layout with subfolders " + f"global_root/, local_root/, and body/ under '{stats_path}'. " + f"Missing or incomplete: {', '.join(missing)}. " + ) + + +class MotionRepBase: + """Base class for motion representations used in generation and conditioning. + + Subclasses define: + - ``size_dict``: feature blocks and their shapes, + - ``last_root_feature``: last entry of the root block, + - ``local_root_size_dict``: local-root feature layout, + and implement transform-specific methods such as ``__call__``, ``inverse``, + ``rotate``, ``translate_2d`` and ``create_conditions``. + """ + + def __init__( + self, + skeleton, + fps, + stats_path: Optional[str] = None, + ): + """Initialize feature slicing metadata and optional normalization stats.""" + + self.skeleton = skeleton + self.fps = fps + self.nbjoints = skeleton.nbjoints + + self.feature_names = list(self.size_dict.keys()) + self.ps = list(self.size_dict.values()) + self.nfeats_dict = {key: val.numel() for key, val in self.size_dict.items()} + feats_cumsum = np.cumsum([0] + list(self.nfeats_dict.values())).tolist() + self.slice_dict = {key: slice(feats_cumsum[i], feats_cumsum[i + 1]) for i, key in enumerate(self.feature_names)} + + self.motion_rep_dim = sum(self.nfeats_dict.values()) + self.root_slice = slice(0, self.slice_dict[self.last_root_feature].stop) + self.body_slice = slice(self.root_slice.stop, self.motion_rep_dim) + self.body_dim = self.body_slice.stop - self.body_slice.start + self.global_root_dim = self.root_slice.stop + self.local_root_dim = sum(val.numel() for val in self.local_root_size_dict.values()) + + if stats_path: + _require_split_stats_layout(stats_path) + self.global_root_stats = Stats(os.path.join(stats_path, "global_root")) + self.local_root_stats = Stats(os.path.join(stats_path, "local_root")) + self.body_stats = Stats(os.path.join(stats_path, "body")) + # self.stats not set; normalize/unnormalize apply per-part below + + def get_root_pos(self, features: torch.Tensor, fallback_to_smooth: bool = True): + """Extract root positions from a feature tensor. + + Supports both ``root_pos`` and ``smooth_root_pos`` representations. + """ + if "root_pos" in self.slice_dict: + return features[..., self.slice_dict["root_pos"]] + + if "smooth_root_pos" not in self.slice_dict: + raise TypeError("This motion rep should have either a root_pos or smooth_root_pos field") + + if fallback_to_smooth: + return features[:, :, self.slice_dict["smooth_root_pos"]] + + # else compute the root pos from the smooth root and local joints offset + smooth_root_pos = features[:, :, self.slice_dict["smooth_root_pos"]].clone() + local_joints_positions_flatten = features[..., self.slice_dict["local_joints_positions"]] + hips_offset = local_joints_positions_flatten[..., self.skeleton.root_idx : self.skeleton.root_idx + 3] + root_pos = torch.stack( + [ + smooth_root_pos[..., 0] + hips_offset[..., 0], + smooth_root_pos[..., 1], + smooth_root_pos[..., 2] + hips_offset[..., 2], + ], + axis=-1, + ) + return root_pos + + @ensure_batched(root_features=3, lengths=1) + def global_root_to_local_root( + self, + root_features: torch.Tensor, + normalized: bool, + lengths: Optional[torch.Tensor], + ): + """Convert global root features to local-root motion features. + + Args: + root_features: Root feature tensor containing root position and + global heading, shaped ``[B, T, D_root]``. + normalized: Whether ``root_features`` are normalized. + lengths: Optional valid lengths per sequence. + + Returns: + Tensor ``[B, T, 4]`` with local root rotational velocity, planar + velocity, and global root height. + """ + if normalized: + root_features = self.global_root_stats.unnormalize(root_features) + + [root_pos, global_root_heading] = einops.unpack(root_features, self.ps[:2], "batch time *") + cos, sin = global_root_heading.unbind(-1) + heading_angle = torch.arctan2(sin, cos) + + local_root_rot_vel = compute_vel_angle(heading_angle, self.fps, lengths=lengths) + local_root_vel = compute_vel_xyz( + root_pos[..., None, :], + self.fps, + lengths=lengths, + )[..., 0, [0, 2]] + global_root_y = root_pos[..., 1] + local_root_motion = torch.cat( + [ + local_root_rot_vel[..., None], + local_root_vel, + global_root_y[..., None], + ], + axis=-1, + ) + + if normalized: + local_root_motion = self.local_root_stats.normalize(local_root_motion) + return local_root_motion + + def get_root_heading_angle(self, features: torch.Tensor) -> torch.Tensor: + """Compute root heading angle from cosine/sine heading features.""" + global_root_heading = features[:, :, self.slice_dict["global_root_heading"]] + cos, sin = global_root_heading.unbind(-1) + return torch.arctan2(sin, cos) + + @ensure_batched(features=3) + def rotate_to( + self, + features: torch.Tensor, + target_angle: torch.Tensor, + return_delta_angle=False, + ): + """Rotate each sequence so frame-0 heading matches ``target_angle``.""" + # rotate so that the first frame angle is the target + # it put the motion_rep to the angle + current_first_angle = self.get_root_heading_angle(features)[:, 0] + delta_angle = target_angle - current_first_angle + rotated_features = self.rotate(features, delta_angle) + if return_delta_angle: + return rotated_features, delta_angle + return rotated_features + + @ensure_batched(features=3) + def rotate_to_zero( + self, + features: torch.Tensor, + return_delta_angle=False, + ): + """Rotate each sequence so frame-0 heading becomes zero.""" + target_angle = torch.zeros(len(features), device=features.device) + return self.rotate_to(features, target_angle, return_delta_angle=return_delta_angle) + + @ensure_batched(features=3) + def randomize_first_heading( + self, + features: torch.Tensor, + return_delta_angle=False, + ) -> torch.Tensor: + """Rotate each sequence to a random frame-0 heading.""" + target_heading_angle = torch.rand(features.shape[0]) * 2 * np.pi + return self.rotate_to( + features, + target_heading_angle, + return_delta_angle=return_delta_angle, + ) + + @ensure_batched(features=3, target_2d_pos=2) + def translate_2d_to( + self, + features: torch.Tensor, + target_2d_pos: torch.Tensor, + return_delta_pos: bool = False, + ) -> torch.Tensor: + """Translate each sequence so frame-0 root ``(x, z)`` matches a target.""" + root_pos = self.get_root_pos(features) + current_first_2d_pos = root_pos[:, 0, [0, 2]].clone() + delta_2d_pos = target_2d_pos - current_first_2d_pos + translated_features = self.translate_2d(features, delta_2d_pos) + if return_delta_pos: + return translated_features, delta_2d_pos + return translated_features + + @ensure_batched(features=3) + def translate_2d_to_zero( + self, + features: torch.Tensor, + return_delta_pos: bool = False, + ) -> torch.Tensor: + """Translate each sequence so frame-0 root ``(x, z)`` is at the origin.""" + target_2d_pos = torch.zeros(len(features), 2, device=features.device) + return self.translate_2d_to(features, target_2d_pos, return_delta_pos=return_delta_pos) + + @ensure_batched(features=3) + def canonicalize(self, features: torch.Tensor): + """Canonicalize heading and planar position at frame 0.""" + rotated_features = self.rotate_to_zero(features) + return self.translate_2d_to_zero(rotated_features) + + def normalize(self, features): + """Normalize features using per-part stats (global_root, local_root, body).""" + gr = slice(0, self.global_root_dim) + lr = slice(self.global_root_dim, self.global_root_dim + self.local_root_dim) + out = torch.empty_like(features, device=features.device, dtype=features.dtype) + out[..., gr] = self.global_root_stats.normalize(features[..., gr]) + out[..., lr] = self.local_root_stats.normalize(features[..., lr]) + out[..., self.body_slice] = self.body_stats.normalize(features[..., self.body_slice]) + return out + + def unnormalize(self, features): + """Undo feature normalization using per-part stats.""" + gr = slice(0, self.global_root_dim) + lr = slice(self.global_root_dim, self.global_root_dim + self.local_root_dim) + out = torch.empty_like(features, device=features.device, dtype=features.dtype) + out[..., gr] = self.global_root_stats.unnormalize(features[..., gr]) + out[..., lr] = self.local_root_stats.unnormalize(features[..., lr]) + out[..., self.body_slice] = self.body_stats.unnormalize(features[..., self.body_slice]) + return out + + def create_conditions_from_constraints( + self, + constraints_lst: list, + length: int, + to_normalize: bool, + device: str, + ): + """Create a conditioning tensor and mask from constraint objects.""" + index_dict, data_dict = build_condition_dicts(constraints_lst) + return self.create_conditions(index_dict, data_dict, length, to_normalize, device) + + def create_conditions_from_constraints_batched( + self, + constraints_lst: list | list[list], + lengths: torch.Tensor, + to_normalize: bool, + device: str, + ): + """Batched version of ``create_conditions_from_constraints``. + + Supports either one shared constraint list for all batch elements, or a per-sample list of + constraint lists. + """ + num_samples = len(lengths) + if not constraints_lst or not isinstance(constraints_lst[0], list): + # If no constraints, or constraints are shared across the batch, + # build once and repeat. + observed_motion, motion_mask = self.create_conditions_from_constraints( + constraints_lst, int(lengths.max()), to_normalize, device + ) + observed_motion = repeat(observed_motion, "t d -> b t d", b=num_samples) + motion_mask = repeat(motion_mask, "t d -> b t d", b=num_samples) + return observed_motion, motion_mask + + length = int(lengths.max()) + observed_motion_lst = [] + motion_mask_lst = [] + for constraints_lst_el in constraints_lst: + observed_motion, motion_mask = self.create_conditions_from_constraints( + constraints_lst_el, + length, + to_normalize, + device, + ) + observed_motion_lst.append(observed_motion) + motion_mask_lst.append(motion_mask) + observed_motion = torch.stack(observed_motion_lst, axis=0) + motion_mask = torch.stack(motion_mask_lst, axis=0) + return observed_motion, motion_mask diff --git a/kimodo/motion_rep/reps/kimodo_motionrep.py b/kimodo/motion_rep/reps/kimodo_motionrep.py new file mode 100644 index 0000000000000000000000000000000000000000..2471d6fc464cb3601c76bb2ece45b178d5af6768 --- /dev/null +++ b/kimodo/motion_rep/reps/kimodo_motionrep.py @@ -0,0 +1,301 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import einops +import torch +from torch import Tensor + +from kimodo.tools import to_numpy + +from ...geometry import cont6d_to_matrix, matrix_to_cont6d +from ...skeleton.kinematics import fk +from ...skeleton.transforms import global_rots_to_local_rots +from ...tools import ensure_batched +from ..conditioning import get_unique_index_and_data +from ..feature_utils import RotateFeatures, compute_heading_angle, compute_vel_xyz +from ..feet import foot_detect_from_pos_and_vel +from ..smooth_root import get_smooth_root_pos +from .base import MotionRepBase + + +class KimodoMotionRep(MotionRepBase): + """Global root / global joints rotations representation, relative to a smooth root.""" + + def __init__( + self, + skeleton, + fps, + stats_path: Optional[str] = None, + ): + nbjoints = skeleton.nbjoints + + self.size_dict = { + "smooth_root_pos": torch.Size([3]), + "global_root_heading": torch.Size([2]), + "local_joints_positions": torch.Size([nbjoints, 3]), + "global_rot_data": torch.Size([nbjoints, 6]), + "velocities": torch.Size([nbjoints, 3]), + "foot_contacts": torch.Size([4]), + } + self.last_root_feature = "global_root_heading" + self.local_root_size_dict = { + "local_root_rot_vel": torch.Size([1]), + "local_root_vel": torch.Size([2]), + "global_root_y": torch.Size([1]), + } + super().__init__(skeleton, fps, stats_path) + + @ensure_batched(local_joint_rots=5, root_positions=3, lengths=1) + def __call__( + self, + local_joint_rots: torch.Tensor, + root_positions: torch.Tensor, + to_normalize: bool, + lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Convert local rotations and root trajectory into smooth-root features. + + Args: + local_joint_rots: Local joint rotation matrices ``[B, T, J, 3, 3]``. + root_positions: Root positions ``[B, T, 3]``. + to_normalize: Whether to normalize output features. + lengths: Optional valid lengths for variable-length batches. + + Returns: + Motion features with shape ``[B, T, motion_rep_dim]``. + """ + device = local_joint_rots.device + if lengths is None: + assert local_joint_rots.shape[0] == 1, "If lenghts is not provided, the input should not be batched." + lengths = torch.tensor([local_joint_rots.shape[1]], device=device) + + ( + global_joints_rots, + global_joints_positions, + local_joints_positions_origin_is_pelvis, + ) = fk(local_joint_rots, root_positions, self.skeleton) + + root_heading_angle = compute_heading_angle(global_joints_positions, self.skeleton) + global_root_heading = torch.stack([torch.cos(root_heading_angle), torch.sin(root_heading_angle)], dim=-1) + + smooth_root_pos = get_smooth_root_pos(root_positions) + hips_offset = root_positions - smooth_root_pos + hips_offset[..., 1] = root_positions[..., 1] + local_joints_positions = local_joints_positions_origin_is_pelvis + hips_offset[:, :, None] + + velocities = compute_vel_xyz(global_joints_positions, self.fps, lengths=lengths) + foot_contacts = foot_detect_from_pos_and_vel(global_joints_positions, velocities, self.skeleton, 0.15, 0.10) + global_rot_data = matrix_to_cont6d(global_joints_rots) + + features, _ = einops.pack( + [ + smooth_root_pos, + global_root_heading, + local_joints_positions, + global_rot_data, + velocities, + foot_contacts, + ], + "batch time *", + ) + + if to_normalize: + features = self.normalize(features) + return features + + @ensure_batched(features=3, angle=1) + def rotate(self, features: torch.Tensor, angle: torch.Tensor): + """Rotate root/joint positional and rotational features by heading.""" + # assume it is not normalized + bs = features.shape[0] + device = features.device + [ + smooth_root_pos, + global_root_heading, + local_joints_positions, + global_rot_data, + velocities, + foot_contacts, + ] = einops.unpack(features, self.ps, "batch time *") + + if not isinstance(angle, torch.Tensor): + angle = torch.tensor(angle, device=device) + if len(angle.shape) == 0: + angle = angle.repeat(bs) + + RF = RotateFeatures(angle) + new_features, _ = einops.pack( + [ + RF.rotate_positions(smooth_root_pos), + RF.rotate_2d_positions(global_root_heading), + RF.rotate_positions(local_joints_positions), + RF.rotate_6d_rotations(global_rot_data), + RF.rotate_positions(velocities), + foot_contacts, + ], + "batch time *", + ) + return new_features + + @ensure_batched(features=3, translation_2d=2) + def translate_2d( + self, + features: torch.Tensor, + translation_2d: torch.Tensor, + ) -> torch.Tensor: + """Translate smooth root planar position by ``(dx, dz)``.""" + # only move on the ground + # If we need a translate_3D function, we should not forget to move the local_joints_positions as well + bs = features.shape[0] + if len(translation_2d.shape) == 1: + translation_2d = translation_2d.repeat(bs, 1) + + new_features = features.clone() + new_smooth_root_pos = new_features[:, :, self.slice_dict["smooth_root_pos"]] + new_smooth_root_pos[:, :, 0] += translation_2d[:, [0]] + new_smooth_root_pos[:, :, 2] += translation_2d[:, [1]] + return new_features + + @ensure_batched(features=3) + def inverse( + self, + features: torch.Tensor, + is_normalized: bool, + posed_joints_from="rotations", + return_numpy: bool = False, + ) -> torch.Tensor: + """Decode smooth-root features into motion tensors.""" + assert posed_joints_from in [ + "rotations", + "positions", + ], "posed_joints_from should 'rotations' or 'positions'" + + if is_normalized: + features = self.unnormalize(features) + + [ + smooth_root_pos, + global_root_heading, + local_joints_positions, + global_rot_data, + velocities, + foot_contacts, + ] = einops.unpack(features, self.ps, "batch time *") + + global_rot_mats = cont6d_to_matrix(global_rot_data) + local_rot_mats = global_rots_to_local_rots(global_rot_mats, self.skeleton) + + posed_joints_from_pos = local_joints_positions.clone() + posed_joints_from_pos[..., 0] += smooth_root_pos[..., None, 0] + posed_joints_from_pos[..., 2] += smooth_root_pos[..., None, 2] + root_positions = posed_joints_from_pos[..., self.skeleton.root_idx, :] + foot_contacts = foot_contacts > 0.5 + + if posed_joints_from == "rotations": + _, posed_joints, _ = self.skeleton.fk( + local_rot_mats, + root_positions, + ) + else: + posed_joints = posed_joints_from_pos + + output_tensor_dict = { + "local_rot_mats": local_rot_mats, + "global_rot_mats": global_rot_mats, + "posed_joints": posed_joints, + "root_positions": root_positions, + "smooth_root_pos": smooth_root_pos, + "foot_contacts": foot_contacts, + "global_root_heading": global_root_heading, + } + if return_numpy: + return to_numpy(output_tensor_dict) + return output_tensor_dict + + def create_conditions( + self, + index_dict: dict[Tensor], + data_dict: dict[Tensor], + length: int, + to_normalize: bool, + device: str, + ): + """Build sparse conditioning tensors for smooth-root representation.""" + # create empty features and mask to be filled in + observed_motion = torch.zeros(length, self.motion_rep_dim, device=device) + motion_mask = torch.zeros(length, self.motion_rep_dim, dtype=bool, device=device) + + def _cat_indices(indices_list: list[Tensor]) -> Tensor: + indices = torch.cat([torch.tensor(x) if not isinstance(x, Tensor) else x for x in indices_list]) + return indices.to(device=device, dtype=torch.long) + + def _match_obs_dtype(tensor: Tensor) -> Tensor: + return tensor.to(device=device, dtype=observed_motion.dtype) + + if (fname := "smooth_root_2d") in index_dict and index_dict[fname]: + indices = _cat_indices(index_dict[fname]) + indices, smooth_root_2d = get_unique_index_and_data(indices, torch.cat(data_dict[fname])) + smooth_root_2d = _match_obs_dtype(smooth_root_2d) + f_sliced = observed_motion[:, self.slice_dict["smooth_root_pos"]] + f_sliced[indices, 0] = smooth_root_2d[:, 0] + f_sliced[indices, 2] = smooth_root_2d[:, 1] + m_sliced = motion_mask[:, self.slice_dict["smooth_root_pos"]] + m_sliced[indices, 0] = True + m_sliced[indices, 2] = True + + if (fname := "root_y_pos") in index_dict and index_dict[fname]: + indices = _cat_indices(index_dict[fname]) + indices, root_pos_Y = get_unique_index_and_data(indices, torch.cat(data_dict[fname])) + root_pos_Y = _match_obs_dtype(root_pos_Y) + f_sliced = observed_motion[:, self.slice_dict["smooth_root_pos"]] + f_sliced[indices, 1] = root_pos_Y + m_sliced = motion_mask[:, self.slice_dict["smooth_root_pos"]] + m_sliced[indices, 1] = True + + if (fname := "global_root_heading") in index_dict and index_dict[fname]: + indices = _cat_indices(index_dict[fname]) + indices, global_root_heading = get_unique_index_and_data(indices, torch.cat(data_dict[fname])) + global_root_heading = _match_obs_dtype(global_root_heading) + f_sliced = observed_motion[:, self.slice_dict[fname]] + f_sliced[indices] = global_root_heading + m_sliced = motion_mask[:, self.slice_dict[fname]] + m_sliced[indices] = True + + if (fname := "global_joints_rots") in index_dict and index_dict[fname]: + indices_lst = _cat_indices(index_dict[fname]) + indices_lst, global_joints_rots = get_unique_index_and_data(indices_lst, torch.cat(data_dict[fname])) + global_joints_rots = _match_obs_dtype(global_joints_rots) + global_rot_data = matrix_to_cont6d(global_joints_rots) + f_sliced = observed_motion[:, self.slice_dict["global_rot_data"]] + masking = torch.zeros(len(f_sliced) * self.nbjoints, 6, device=device, dtype=bool) + masking[indices_lst.T[0] * self.nbjoints + indices_lst.T[1]] = True + masking = masking.reshape(len(f_sliced), self.nbjoints * 6) + f_sliced[masking] = global_rot_data.flatten() + m_sliced = motion_mask[:, self.slice_dict["global_rot_data"]] + m_sliced[masking] = True + + if (fname := "global_joints_positions") in index_dict and index_dict[fname]: + indices_lst = _cat_indices(index_dict[fname]) + indices_lst, global_joints_positions = get_unique_index_and_data(indices_lst, torch.cat(data_dict[fname])) + global_joints_positions = _match_obs_dtype(global_joints_positions) + T_indices = indices_lst[:, 0].contiguous() + _test = motion_mask[T_indices, self.slice_dict["smooth_root_pos"]] + if not _test[:, [0, 2]].all(): + raise ValueError("For constraining global positions, the smooth root should also be constrained.") + smooth_root_pos = observed_motion[T_indices, self.slice_dict["smooth_root_pos"]].clone() + local_reference = smooth_root_pos.clone() + local_reference[..., 1] = 0.0 + local_joints_positions = global_joints_positions - local_reference + f_sliced = observed_motion[:, self.slice_dict["local_joints_positions"]] + masking = torch.zeros(len(f_sliced) * self.nbjoints, 3, device=device, dtype=bool) + masking[indices_lst.T[0] * self.nbjoints + indices_lst.T[1]] = True + masking = masking.reshape(len(f_sliced), self.nbjoints * 3) + f_sliced[masking] = local_joints_positions.flatten() + m_sliced = motion_mask[:, self.slice_dict["local_joints_positions"]] + m_sliced[masking] = True + + if to_normalize: + observed_motion = self.normalize(observed_motion) + return observed_motion, motion_mask diff --git a/kimodo/motion_rep/reps/tmr_motionrep.py b/kimodo/motion_rep/reps/tmr_motionrep.py new file mode 100644 index 0000000000000000000000000000000000000000..3468b50539e1e286eba61e637d95dfc1802058ba --- /dev/null +++ b/kimodo/motion_rep/reps/tmr_motionrep.py @@ -0,0 +1,222 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""TMR motion representation: global root, global joints, velocities, and foot contacts.""" + +from typing import Optional + +import einops +import torch + +from ...skeleton.kinematics import fk +from ...tools import ensure_batched, to_numpy +from ..feature_utils import RotateFeatures, compute_heading_angle, compute_vel_xyz +from ..feet import foot_detect_from_pos_and_vel +from .base import MotionRepBase + + +class TMRMotionRep(MotionRepBase): + """Motion representation with global root and global joint positions. + + Feature layout: + - root position ``(x, y, z)`` + - root heading as ``(cos(theta), sin(theta))`` + - local joint positions (root removed, ground-referenced) + - global joint velocities + - binary foot contacts + """ + + def __init__( + self, + skeleton, + fps, + stats_path: Optional[str] = None, + ): + nbjoints = skeleton.nbjoints + + self.size_dict = { + "root_pos": torch.Size([3]), + "global_root_heading": torch.Size([2]), + "local_joints_positions": torch.Size([nbjoints - 1, 3]), + "velocities": torch.Size([nbjoints, 3]), + "foot_contacts": torch.Size([4]), + } + self.last_root_feature = "global_root_heading" + self.local_root_size_dict = { + "local_root_rot_vel": torch.Size([1]), + "local_root_vel": torch.Size([2]), + "global_root_y": torch.Size([1]), + } + super().__init__(skeleton, fps, stats_path) + + @ensure_batched(local_joint_rots=5, root_positions=3, posed_joints=4, lengths=1) + def __call__( + self, + local_joint_rots: Optional[torch.Tensor] = None, + root_positions: Optional[torch.Tensor] = None, + posed_joints: Optional[torch.Tensor] = None, + *, + to_normalize: bool, + lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Convert motion inputs to this feature representation. + + Args: + local_joint_rots: Local joint rotation matrices ``[B, T, J, 3, 3]``. + Required when ``posed_joints`` is not provided. + root_positions: Root translations ``[B, T, 3]``. Required when + ``posed_joints`` is not provided. + posed_joints: Optional precomputed global joint positions + ``[B, T, J, 3]``. If passed, FK is skipped. + to_normalize: Whether to normalize output features. + lengths: Optional valid lengths for variable-length batches. + + Returns: + Motion features with shape ``[B, T, motion_rep_dim]``. + """ + if posed_joints is not None: + device = posed_joints.device + nbatch, nbframes, nbjoints = posed_joints.shape[:3] + else: + device = local_joint_rots.device + nbatch, nbframes, nbjoints = local_joint_rots.shape[:3] + + if lengths is None: + assert nbatch == 1, "If lenghts is not provided, the input should not be batched." + lengths = torch.tensor([nbframes], device=device) + + if posed_joints is None: + _, global_positions, local_joints_positions_origin_is_pelvis = fk( + local_joint_rots, root_positions, self.skeleton + ) + else: + global_positions = posed_joints + root_positions = posed_joints[:, :, 0] + local_joints_positions_origin_is_pelvis = posed_joints - root_positions[:, :, None] + + root_heading_angle = compute_heading_angle(global_positions, self.skeleton) + global_root_heading = torch.stack([torch.cos(root_heading_angle), torch.sin(root_heading_angle)], dim=-1) + + ground_offset = 0 * root_positions + ground_offset[..., 1] = root_positions[..., 1] + local_joints_positions = local_joints_positions_origin_is_pelvis[:, :, 1:] + ground_offset[:, :, None] + velocities = compute_vel_xyz(global_positions, self.fps, lengths=lengths) + foot_contacts = foot_detect_from_pos_and_vel(global_positions, velocities, self.skeleton, 0.15, 0.10) + + features, _ = einops.pack( + [ + root_positions, + global_root_heading, + local_joints_positions, + velocities, + foot_contacts, + ], + "batch time *", + ) + + if to_normalize: + features = self.normalize(features) + return features + + @ensure_batched(features=3, angle=1) + def rotate(self, features: torch.Tensor, angle: torch.Tensor): + """Rotate all spatial features by a heading delta (radians).""" + # rotate by the angle + # it add the angle to the current features + # assume it is not normalized + bs = features.shape[0] + device = features.device + [ + root_pos, + global_root_heading, + local_joints_positions, + velocities, + foot_contacts, + ] = einops.unpack(features, self.ps, "batch time *") + + if not isinstance(angle, torch.Tensor): + angle = torch.tensor(angle, device=device) + if len(angle.shape) == 0: + angle = angle.repeat(bs) + + RF = RotateFeatures(angle) + new_features, _ = einops.pack( + [ + RF.rotate_positions(root_pos), + RF.rotate_2d_positions(global_root_heading), + RF.rotate_positions(local_joints_positions), + RF.rotate_positions(velocities), + foot_contacts, + ], + "batch time *", + ) + return new_features + + @ensure_batched(features=3, translation_2d=2) + def translate_2d( + self, + features: torch.Tensor, + translation_2d: torch.Tensor, + ) -> torch.Tensor: + """Translate root planar position by ``(dx, dz)``.""" + # only move on the ground + # For 3D, we should not forget to move the local_joints_positions as well + bs = features.shape[0] + if len(translation_2d.shape) == 1: + translation_2d = translation_2d.repeat(bs, 1) + + new_features = features.clone() + new_root_pos = new_features[:, :, self.slice_dict["root_pos"]] + new_root_pos[:, :, 0] += translation_2d[:, 0] + new_root_pos[:, :, 2] += translation_2d[:, 1] + return new_features + + @ensure_batched(features=3) + def inverse( + self, + features: torch.Tensor, + is_normalized: bool, + posed_joints_from="positions", + return_numpy: bool = False, + ) -> torch.Tensor: + """Decode features back to a motion dictionary. + + Args: + features: Feature tensor ``[B, T, D]``. + is_normalized: Whether input features are normalized. + posed_joints_from: Must be ``"positions"`` for this representation. + return_numpy: Whether to convert tensors to numpy arrays. + + Returns: + Dictionary containing reconstructed positions and auxiliary data. + """ + assert posed_joints_from == "positions" + if is_normalized: + features = self.unnormalize(features) + + [ + root_positions, + global_root_heading, + local_joints_positions, + velocities, + foot_contacts, + ] = einops.unpack(features, self.ps, "batch time *") + + dummy_root = 0 * local_joints_positions[:, :, [0]] + posed_joints_from_pos = torch.stack([dummy_root, local_joints_positions], axis=2) + posed_joints_from_pos[..., 0] += root_positions[..., None, 0] + posed_joints_from_pos[..., 2] += root_positions[..., None, 2] + root_positions = posed_joints_from_pos[..., self.skeleton.root_idx, :] + foot_contacts = foot_contacts > 0.5 + posed_joints = posed_joints_from_pos + + output_tensor_dict = { + "local_rot_mats": None, + "global_rot_mats": None, + "posed_joints": posed_joints, + "root_positions": root_positions, + "foot_contacts": foot_contacts, + "global_root_heading": global_root_heading, + } + if return_numpy: + return to_numpy(output_tensor_dict) + return output_tensor_dict diff --git a/kimodo/motion_rep/smooth_root.py b/kimodo/motion_rep/smooth_root.py new file mode 100644 index 0000000000000000000000000000000000000000..dd23f6f35d5a87241afbbf4d6e3106d636927bc7 --- /dev/null +++ b/kimodo/motion_rep/smooth_root.py @@ -0,0 +1,234 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Smooth root trajectory: ADMM-based smoother with margin constraints and get_smooth_root_pos helper.""" + +import math + +import numpy as np +import torch +from scipy import sparse +from scipy.sparse.linalg import splu + +from kimodo.tools import ensure_batched + + +class TrajectorySmoother: + """Modify trajectories to hit target values while respecting soft constraints. + + This smoother keeps the trajectory close to the original positions while minimizing + accelerations. Targets are enforced at specified frames via soft constraints. + """ + + def __init__( + self, + margins, + pos_weight=0.0, + loop=False, + admm_iters=100, + alpha_overrelax=1.0, + circle_project=False, + ): + """Initialize the TrajectorySmoother. + + Args: + margins: Array of margin values for each frame. + margins[i] < 0: unconstrained + margins[i] == 0: pinned on this frame + margins[i] > 0: can deviate within the margin + pos_weight: Weight for position preservation + loop: Whether the trajectory should loop + admm_iters: Number of ADMM iterations + """ + self.pos_weight = pos_weight + self.admm_iters = admm_iters + self.alpha_overrelax = alpha_overrelax + self.circle_project = circle_project + N = len(margins) + + # Store margin information as numpy arrays + self.margin_vals = margins + + # Build acceleration matrix A + a_data = [] + a_rows = [] + a_cols = [] + + for i in range(1, N - 1): + scale = 1.0 + a_data.extend([-scale, 2.0 * scale, -scale]) + a_rows.extend([i, i, i]) + a_cols.extend([i - 1, i, i + 1]) + + if loop: + # Add periodic accelerations + scale = 1.0 + a_data.extend([-scale, 2.0 * scale, -scale]) + a_rows.extend([0, 0, 0]) + a_cols.extend([N - 1, 0, 1]) + + scale = 1.0 + a_data.extend([-scale, 2.0 * scale, -scale]) + a_rows.extend([N - 1, N - 1, N - 1]) + a_cols.extend([N - 2, N - 1, 0]) + + A = sparse.csr_matrix((a_data, (a_rows, a_cols)), shape=(N, N)) + + # Build identity matrix + identity_matrix = sparse.eye(N) + + # Build system matrix M + M = pos_weight * identity_matrix + A.T @ A + + # Calculate ADMM step size + diag_max = max(abs(M.diagonal())) + self.admm_stepsize = 0.25 * np.sqrt(diag_max) + + M = M + self.admm_stepsize * identity_matrix + self.system_lu = splu(M.tocsc()) + + def smooth(self, targets, x0): + """Interpolate between reference positions while satisfying constraints. + + Args: + observations: Target positions for constrained frames (numpy array) + ref_positions: Reference positions defining original shape + (numpy array) + + Returns: + Interpolated positions (numpy array) + """ + x_target = targets.copy() + x = x0.copy() + z = np.zeros_like(x) + u = np.zeros_like(x) + + for _ in range(self.admm_iters): + self.z_update(z, x, x_target, u) + self.u_update(u, x, z) + self.x_update(x, z, u, x_target) + + return x + + def x_update(self, x, z, u, x_t): + """Update x in the ADMM iteration.""" + + # x = (wp * I + A^T A + p I)^-1 (wp * x_orig + p (z - u)) + r = self.pos_weight * x_t + self.admm_stepsize * (z - u) + x[:] = self.system_lu.solve(r) + + def z_update(self, z, x, z_t, u): + """Update z in the ADMM iteration using vectorized operations.""" + # Compute the difference from target for all margin locations at once + z[:] = x + u - z_t + + # Check if we need to project back to margin + z_diff_norms = np.linalg.norm(z, axis=1) + mask = z_diff_norms > self.margin_vals + if np.any(mask): + scale_factors = self.margin_vals[mask] / z_diff_norms[mask] + z[mask] *= scale_factors[:, np.newaxis] + + # Add back the target + z[:] += z_t + + if self.circle_project: + z[:] = z / (np.linalg.norm(z, axis=1, keepdims=True) + 1.0e-6) + + def u_update(self, u, x, z): + """Update u in the ADMM iteration using vectorized operations.""" + u[:] += self.alpha_overrelax * (x - z) + + +def smooth_signal(x, margins, pos_weight=0, alpha_overrelax=1.8, admm_iters=500, circle_project=False): + """Multigrid trajectory smoothing with margin constraints. + + Args: + x: Input trajectory ``[T, D]`` as a NumPy array. + margins: Allowed radius around each target frame ``[T]``. + pos_weight: Weight for staying close to the original signal. + alpha_overrelax: ADMM over-relaxation coefficient. + admm_iters: ADMM iterations per multigrid level. + circle_project: If ``True``, project each vector to the unit sphere. + + Returns: + Smoothed trajectory of shape ``[T, D]``. + """ + x_smoothed = x.copy() + x_smoothed[:] = x.mean(axis=0, keepdims=True) + + # smooth the signal, multigrid style by starting out coarse, + # doubling the resolution and repeating until we're at the full + # resolution, using the previous result as the initial guess. + levels = int(math.floor(math.log2(len(x)))) + levels = max(levels - 4, 1) + + stepsize = 2**levels + while True: + # smooth signals at this level: + num_steps = len(x_smoothed[::stepsize]) + smoother = TrajectorySmoother( + margins=margins[::stepsize], + pos_weight=pos_weight, + alpha_overrelax=alpha_overrelax, + admm_iters=admm_iters, + circle_project=circle_project, + ) + x_smoothed[::stepsize] = smoother.smooth(x[::stepsize], x_smoothed[::stepsize]) + + # interpolate to next level: + next_stepsize = stepsize // 2 + num_interleaved = len(x_smoothed[next_stepsize::stepsize]) + if num_interleaved == num_steps: + # linearly extrapolate the last value if we have to: + x_smoothed[next_stepsize::stepsize][-1] = ( + x_smoothed[::stepsize][-1] + (x_smoothed[::stepsize][-1] - x_smoothed[::stepsize][-2]) / 2 + ) + num_interleaved = num_interleaved - 1 + + # linearly interpolate the remaining values: + x_smoothed[next_stepsize::stepsize][:num_interleaved] = ( + x_smoothed[::stepsize][:-1] + x_smoothed[::stepsize][1:] + ) / 2 + + if stepsize == 1: + break + + stepsize //= 2 + + return x_smoothed + + +@ensure_batched(hip_translations=3) +def get_smooth_root_pos(hip_translations): + """Smooth root trajectory in the ground plane while preserving height. + + Args: + hip_translations: Root translations ``[B, T, 3]``. + + Returns: + Smoothed root translations ``[B, T, 3]`` where ``x/z`` are smoothed and + ``y`` remains unchanged. + """ + root_translations_xz = hip_translations[..., [0, 2]] + root_translations_y = hip_translations[..., [1]] + + batch_size, nframes = root_translations_xz.shape[:2] + margins = np.full(root_translations_xz.shape[1], 0.06) + + root_translations_smoothed_xz = [] + for batch in range(batch_size): + root_translations_smoothed_xz.append( + smooth_signal(root_translations_xz[batch].detach().cpu().numpy(), margins)[None] + ) + + root_translations_smoothed_xz = torch.tensor(np.concatenate(root_translations_smoothed_xz)) + + root_translations = torch.cat( + [ + root_translations_smoothed_xz.to(root_translations_y.device), + root_translations_y, + ], + dim=-1, + )[..., [0, 2, 1]] + + return root_translations diff --git a/kimodo/motion_rep/stats.py b/kimodo/motion_rep/stats.py new file mode 100644 index 0000000000000000000000000000000000000000..eecd9d417c1a4e54546e75afe3fb82f9e5e57e31 --- /dev/null +++ b/kimodo/motion_rep/stats.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Feature normalization statistics (mean/std) for motion representations.""" + +import logging +import os +from typing import Optional + +import numpy as np +import torch + +log = logging.getLogger(__name__) + + +class Stats(torch.nn.Module): + """Utility module for feature normalization statistics. + + Normalization follows: + ``(data - mean) / sqrt(std**2 + eps)`` + """ + + def __init__( + self, + folder: Optional[str] = None, + load: bool = True, + eps=1e-05, + ): + super().__init__() + self.folder = folder + self.eps = eps + if folder is not None and load: + self.load() + + def sliced(self, indices): + """Return a new ``Stats`` object containing selected feature indices.""" + new_stats = Stats(folder=self.folder, load=False, eps=self.eps) + new_stats.register_from_tensors( + self.mean[..., indices].clone(), + self.std[..., indices].clone(), + ) + return new_stats + + def load(self): + """Load ``mean.npy`` and ``std.npy`` from ``self.folder``.""" + mean_path = os.path.join(self.folder, "mean.npy") + std_path = os.path.join(self.folder, "std.npy") + if not os.path.exists(mean_path) or not os.path.exists(std_path): + raise FileNotFoundError( + f"Missing stats files in '{self.folder}'. Expected:\n" + f" - {mean_path}\n" + f" - {std_path}\n\n" + "Make sure the checkpoint/stats have been downloaded and are mounted into the container.\n" + "If you're using Docker Compose, run it from the repo root so `./:/workspace` mounts the correct directory." + ) + + mean = torch.from_numpy(np.load(mean_path)) + std = torch.from_numpy(np.load(std_path)) + self.register_from_tensors(mean, std) + + def register_from_tensors(self, mean: torch.Tensor, std: torch.Tensor): + """Register mean/std tensors as non-persistent buffers.""" + self.register_buffer("mean", mean, persistent=False) + self.register_buffer("std", std, persistent=False) + + def normalize(self, data: torch.Tensor) -> torch.Tensor: + """Normalize data using the stored statistics.""" + mean = self.mean.to(device=data.device, dtype=data.dtype) + std = self.std.to(device=data.device, dtype=data.dtype) + # adjust std with eps + return (data - mean) / torch.sqrt(std**2 + self.eps) + + def unnormalize(self, data: torch.Tensor) -> torch.Tensor: + """Undo normalization using the stored statistics.""" + mean = self.mean.to(device=data.device, dtype=data.dtype) + std = self.std.to(device=data.device, dtype=data.dtype) + # adjust std with eps + return data * torch.sqrt(std**2 + self.eps) + mean + + def is_loaded(self): + """Return whether statistics are currently available.""" + return hasattr(self, "mean") + + def get_dim(self): + """Return feature dimensionality.""" + return self.mean.shape[0] + + def save( + self, + folder: Optional[str] = None, + mean: Optional[torch.Tensor] = None, + std: Optional[torch.Tensor] = None, + ): + """Save statistics to ``folder`` as ``mean.npy`` and ``std.npy``.""" + if folder is None: + folder = self.folder + if folder is None: + raise ValueError("No folder to save stats") + + if mean is None and std is None: + try: + mean = self.mean.cpu().numpy() + std = self.std.cpu().numpy() + except AttributeError: + raise ValueError("Stats were not loaded") + + # don't override stats folder + os.makedirs(folder, exist_ok=False) + + np.save(os.path.join(folder, "mean.npy"), mean) + np.save(os.path.join(folder, "std.npy"), std) + + def __eq__(self, other): + return (self.mean.cpu() == other.mean.cpu()).all() and (self.std.cpu() == other.std.cpu()).all() + + # should define a hash value for pytorch, as we defined __eq__ + def __hash__(self): + # Convert mean and std to bytes for a consistent hash value + mean_hash = hash(self.mean.detach().cpu().numpy().tobytes()) + std_hash = hash(self.std.detach().cpu().numpy().tobytes()) + return hash((mean_hash, std_hash)) + + def __repr__(self): + return f'Stats(folder="{self.folder}")' diff --git a/kimodo/pipeline/__init__.py b/kimodo/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6bbb3514c48cfc71263d9f92cb75def0a40cab57 --- /dev/null +++ b/kimodo/pipeline/__init__.py @@ -0,0 +1,28 @@ +"""Pipeline utilities for prompt/script to Kimodo generation flows.""" + +from .blend_quality import ( + BlendGuardrailConfig, + TransitionSettings, + apply_transition_guardrails, + harmonize_scene_transitions, +) +from .script_to_kimodo import ( + CharacterKimodoPlan, + build_character_plan, + generator_request_to_plans, + run_multi_character_generation, +) +from .scheduler_runtime import SceneScheduleResult, run_scheduled_scene + +__all__ = [ + "CharacterKimodoPlan", + "BlendGuardrailConfig", + "TransitionSettings", + "apply_transition_guardrails", + "harmonize_scene_transitions", + "build_character_plan", + "generator_request_to_plans", + "run_multi_character_generation", + "SceneScheduleResult", + "run_scheduled_scene", +] diff --git a/kimodo/pipeline/blend_quality.py b/kimodo/pipeline/blend_quality.py new file mode 100644 index 0000000000000000000000000000000000000000..8d79e8f17634806ccbd90a39cda839e868181f12 --- /dev/null +++ b/kimodo/pipeline/blend_quality.py @@ -0,0 +1,116 @@ +"""Card 7 blend quality guardrails for transition blending safety and consistency.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class TransitionSettings: + """Transition settings passed to Kimodo generation.""" + + num_transition_frames: int + share_transition: bool + percentage_transition_override: float + + +@dataclass(frozen=True) +class BlendGuardrailConfig: + """Runtime safety bounds for transition blending.""" + + min_transition_frames: int = 1 + max_transition_frames: int = 12 + min_segment_frames_for_share: int = 12 + max_transition_ratio: float = 0.30 + max_shared_window_frames: int = 24 + harmonize_window: int = 2 + + +def _clamp(value: float, low: float, high: float) -> float: + return max(low, min(high, value)) + + +def apply_transition_guardrails( + segment_frames: list[int], + policies: list[str], + requested: TransitionSettings, + *, + config: BlendGuardrailConfig = BlendGuardrailConfig(), +) -> TransitionSettings: + """Clamp transition settings to safe ranges for short/long segments. + + Guardrails avoid transition windows that dominate short segments and reduce blending artifacts + for scripted interactions. + """ + if len(segment_frames) < 2: + safe_frames = int(_clamp(requested.num_transition_frames, config.min_transition_frames, config.max_transition_frames)) + return TransitionSettings( + num_transition_frames=safe_frames, + share_transition=False, + percentage_transition_override=0.0, + ) + + min_prev = min(segment_frames[:-1]) + min_next = min(segment_frames[1:]) + # Keep at least one non-transition frame in the shortest pair. + shortest_pair_budget = max(config.min_transition_frames, min(min_prev, min_next) - 1) + + safe_frames = int( + _clamp( + requested.num_transition_frames, + config.min_transition_frames, + min(config.max_transition_frames, shortest_pair_budget), + ) + ) + + has_cut = "cut" in policies + can_share = ( + requested.share_transition + and not has_cut + and min_prev >= config.min_segment_frames_for_share + and min_next >= config.min_segment_frames_for_share + ) + + if not can_share: + return TransitionSettings( + num_transition_frames=safe_frames, + share_transition=False, + percentage_transition_override=0.0, + ) + + safe_pct = _clamp(requested.percentage_transition_override, 0.0, config.max_transition_ratio) + + # Cap shared overlap by configured hard ceiling and shortest-pair budget. + max_pct_from_shared_window = max(0.0, (config.max_shared_window_frames - safe_frames) / max(1, min_prev)) + max_pct_from_shortest_pair = max(0.0, (shortest_pair_budget - safe_frames) / max(1, min_prev)) + safe_pct = min(safe_pct, max_pct_from_shared_window, max_pct_from_shortest_pair) + + return TransitionSettings( + num_transition_frames=safe_frames, + share_transition=True, + percentage_transition_override=float(safe_pct), + ) + + +def harmonize_scene_transitions( + settings_by_character: dict[str, TransitionSettings], + *, + config: BlendGuardrailConfig = BlendGuardrailConfig(), +) -> dict[str, TransitionSettings]: + """Nudge transition-frame counts toward a scene median for multi-character consistency.""" + if len(settings_by_character) < 2: + return settings_by_character + + frame_values = sorted(setting.num_transition_frames for setting in settings_by_character.values()) + median = frame_values[len(frame_values) // 2] + low = max(config.min_transition_frames, median - config.harmonize_window) + high = min(config.max_transition_frames, median + config.harmonize_window) + + harmonized: dict[str, TransitionSettings] = {} + for character_id, setting in settings_by_character.items(): + harmonized[character_id] = TransitionSettings( + num_transition_frames=int(_clamp(setting.num_transition_frames, low, high)), + share_transition=setting.share_transition, + percentage_transition_override=setting.percentage_transition_override, + ) + return harmonized \ No newline at end of file diff --git a/kimodo/pipeline/scheduler_runtime.py b/kimodo/pipeline/scheduler_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..24f40fbf0a208aebdd6c98669cce1a2a4cbbbf6e --- /dev/null +++ b/kimodo/pipeline/scheduler_runtime.py @@ -0,0 +1,139 @@ +"""Card 8 runtime orchestration: deterministic multi-character scheduling.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Optional + +from kimodo.pipeline.script_to_kimodo import run_multi_character_generation +from kimodo.schemas import GeneratorRequest +from kimodo.scheduler import ( + CharacterState, + CharacterSegmentState, + ConflictResolutionPolicy, + DeterministicLoop, +) + +LOGGER = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class SceneScheduleResult: + """Structured result for scheduled scene execution.""" + + outputs: dict[str, dict[str, Any]] + errors: dict[str, str] + plans: dict[str, Any] + state_hashes: list[str] + interactions: list[tuple[int, str, str]] + completed_segments: dict[str, int] + + +def _activate_next_segment(loop: DeterministicLoop, character_id: str, plan: Any, segment_index: int) -> None: + """Set active segment in loop state for one character.""" + slot = loop.characters[character_id] + slot.segment_state = CharacterSegmentState( + character_id=character_id, + segment_index=segment_index, + frames_elapsed=0, + total_frames=plan.num_frames[segment_index], + ) + segment = plan.segment_transition_policies[segment_index] + # Interaction target is encoded in planner request segments; set later in per-tick update. + slot.current_state = CharacterState.BUSY if segment != "cut" else CharacterState.TRANSITIONING + + +def run_scheduled_scene( + model: Any, + request: GeneratorRequest, + *, + fps: float, + seed: int = 42, + conflict_policy: ConflictResolutionPolicy = ConflictResolutionPolicy.COOLDOWN, + diffusion_steps: int = 100, + cfg_weight: Optional[list[float]] = None, + cfg_type: Optional[str] = None, + post_processing: bool = True, + root_margin: float = 0.04, + constraint_resolver: Optional[Any] = None, + continue_on_error: bool = False, +) -> SceneScheduleResult: + """Run generation then deterministic timeline scheduling for all characters in a scene.""" + LOGGER.info("card8.run_scheduled_scene.start scene_id=%s chars=%s", request.scene_id, len(request.characters)) + + outputs, errors, plans = run_multi_character_generation( + model, + request, + fps=fps, + diffusion_steps=diffusion_steps, + cfg_weight=cfg_weight, + cfg_type=cfg_type, + post_processing=post_processing, + root_margin=root_margin, + constraint_resolver=constraint_resolver, + continue_on_error=continue_on_error, + ) + + loop = DeterministicLoop( + fps=int(fps), + seed=seed, + conflict_policy=conflict_policy, + ) + + for priority, character in enumerate(request.characters): + loop.register_character(character.character_id, character.skeleton_type, priority=priority) + + segment_indices = {character.character_id: 0 for character in request.characters} + completed_segments = {character.character_id: 0 for character in request.characters} + + for character in request.characters: + plan = plans.get(character.character_id) + if plan is None: + continue + if not plan.num_frames: + continue + _activate_next_segment(loop, character.character_id, plan, segment_index=0) + first_segment = character.segments[0] + loop.characters[character.character_id].interaction_target = first_segment.interaction_target + + total_scene_frames = max((plan.total_frames for plan in plans.values()), default=0) + state_hashes: list[str] = [] + interactions: list[tuple[int, str, str]] = [] + + for _ in range(total_scene_frames): + tick = loop.advance_tick({}) + state_hashes.append(loop.get_state_hash()) + + for winner, loser in tick.interactions: + interactions.append((tick.tick_number, winner, loser)) + + for character_id in tick.completed_segments: + plan = plans.get(character_id) + if plan is None: + continue + completed_segments[character_id] += 1 + next_index = segment_indices[character_id] + 1 + if next_index < len(plan.num_frames): + segment_indices[character_id] = next_index + _activate_next_segment(loop, character_id, plan, next_index) + source_char = next(c for c in request.characters if c.character_id == character_id) + loop.characters[character_id].interaction_target = source_char.segments[next_index].interaction_target + else: + loop.characters[character_id].segment_state = None + loop.characters[character_id].interaction_target = None + + LOGGER.info( + "card8.run_scheduled_scene.exit scene_id=%s hashes=%s interactions=%s", + request.scene_id, + len(state_hashes), + len(interactions), + ) + return SceneScheduleResult( + outputs=outputs, + errors=errors, + plans=plans, + state_hashes=state_hashes, + interactions=interactions, + completed_segments=completed_segments, + ) \ No newline at end of file diff --git a/kimodo/pipeline/script_to_kimodo.py b/kimodo/pipeline/script_to_kimodo.py new file mode 100644 index 0000000000000000000000000000000000000000..08a53cece2a4efe04271cfb421cfbf5d0df15f05 --- /dev/null +++ b/kimodo/pipeline/script_to_kimodo.py @@ -0,0 +1,270 @@ +"""Card 6 mapping layer: planner scripts -> Kimodo generation inputs.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional + +from kimodo.pipeline.blend_quality import ( + TransitionSettings, + apply_transition_guardrails, + harmonize_scene_transitions, +) +from kimodo.schemas import CharacterGenerationState, GeneratorRequest, MotionSegment + +LOGGER = logging.getLogger(__name__) + +ConstraintResolver = Callable[[CharacterGenerationState, int], list[Any]] + + +@dataclass(frozen=True) +class CharacterKimodoPlan: + """Resolved per-character generation plan consumable by Kimodo model(...).""" + + character_id: str + prompts: list[str] + num_frames: list[int] + total_frames: int + constraint_lst: list[Any] + num_transition_frames: int + share_transition: bool + percentage_transition_override: float + segment_transition_policies: list[str] + + +def seconds_to_frames(duration_sec: float, fps: float) -> int: + """Convert segment duration in seconds to frames with a hard minimum of one frame.""" + return max(1, int(round(float(duration_sec) * float(fps)))) + + +def _transition_from_segments(segments: list[MotionSegment]) -> tuple[int, bool, float]: + """Aggregate segment transition policy into model-level transition parameters. + + Kimodo applies transition settings at call-level, so we choose a conservative aggregate: + - If any segment requests cut, disable shared transitions and lower overlap. + - If any segment requests overlap, increase transition blending. + - Otherwise use smooth defaults. + """ + policies = {segment.transition_policy.value for segment in segments} + + if "cut" in policies: + return 1, False, 0.0 + if "overlap" in policies: + return 8, True, 0.2 + if "hold" in policies: + return 3, False, 0.05 + return 5, True, 0.10 + + +def build_character_plan( + character: CharacterGenerationState, + *, + fps: float, + constraint_resolver: Optional[ConstraintResolver] = None, + apply_blend_guardrails: bool = True, +) -> CharacterKimodoPlan: + """Build one character generation plan from script segments. + + Entry/exit logs are intentionally verbose to make runtime mapping diagnostics explicit. + """ + LOGGER.info( + "card6.build_character_plan.start character_id=%s segments=%s fps=%.2f", + character.character_id, + len(character.segments), + fps, + ) + + prompts = [segment.action_text for segment in character.segments] + num_frames = [seconds_to_frames(segment.duration_sec, fps) for segment in character.segments] + total_frames = sum(num_frames) + + if character.constraints: + if constraint_resolver is None: + raise ValueError( + f"Constraints were provided for character '{character.character_id}' but no constraint_resolver was supplied" + ) + # Constraint translation is caller-owned because target constraint classes are model/skeleton specific. + constraint_lst = constraint_resolver(character, total_frames) + else: + constraint_lst = [] + + num_transition_frames, share_transition, percentage_transition_override = _transition_from_segments(character.segments) + + if apply_blend_guardrails: + guarded = apply_transition_guardrails( + num_frames, + [segment.transition_policy.value for segment in character.segments], + TransitionSettings( + num_transition_frames=num_transition_frames, + share_transition=share_transition, + percentage_transition_override=percentage_transition_override, + ), + ) + num_transition_frames = guarded.num_transition_frames + share_transition = guarded.share_transition + percentage_transition_override = guarded.percentage_transition_override + + plan = CharacterKimodoPlan( + character_id=character.character_id, + prompts=prompts, + num_frames=num_frames, + total_frames=total_frames, + constraint_lst=constraint_lst, + num_transition_frames=num_transition_frames, + share_transition=share_transition, + percentage_transition_override=percentage_transition_override, + segment_transition_policies=[segment.transition_policy.value for segment in character.segments], + ) + + LOGGER.info( + "card6.build_character_plan.exit character_id=%s total_frames=%s transitions=(frames=%s share=%s pct=%.2f)", + plan.character_id, + plan.total_frames, + plan.num_transition_frames, + plan.share_transition, + plan.percentage_transition_override, + ) + return plan + + +def generator_request_to_plans( + request: GeneratorRequest, + *, + fps: float, + constraint_resolver: Optional[ConstraintResolver] = None, + apply_blend_guardrails: bool = True, +) -> dict[str, CharacterKimodoPlan]: + """Map all characters in a generator request to executable per-character Kimodo plans.""" + LOGGER.info("card6.generator_request_to_plans.start scene_id=%s chars=%s", request.scene_id, len(request.characters)) + + plans = { + character.character_id: build_character_plan( + character, + fps=fps, + constraint_resolver=constraint_resolver, + apply_blend_guardrails=apply_blend_guardrails, + ) + for character in request.characters + } + + if apply_blend_guardrails and len(plans) > 1: + harmonized = harmonize_scene_transitions( + { + character_id: TransitionSettings( + num_transition_frames=plan.num_transition_frames, + share_transition=plan.share_transition, + percentage_transition_override=plan.percentage_transition_override, + ) + for character_id, plan in plans.items() + } + ) + plans = { + character_id: CharacterKimodoPlan( + character_id=plan.character_id, + prompts=plan.prompts, + num_frames=plan.num_frames, + total_frames=plan.total_frames, + constraint_lst=plan.constraint_lst, + num_transition_frames=harmonized[character_id].num_transition_frames, + share_transition=harmonized[character_id].share_transition, + percentage_transition_override=harmonized[character_id].percentage_transition_override, + segment_transition_policies=plan.segment_transition_policies, + ) + for character_id, plan in plans.items() + } + + LOGGER.info("card6.generator_request_to_plans.exit scene_id=%s plans=%s", request.scene_id, len(plans)) + return plans + + +def run_character_generation( + model: Any, + plan: CharacterKimodoPlan, + *, + diffusion_steps: int, + num_samples: int, + cfg_weight: Optional[list[float]] = None, + cfg_type: Optional[str] = None, + post_processing: bool = True, + root_margin: float = 0.04, +) -> dict[str, Any]: + """Execute Kimodo generation for one character plan.""" + LOGGER.info( + "card6.run_character_generation.start character_id=%s prompts=%s total_frames=%s", + plan.character_id, + len(plan.prompts), + plan.total_frames, + ) + + result = model( + plan.prompts, + plan.num_frames, + constraint_lst=plan.constraint_lst, + num_denoising_steps=diffusion_steps, + num_samples=num_samples, + multi_prompt=True, + cfg_weight=cfg_weight or [2.0, 2.0], + cfg_type=cfg_type, + num_transition_frames=plan.num_transition_frames, + share_transition=plan.share_transition, + percentage_transition_override=plan.percentage_transition_override, + post_processing=post_processing, + root_margin=root_margin, + ) + + LOGGER.info("card6.run_character_generation.exit character_id=%s", plan.character_id) + return result + + +def run_multi_character_generation( + model: Any, + request: GeneratorRequest, + *, + fps: float, + diffusion_steps: int = 100, + cfg_weight: Optional[list[float]] = None, + cfg_type: Optional[str] = None, + post_processing: bool = True, + root_margin: float = 0.04, + constraint_resolver: Optional[ConstraintResolver] = None, + continue_on_error: bool = False, +) -> tuple[dict[str, dict[str, Any]], dict[str, str], dict[str, CharacterKimodoPlan]]: + """Run per-character Kimodo generation for all scripts in a scene. + + Returns: + outputs: per-character model outputs. + errors: per-character error message for failed runs. + plans: resolved per-character plans for auditing and verification. + """ + LOGGER.info("card6.run_multi_character_generation.start scene_id=%s", request.scene_id) + + plans = generator_request_to_plans(request, fps=fps, constraint_resolver=constraint_resolver) + outputs: dict[str, dict[str, Any]] = {} + errors: dict[str, str] = {} + + for character_id, plan in plans.items(): + try: + outputs[character_id] = run_character_generation( + model, + plan, + diffusion_steps=diffusion_steps, + num_samples=request.num_samples, + cfg_weight=cfg_weight, + cfg_type=cfg_type, + post_processing=post_processing, + root_margin=root_margin, + ) + except Exception as exc: # pylint: disable=broad-except + LOGGER.exception("card6.run_multi_character_generation.error character_id=%s", character_id) + errors[character_id] = str(exc) + if not continue_on_error: + break + + LOGGER.info( + "card6.run_multi_character_generation.exit scene_id=%s success=%s errors=%s", + request.scene_id, + len(outputs), + len(errors), + ) + return outputs, errors, plans diff --git a/kimodo/planner/__init__.py b/kimodo/planner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f51cf075bfb0e4029fe0b1dce4d285f555934f4 --- /dev/null +++ b/kimodo/planner/__init__.py @@ -0,0 +1,5 @@ +"""Planner modules for prompt-to-script conversion.""" + +from .qwen_adapter import QwenPlannerAdapter + +__all__ = ["QwenPlannerAdapter"] diff --git a/kimodo/planner/qwen_adapter.py b/kimodo/planner/qwen_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..bb56557adfc3abb21024a5eb439159ec8e96b9b4 --- /dev/null +++ b/kimodo/planner/qwen_adapter.py @@ -0,0 +1,362 @@ +"""Qwen planner adapter: story prompt -> validated multi-character motion scripts.""" + +from __future__ import annotations + +import json +import logging +import os +import re +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen +from dataclasses import dataclass +from typing import Any, Iterable, Protocol + +from huggingface_hub import InferenceClient + +from kimodo.schemas import MotionSegment, PlannerRequest, PlannerResponse, TransitionPolicy + +LOGGER = logging.getLogger(__name__) + +_ALLOWED_TRANSITIONS = {item.value for item in TransitionPolicy} +_STATUS_VALUES = {"success", "partial", "error"} + + +class PlannerClient(Protocol): + """Protocol for LLM client implementations.""" + + def text_generation(self, prompt: str, model: str, max_new_tokens: int, temperature: float) -> str: + """Return generated text for a given model and prompt.""" + + +class HFInferencePlannerClient: + """Hugging Face Inference client wrapper used in production.""" + + def __init__(self, token: str | None = None): + self._client = InferenceClient(token=token) + + def text_generation(self, prompt: str, model: str, max_new_tokens: int, temperature: float) -> str: + return self._client.text_generation( + model=model, + prompt=prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + ) + + +class FireworksPlannerClient: + """Fireworks Inference API client using OpenAI-compatible chat completions.""" + + def __init__(self, api_key: str, base_url: str = "https://api.fireworks.ai/inference/v1"): + self._api_key = api_key + self._base_url = base_url.rstrip("/") + + def text_generation(self, prompt: str, model: str, max_new_tokens: int, temperature: float) -> str: + payload = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": max_new_tokens, + "temperature": temperature, + } + body = json.dumps(payload).encode("utf-8") + req = Request( + f"{self._base_url}/chat/completions", + data=body, + headers={ + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json", + }, + method="POST", + ) + try: + with urlopen(req, timeout=30) as response: + data = json.loads(response.read().decode("utf-8")) + except HTTPError as exc: + detail = exc.read().decode("utf-8", errors="ignore") + raise RuntimeError(f"fireworks http {exc.code}: {detail}") from exc + except URLError as exc: + raise RuntimeError(f"fireworks network error: {exc}") from exc + + choices = data.get("choices") if isinstance(data, dict) else None + if not isinstance(choices, list) or not choices: + raise RuntimeError("fireworks response missing choices") + message = choices[0].get("message") if isinstance(choices[0], dict) else None + content = message.get("content") if isinstance(message, dict) else None + if not isinstance(content, str) or not content.strip(): + raise RuntimeError("fireworks response missing content") + return content + + +@dataclass(frozen=True) +class PlannerConfig: + """Runtime settings for Qwen planner generation.""" + + model_candidates: tuple[str, ...] = ( + "Qwen/Qwen2.5-7B-Instruct", + "Qwen/Qwen2.5-3B-Instruct", + "Qwen/Qwen2.5-1.5B-Instruct", + ) + max_retries_per_model: int = 2 + max_new_tokens: int = 700 + temperature: float = 0.2 + + @classmethod + def from_env(cls) -> PlannerConfig: + raw_models = (os.getenv("KIMODO_PLANNER_MODELS") or "").strip() + model_candidates = tuple(m.strip() for m in raw_models.split(",") if m.strip()) + if not model_candidates: + model_candidates = cls.model_candidates + return cls(model_candidates=model_candidates) + + +class QwenPlannerAdapter: + """Convert high-level scene prompts to schema-validated planner scripts.""" + + def __init__(self, client: PlannerClient | None = None, config: PlannerConfig | None = None): + self.config = config or PlannerConfig.from_env() + self.client = client or self._build_default_client() + + def _build_default_client(self) -> PlannerClient: + provider = (os.getenv("KIMODO_PLANNER_PROVIDER") or "hf").strip().lower() + if provider == "fireworks": + api_key = os.getenv("FIREWORKS_API_KEY") + if not api_key: + raise ValueError("FIREWORKS_API_KEY is required when KIMODO_PLANNER_PROVIDER=fireworks") + base_url = os.getenv("FIREWORKS_BASE_URL") or "https://api.fireworks.ai/inference/v1" + return FireworksPlannerClient(api_key=api_key, base_url=base_url) + return HFInferencePlannerClient(token=os.getenv("HF_TOKEN")) + + def plan(self, request: PlannerRequest) -> PlannerResponse: + """Generate a planner response; fallback template is used if model outputs are malformed.""" + LOGGER.info("planner.plan.start scene_id=%s chars=%d", request.scene_id, len(request.characters)) + prompt = self._build_prompt(request) + + errors: list[str] = [] + for model_name in self.config.model_candidates: + for attempt in range(1, self.config.max_retries_per_model + 1): + try: + raw = self.client.text_generation( + prompt=prompt, + model=model_name, + max_new_tokens=self.config.max_new_tokens, + temperature=self.config.temperature, + ) + payload = self._parse_json_payload(raw) + normalized = self._normalize_payload(payload, request) + response = PlannerResponse(**normalized) + LOGGER.info( + "planner.plan.success scene_id=%s model=%s attempt=%d", + request.scene_id, + model_name, + attempt, + ) + return response + except Exception as exc: # pylint: disable=broad-except + err = f"model={model_name} attempt={attempt} error={exc}" + LOGGER.warning("planner.plan.retry scene_id=%s %s", request.scene_id, err) + errors.append(err) + + fallback = self._fallback_response(request, errors) + LOGGER.info("planner.plan.fallback scene_id=%s", request.scene_id) + return fallback + + def _build_prompt(self, request: PlannerRequest) -> str: + LOGGER.info("planner.build_prompt.start scene_id=%s", request.scene_id) + char_block = "\n".join( + f"- {c.character_id} (skeleton={c.skeleton_type}, desc={c.description or 'none'})" + for c in request.characters + ) + prompt = ( + "You are a motion-planning copilot. Return strict JSON only, no markdown.\\n" + "Generate per-character motion segments for this scene.\\n" + "Rules:\\n" + "1) Return object keys: status, scripts, total_duration_sec.\\n" + "2) status must be success|partial|error.\\n" + "3) scripts is object mapping character_id -> list of segments.\\n" + "4) segment keys: segment_id(int), action_text(str), duration_sec(float 0.5-30), transition_policy(smooth|cut|hold|overlap), interaction_target(optional str), constraints(optional object).\\n" + "5) Use only provided character_ids.\\n" + "6) Keep total duration <= duration_limit_sec.\\n\\n" + f"scene_id: {request.scene_id}\\n" + f"duration_limit_sec: {request.duration_limit_sec}\\n" + f"user_prompt: {request.user_prompt}\\n" + f"characters:\\n{char_block}\\n" + ) + LOGGER.info("planner.build_prompt.exit scene_id=%s prompt_chars=%s", request.scene_id, len(prompt)) + return prompt + + def _parse_json_payload(self, raw_text: str) -> dict[str, Any]: + LOGGER.info("planner.parse_json.start") + text = raw_text.strip() + if not text: + raise ValueError("empty model response") + + # Accept plain JSON or fenced JSON blocks and parse strictly. + if text.startswith("```"): + match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", text, flags=re.DOTALL) + if not match: + raise ValueError("fenced block found without JSON object") + text = match.group(1).strip() + + if not text.startswith("{"): + start = text.find("{") + end = text.rfind("}") + if start == -1 or end == -1 or end <= start: + raise ValueError("no JSON object found in model response") + text = text[start : end + 1] + + payload = json.loads(text) + if not isinstance(payload, dict): + raise ValueError("planner response must be a JSON object") + LOGGER.info("planner.parse_json.exit keys=%s", sorted(payload.keys())) + return payload + + def _normalize_payload(self, payload: dict[str, Any], request: PlannerRequest) -> dict[str, Any]: + LOGGER.info("planner.normalize_payload.start scene_id=%s", request.scene_id) + normalized: dict[str, Any] = { + "scene_id": request.scene_id, + "status": self._normalize_status(payload.get("status")), + "error_message": payload.get("error_message"), + "scripts": {}, + "metadata": payload.get("metadata") or {}, + "total_duration_sec": float(payload.get("total_duration_sec") or 0.0), + } + + requested_ids = [c.character_id for c in request.characters] + requested_set = set(requested_ids) + raw_scripts = payload.get("scripts") or {} + + if not isinstance(raw_scripts, dict): + raise ValueError("scripts must be an object") + + for raw_char_id, raw_segments in raw_scripts.items(): + char_id = self._normalize_character_id(str(raw_char_id)) + if char_id not in requested_set: + continue + normalized["scripts"][char_id] = self._normalize_segments(raw_segments, requested_set, request.duration_limit_sec) + + # Guarantee all requested characters have at least one segment. + for char_id in requested_ids: + if char_id not in normalized["scripts"] or not normalized["scripts"][char_id]: + normalized["scripts"][char_id] = self._default_segments(char_id, request.user_prompt) + + normalized["total_duration_sec"] = min( + request.duration_limit_sec, + max( + sum(float(seg.duration_sec) for seg in segs) + for segs in normalized["scripts"].values() + ), + ) + + normalized["metadata"] = { + **(normalized["metadata"] if isinstance(normalized["metadata"], dict) else {}), + "normalized": True, + } + LOGGER.info( + "planner.normalize_payload.exit scene_id=%s script_chars=%s total_duration=%.2f", + request.scene_id, + len(normalized["scripts"]), + float(normalized["total_duration_sec"]), + ) + return normalized + + def _normalize_status(self, value: Any) -> str: + status = str(value or "success").strip().lower() + return status if status in _STATUS_VALUES else "partial" + + def _normalize_character_id(self, value: str) -> str: + normalized = re.sub(r"[^a-zA-Z0-9_-]", "_", value.strip()) + normalized = re.sub(r"_+", "_", normalized) + normalized = normalized.strip("_") + return normalized[:50] or "character" + + def _normalize_segments( + self, + raw_segments: Any, + valid_ids: set[str], + duration_limit_sec: float, + ) -> list[MotionSegment]: + LOGGER.info("planner.normalize_segments.start") + if not isinstance(raw_segments, list): + raise ValueError("character script must be a list") + + normalized: list[MotionSegment] = [] + total = 0.0 + for idx, seg in enumerate(raw_segments): + if not isinstance(seg, dict): + continue + action_text = str(seg.get("action_text") or "hold idle stance").strip() + if len(action_text) < 3: + action_text = "hold idle stance" + duration = self._clamp_duration(seg.get("duration_sec")) + if total + duration > duration_limit_sec: + duration = max(0.5, duration_limit_sec - total) + total += duration + + interaction_target = seg.get("interaction_target") + if interaction_target is not None: + interaction_target = self._normalize_character_id(str(interaction_target)) + if interaction_target not in valid_ids: + interaction_target = None + + transition = str(seg.get("transition_policy") or TransitionPolicy.SMOOTH.value).lower() + if transition not in _ALLOWED_TRANSITIONS: + transition = TransitionPolicy.SMOOTH.value + + normalized.append( + MotionSegment( + segment_id=idx, + action_text=action_text[:500], + duration_sec=duration, + transition_policy=TransitionPolicy(transition), + interaction_target=interaction_target, + constraints=seg.get("constraints") if isinstance(seg.get("constraints"), dict) else {}, + ) + ) + + LOGGER.info("planner.normalize_segments.exit segment_count=%s", len(normalized)) + return normalized + + def _clamp_duration(self, value: Any) -> float: + try: + number = float(value) + except (TypeError, ValueError): + return 2.0 + return min(30.0, max(0.5, number)) + + def _default_segments(self, char_id: str, prompt: str) -> list[MotionSegment]: + base_text = prompt.strip()[:120] + return [ + MotionSegment( + segment_id=0, + action_text=f"{char_id} starts: {base_text}", + duration_sec=2.0, + transition_policy=TransitionPolicy.SMOOTH, + interaction_target=None, + constraints={}, + ), + MotionSegment( + segment_id=1, + action_text=f"{char_id} continues with controlled motion and stable stance", + duration_sec=2.0, + transition_policy=TransitionPolicy.SMOOTH, + interaction_target=None, + constraints={}, + ), + ] + + def _fallback_response(self, request: PlannerRequest, errors: Iterable[str]) -> PlannerResponse: + LOGGER.info("planner.fallback_response.start scene_id=%s", request.scene_id) + scripts = { + c.character_id: self._default_segments(c.character_id, request.user_prompt) + for c in request.characters + } + total_duration = max(sum(seg.duration_sec for seg in segs) for segs in scripts.values()) + response = PlannerResponse( + scene_id=request.scene_id, + status="partial", + error_message="; ".join(errors)[:1200] if errors else "model output malformed; fallback used", + scripts=scripts, + metadata={"fallback_used": True, "normalized": True}, + total_duration_sec=min(request.duration_limit_sec, total_duration), + ) + LOGGER.info("planner.fallback_response.exit scene_id=%s", request.scene_id) + return response diff --git a/kimodo/postprocess.py b/kimodo/postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8c0120d7776c4dce7478e3ce8a51f5308eaf69 --- /dev/null +++ b/kimodo/postprocess.py @@ -0,0 +1,346 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Post-processing utilities for motion generation output.""" + +from types import SimpleNamespace +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch + +from .constraints import ( + EndEffectorConstraintSet, + FullBodyConstraintSet, + Root2DConstraintSet, +) +from .geometry import matrix_to_quaternion, quaternion_to_matrix +from .skeleton import ( + G1Skeleton34, + SkeletonBase, + SMPLXSkeleton22, + SOMASkeleton30, + SOMASkeleton77, + fk, +) + + +def extract_input_motion_from_constraints( + constraint_lst: List, + skeleton: SkeletonBase, + num_frames: int, + num_joints: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Extract hip translations and local rotations from constraints for postprocessing. + + Args: + constraint_lst: List of constraints (FullBodyConstraintSet, EndEffectorConstraintSet, etc.) + skeleton: Skeleton instance + num_frames: Total number of frames in the motion + num_joints: Number of joints + + Returns: + Tuple of (hip_translations_input, rotations_input): + - hip_translations_input: Hip translations, shape (T, 3) + - rotations_input: Local joint rotations as quaternions, shape (T, J, 4) + """ + # Initialize with zeros for all frames + hip_translations_input = torch.zeros(num_frames, 3) + rotations_input = torch.zeros(num_frames, num_joints, 4) + rotations_input[..., 0] = 1.0 # Initialize as identity quaternions (w=1, x=y=z=0) + + def _match_hip_dtype(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(device=hip_translations_input.device, dtype=hip_translations_input.dtype) + + def _match_rot_dtype(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(device=rotations_input.device, dtype=rotations_input.dtype) + + if not constraint_lst: + return hip_translations_input, rotations_input + + for constraint in constraint_lst: + frame_indices = constraint.frame_indices + if isinstance(frame_indices, torch.Tensor): + valid_mask = frame_indices < num_frames + if valid_mask.sum() == 0: + continue + frame_indices = frame_indices[valid_mask] + else: + valid_positions = [i for i, idx in enumerate(frame_indices) if idx < num_frames] + if not valid_positions: + continue + frame_indices = [frame_indices[i] for i in valid_positions] + + # Handle Root2DConstraintSet separately - only assign smooth_root_2d at xz dimensions + if isinstance(constraint, Root2DConstraintSet): + smooth_root_2d = constraint.smooth_root_2d # (K, 2) where K = len(frame_indices) + if isinstance(frame_indices, torch.Tensor): + smooth_root_2d = smooth_root_2d[valid_mask] + else: + smooth_root_2d = smooth_root_2d[valid_positions] + smooth_root_2d = _match_hip_dtype(smooth_root_2d) + hip_translations_input[frame_indices, 0] = smooth_root_2d[:, 0] # x + hip_translations_input[frame_indices, 2] = smooth_root_2d[:, 1] # z + continue + elif isinstance(constraint, FullBodyConstraintSet) or isinstance(constraint, EndEffectorConstraintSet): + global_rots = constraint.global_joints_rots # (K, J, 3, 3) where K = len(frame_indices) + global_positions = constraint.global_joints_positions # (K, J, 3) + if isinstance(frame_indices, torch.Tensor): + global_rots = global_rots[valid_mask] + global_positions = global_positions[valid_mask] + smooth_root_2d = constraint.smooth_root_2d[valid_mask] + else: + global_rots = global_rots[valid_positions] + global_positions = global_positions[valid_positions] + smooth_root_2d = constraint.smooth_root_2d[valid_positions] + + root_positions = global_positions[:, skeleton.root_idx] # (K, 3) + # Replace xz with smooth_root_2d values. + root_positions[:, 0] = smooth_root_2d[:, 0] # x + root_positions[:, 2] = smooth_root_2d[:, 1] # z + + local_rot_mats = skeleton.global_rots_to_local_rots(global_rots) # (K, J, 3, 3) + local_rot_quats = matrix_to_quaternion(local_rot_mats) # (K, J, 4) + + hip_translations_input[frame_indices] = _match_hip_dtype(root_positions) + rotations_input[frame_indices] = _match_rot_dtype(local_rot_quats) + else: + NotImplementedError(f"Constraint {constraint.name} is not supported") + + return hip_translations_input, rotations_input + + +def create_working_rig_from_skeleton( + skeleton: SkeletonBase, above_ground_offset: float = 0.007 +) -> List[SimpleNamespace]: + """Create the working rig as a list of SimpleNamespace objects from skeleton. + + Args: + skeleton: SkeletonBase instance with bone_order_names, neutral_joints, joint_parents + above_ground_offset: Additional offset to position the rig slightly above ground + Returns: + List of SimpleNamespace objects representing the working rig + """ + working_rig_joints = [] + + joint_names = skeleton.bone_order_names + neutral_positions = skeleton.neutral_joints.cpu().numpy() + parent_indices = skeleton.joint_parents.cpu().numpy() + + if isinstance(skeleton, (G1Skeleton34, SMPLXSkeleton22)): + retarget_map = { + skeleton.bone_order_names[skeleton.root_idx]: "Hips", + skeleton.left_hand_joint_names[0]: "LeftHand", + skeleton.right_hand_joint_names[0]: "RightHand", + skeleton.left_foot_joint_names[0]: "LeftFoot", + skeleton.right_foot_joint_names[0]: "RightFoot", + } + else: + # works for SOMA + retarget_map = { + "Hips": "Hips", + "Head": "Head", + "LeftHand": "LeftHand", + "RightHand": "RightHand", + "LeftFoot": "LeftFoot", + "RightFoot": "RightFoot", + } + + for i, joint_name in enumerate(joint_names): + parent_name = None if parent_indices[i] == -1 else joint_names[parent_indices[i]] + + # Calculate local translation relative to parent + if parent_indices[i] == -1: + # Move the rig so that the lowest point (toe) is at ground level (y=0), + # plus a small offset to position the rig slightly above ground + toe_height = neutral_positions[:, 1].min() # lowest y-coordinate (toe) + local_translation = ( + neutral_positions[i] + np.array([0.0, -toe_height + above_ground_offset, 0.0]) + ).tolist() + else: + parent_idx = parent_indices[i] + parent_position = neutral_positions[parent_idx] + joint_position = neutral_positions[i] + local_translation = (joint_position - parent_position).tolist() + + # Default rotation (identity quaternion: x=0, y=0, z=0, w=1) + default_rotation = [0.0, 0.0, 0.0, 1.0] + + joint_info = SimpleNamespace( + name=joint_name, + parent=parent_name, + t_pose_rotation=default_rotation, + t_pose_translation=local_translation, + retarget_tag=retarget_map.get(joint_name), + ) + + working_rig_joints.append(joint_info) + + return working_rig_joints + + +def post_process_motion( + local_rot_mats: torch.Tensor, + root_positions: torch.Tensor, + contacts: torch.Tensor, + skeleton: SkeletonBase, + constraint_lst: Optional[List] = None, + contact_threshold: float = 0.5, + root_margin: float = 0.04, +) -> Dict[str, torch.Tensor]: + """Post-process generated motion to reduce foot skating and improve quality. + + Args: + local_rot_mats: Local joint rotation matrices, shape (B, T, J, 3, 3) + root_positions: Root joint positions, shape (B, T, 3) + contacts: Foot contact labels, shape (B, T, num_contacts) + skeleton: Skeleton instance + constraint_lst: Optional list of constraints (or list of lists of constraints for batched inference)(FullBodyConstraintSet, etc.) + contact_threshold: Threshold for foot contact detection + root_margin: Margin for root position correction + + Returns: + Dictionary with corrected motion data: + - local_rot_mats: Corrected local rotation matrices (B, T, J, 3, 3) + - root_positions: Corrected root positions (B, T, 3) + - posed_joints: Corrected global joint positions (B, T, J, 3) + - global_rot_mats: Corrected global rotation matrices (B, T, J, 3, 3) + """ + # Ensure batch dimension + assert local_rot_mats.dim() == 5, "local_rot_mats should be 5D, make sure to include the batch dimension" + + batch_size, num_frames, num_joints = local_rot_mats.shape[:3] + + def _build_constraint_masks_dict(constraints: List) -> Dict[str, torch.Tensor]: + out = { + key: torch.zeros(num_frames, dtype=torch.float32) + for key in [ + "FullBody", + "LeftFoot", + "RightFoot", + "LeftHand", + "RightHand", + "Root", + ] + } + for constraint in constraints: + frame_indices = constraint.frame_indices + if isinstance(frame_indices, torch.Tensor): + frame_indices = frame_indices[frame_indices < num_frames] + if frame_indices.numel() == 0: + continue + else: + frame_indices = [idx for idx in frame_indices if idx < num_frames] + if not frame_indices: + continue + if constraint.name == "fullbody": + out["FullBody"][frame_indices] = 1.0 + elif constraint.name == "left-foot": + out["LeftFoot"][frame_indices] = 1.0 + elif constraint.name == "right-foot": + out["RightFoot"][frame_indices] = 1.0 + elif constraint.name == "left-hand": + out["LeftHand"][frame_indices] = 1.0 + elif constraint.name == "right-hand": + out["RightHand"][frame_indices] = 1.0 + elif constraint.name == "root2d": + out["Root"][frame_indices] = 1.0 + return out + + # Create constraint masks from constraint_lst (one dict per batch item when batched) + batched_constraints = bool(constraint_lst) and isinstance(constraint_lst[0], list) + if batched_constraints: + constraint_masks_dict_lst = [_build_constraint_masks_dict(constraint_lst[b]) for b in range(batch_size)] + else: + constraint_masks_dict = ( + _build_constraint_masks_dict(constraint_lst) + if constraint_lst + else { + key: torch.zeros(num_frames, dtype=torch.float32) + for key in [ + "FullBody", + "LeftFoot", + "RightFoot", + "LeftHand", + "RightHand", + "Root", + ] + } + ) + + # Create working rig + above_ground_offset = 0.02 if isinstance(skeleton, (SOMASkeleton30, SOMASkeleton77)) else 0.007 + # larger offset for SOMA since model tends to generate lower to the ground + working_rig = create_working_rig_from_skeleton(skeleton, above_ground_offset=above_ground_offset) + has_double_ankle_joints = isinstance(skeleton, G1Skeleton34) + + # Prepare input tensors. The generated motion will be modified in place. Clone first. + neutral_joints_pelvis_offset = skeleton.neutral_joints[0].cpu().clone() + hip_translations_corrected = root_positions.cpu().clone() + rotations_corrected = matrix_to_quaternion(local_rot_mats).cpu().clone() # (B, T, J, 4) + contacts = contacts.cpu() + + # Extract input motion (target keyframes) from constraints for each batch + # For constrained keyframes, use the original motion from constraints + # For non-constrained frames, zeros are used + hip_translations_input = torch.zeros(batch_size, num_frames, 3) + rotations_input = torch.zeros(batch_size, num_frames, num_joints, 4) + rotations_input[..., 0] = 1.0 # Initialize as identity quaternions (w=1, x=y=z=0) + + if constraint_lst: + for b in range(batch_size): + # Get constraints for this batch item (if batched) or use the same list + constraints_lst_el = ( + constraint_lst[b] + if isinstance( + constraint_lst[0], list + ) # when the constraint_list is in batch format, each item in a list is a constraintlist for one sample + else constraint_lst # single constraint list shared for all samples in the batch + ) + hip_translations_input[b], rotations_input[b] = extract_input_motion_from_constraints( + constraints_lst_el, + skeleton, + num_frames, + num_joints, + ) + + # Call the motion correction for each batch (optional package) + try: + from motion_correction import motion_postprocess + except ImportError as e: + raise RuntimeError( + "Motion correction is required for this postprocessing path but the " + "motion_correction package is not installed. Install with: pip install -e ." + ) from e + for b in range(batch_size): + masks_b = constraint_masks_dict_lst[b] if batched_constraints else constraint_masks_dict + motion_postprocess.correct_motion( + hip_translations_corrected[b : b + 1], + rotations_corrected[b : b + 1], + contacts[b : b + 1], + hip_translations_input[b : b + 1], + rotations_input[b : b + 1], + masks_b, + contact_threshold, + root_margin, + working_rig, + has_double_ankle_joints, + ) + + local_rot_mats_corrected = quaternion_to_matrix(rotations_corrected) + + # Compute posed joints using FK + device = local_rot_mats.device + global_rot_mats, posed_joints, _ = fk( + local_rot_mats_corrected.to(device), + hip_translations_corrected.to(device), + skeleton, + ) + + result = { + "local_rot_mats": local_rot_mats_corrected.to(device), + "root_positions": hip_translations_corrected.to(device), + "posed_joints": posed_joints, + "global_rot_mats": global_rot_mats, + } + + return result diff --git a/kimodo/runtime/__init__.py b/kimodo/runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65bae40bb148a8df5e7992f03c5e929c69178f61 --- /dev/null +++ b/kimodo/runtime/__init__.py @@ -0,0 +1,9 @@ +"""Runtime helpers for device selection and backend health checks.""" + +from .device import RuntimeHealthReport, select_runtime_device, runtime_health_report + +__all__ = [ + "RuntimeHealthReport", + "select_runtime_device", + "runtime_health_report", +] \ No newline at end of file diff --git a/kimodo/runtime/device.py b/kimodo/runtime/device.py new file mode 100644 index 0000000000000000000000000000000000000000..7880988c4ce49ce6f89b6bd325214d68e05b7c08 --- /dev/null +++ b/kimodo/runtime/device.py @@ -0,0 +1,172 @@ +"""Card 9 runtime device bootstrap helpers (AMD/ROCm-friendly).""" + +from __future__ import annotations + +import logging +import os +from dataclasses import asdict, dataclass +from typing import Optional + +import torch + +LOGGER = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class RuntimeHealthReport: + """Runtime/backend detection report for startup health checks.""" + + requested_device: str + selected_device: str + backend: str + cuda_available: bool + rocm_available: bool + mps_available: bool + strict_mode: bool + reason: str + + def to_dict(self) -> dict: + return asdict(self) + + +def _env_bool(name: str, default: bool = False) -> bool: + raw = os.environ.get(name) + if raw is None: + return default + return str(raw).strip().lower() in ("1", "true", "yes", "on") + + +def _normalize_requested_device(requested: Optional[str]) -> str: + value = requested or os.environ.get("KIMODO_DEVICE") or os.environ.get("DEVICE") or "auto" + return str(value).strip().lower() + + +def _has_mps() -> bool: + backends = getattr(torch, "backends", None) + mps = getattr(backends, "mps", None) + if mps is None: + return False + is_available = getattr(mps, "is_available", None) + if callable(is_available): + try: + return bool(is_available()) + except Exception: # pragma: no cover + return False + return False + + +def _backend_name(cuda_available: bool, rocm_available: bool, mps_available: bool) -> str: + if rocm_available: + return "rocm" + if cuda_available: + return "cuda" + if mps_available: + return "mps" + return "cpu" + + +def select_runtime_device(requested: Optional[str] = None) -> str: + """Resolve runtime device with ROCm/CUDA/CPU fallback. + + Resolution order: + - explicit requested argument + - environment variable KIMODO_DEVICE (or DEVICE) + - auto + + If KIMODO_STRICT_DEVICE=true and requested accelerator is unavailable, raises ValueError. + """ + LOGGER.info("card9.select_runtime_device.start requested=%s", requested) + strict_mode = _env_bool("KIMODO_STRICT_DEVICE", default=False) + req = _normalize_requested_device(requested) + + cuda_available = bool(torch.cuda.is_available()) + rocm_available = cuda_available and bool(getattr(torch.version, "hip", None)) + mps_available = _has_mps() + + accelerator_aliases = {"cuda", "cuda:0", "gpu", "rocm", "hip", "amd"} + + if req == "cpu": + selected = "cpu" + reason = "explicit_cpu" + elif req in ("mps", "apple"): + if mps_available: + selected = "mps" + reason = "explicit_mps" + elif strict_mode: + raise ValueError("Requested MPS device but MPS backend is unavailable") + else: + selected = "cpu" + reason = "mps_unavailable_fallback_cpu" + elif req in accelerator_aliases: + if cuda_available: + selected = "cuda:0" + reason = "explicit_accelerator_available" + elif strict_mode: + raise ValueError(f"Requested accelerator '{req}' but no torch accelerator is available") + else: + selected = "cpu" + reason = "accelerator_unavailable_fallback_cpu" + elif req == "auto": + if cuda_available: + selected = "cuda:0" + reason = "auto_accelerator" + elif mps_available: + selected = "mps" + reason = "auto_mps" + else: + selected = "cpu" + reason = "auto_cpu" + else: + # Preserve explicit torch device strings (e.g. cuda:1, cpu) when possible. + if req.startswith("cuda"): + if cuda_available: + selected = req + reason = "explicit_cuda_index" + elif strict_mode: + raise ValueError(f"Requested device '{req}' but CUDA/ROCm backend is unavailable") + else: + selected = "cpu" + reason = "explicit_cuda_unavailable_fallback_cpu" + else: + if strict_mode: + raise ValueError(f"Unknown device specifier '{req}'") + selected = "cpu" + reason = "unknown_device_fallback_cpu" + + LOGGER.info("card9.select_runtime_device.exit selected=%s reason=%s", selected, reason) + return selected + + +def runtime_health_report(requested: Optional[str] = None) -> RuntimeHealthReport: + """Return a startup runtime report suitable for health checks and logs.""" + LOGGER.info("card9.runtime_health_report.start requested=%s", requested) + + strict_mode = _env_bool("KIMODO_STRICT_DEVICE", default=False) + req = _normalize_requested_device(requested) + cuda_available = bool(torch.cuda.is_available()) + rocm_available = cuda_available and bool(getattr(torch.version, "hip", None)) + mps_available = _has_mps() + + selected = select_runtime_device(req) + reason = "ok" + if selected == "cpu" and req in {"cuda", "cuda:0", "gpu", "rocm", "hip", "amd"}: + reason = "fallback_cpu" + + report = RuntimeHealthReport( + requested_device=req, + selected_device=selected, + backend=_backend_name(cuda_available, rocm_available, mps_available), + cuda_available=cuda_available, + rocm_available=rocm_available, + mps_available=mps_available, + strict_mode=strict_mode, + reason=reason, + ) + LOGGER.info( + "card9.runtime_health_report.exit requested=%s selected=%s backend=%s reason=%s", + report.requested_device, + report.selected_device, + report.backend, + report.reason, + ) + return report \ No newline at end of file diff --git a/kimodo/sanitize.py b/kimodo/sanitize.py new file mode 100644 index 0000000000000000000000000000000000000000..f3728e2fd1b3c3c949aa9d69a94db1570c58e230 --- /dev/null +++ b/kimodo/sanitize.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Text prompt sanitization for motion generation (whitespace, punctuation, capitalization).""" + + +def sanitize_text(text: str, paragraph: bool = True) -> str: + """Sanitize a text prompt: strip, collapse spaces, capitalize, trim non-alphanumeric, add/fix final punctuation. + + Args: + text: Input text prompt. + paragraph: If True, capitalize after each sentence break and normalize spacing between sentences. + + Returns: + Sanitized text. + """ + # remove any trailing or leading whitespace + text = text.strip() + + # https://stackoverflow.com/a/1546251 + # replace duplicate spaces by one space + text = " ".join(text.split()) + + if text == "": + return text + + # removing leading non alpha numeric characters + for i, c in enumerate(text): + if not str.isalnum(c): + continue + break + text = text[i:] + + # Capitalize + text = text.capitalize() + + final_punctuations = ".!?\"])'" + # removing trailing non alpha numeric characters + # expect final punctuations + for i, c in reversed(list(enumerate(text))): + if not str.isalnum(c) and c not in final_punctuations: + continue + break + text = text[: i + 1] + + # Adding period at the end if needed + if text[-1] not in ".!?": + text = text + "." + + if paragraph: + # fix end of sentences if several sentences + for sentence_break in ".!?": + subtexts = text.split(sentence_break) + text = f"{sentence_break} ".join( # put back a space after the break + [ + y[0].capitalize() + y[1:] # only capitalize the first character + if y + else y # y is empty at the end + for x in subtexts + for y in [x.strip()] # remove extra spaces + ] + ).strip() # remove extra space at the end + return text + + +def sanitize_texts(texts: list[str]) -> list[str]: + """Sanitize each text prompt in the list (see sanitize_text). + + Args: + texts: List of input text prompts. + + Returns: + List of sanitized texts. + """ + return [sanitize_text(text) for text in texts] + + +if __name__ == "__main__": + texts = [ + " A person is walking.", + "someone go forward", + "jump", + "jumping!", + "jumping)", + "-go", + "blocasdji -----", + "", + ] + + print("Old texts") + print("\n".join(texts)) + print() + + new_texts = sanitize_texts(texts) + print("Sanitized texts") + print("\n".join(new_texts)) diff --git a/kimodo/scheduler.py b/kimodo/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..ebdd5b57484e41dda120c911532857bc76fd47ae --- /dev/null +++ b/kimodo/scheduler.py @@ -0,0 +1,492 @@ +""" +Card 3: Shared State Loop (Deterministic Event Scheduler) + +Defines deterministic event ordering for multi-character interactions in one port. +- Synchronized time with per-character state containers +- Deterministic conflict resolution (same seed → same outcome) +- Support for ≥3 active characters +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional, Any +import hashlib +import json +import logging + + +LOGGER = logging.getLogger(__name__) + + +class CharacterState(Enum): + """Character lifecycle state.""" + IDLE = "idle" + BUSY = "busy" + TRANSITIONING = "transitioning" + INTERACTING = "interacting" + + +class ConflictResolutionPolicy(str, Enum): + """How to resolve conflicting interactions.""" + PRIORITY_BASED = "priority_based" # Higher priority character wins + FIFO = "fifo" # First in, first out + COOLDOWN = "cooldown" # Enforce cooldown between interactions + NEGOTIATION = "negotiation" # Custom negotiation logic + + +@dataclass +class CharacterSegmentState: + """Current state of a character's motion segment execution.""" + + character_id: str + segment_index: int # Current segment in script + frames_elapsed: int # Frames executed in current segment + total_frames: int # Total frames for current segment + is_complete: bool = False + + def progress(self) -> float: + """Return 0-1 progress through current segment.""" + if self.total_frames == 0: + return 1.0 + return min(1.0, self.frames_elapsed / self.total_frames) + + +@dataclass +class CharacterSlot: + """Per-character state container in shared loop.""" + + character_id: str + skeleton_type: str + current_state: CharacterState = CharacterState.IDLE + segment_state: Optional[CharacterSegmentState] = None + + # Interaction tracking + interaction_target: Optional[str] = None + last_interaction_time_ms: int = 0 + interaction_cooldown_ms: int = 500 # Prevent rapid re-interactions + + # Metadata + priority: int = 0 # For conflict resolution + cycle_count: int = 0 # Lifecycle counter + + def is_busy(self) -> bool: + """Check if character is currently executing motion.""" + return self.current_state in [ + CharacterState.BUSY, + CharacterState.TRANSITIONING, + CharacterState.INTERACTING + ] + + def can_interact(self, current_time_ms: int) -> bool: + """Check if character can start new interaction.""" + time_since_last = current_time_ms - self.last_interaction_time_ms + return time_since_last >= self.interaction_cooldown_ms + + +@dataclass +class LoopTick: + """Single tick in the deterministic event loop.""" + + tick_number: int + frame_number: int + time_ms: float + fps: int = 30 + + # Per-tick events + character_updates: Dict[str, CharacterSlot] = field(default_factory=dict) + completed_segments: List[str] = field(default_factory=list) + interactions: List[tuple] = field(default_factory=list) # [(from_id, to_id), ...] + + def get_timestamp(self) -> dict: + """Return tick metadata for auditing.""" + return { + "tick_number": self.tick_number, + "frame_number": self.frame_number, + "time_ms": self.time_ms, + "fps": self.fps, + } + + +class DeterministicLoop: + """ + Deterministic multi-character event loop. + + Ensures: + - Same seed → same outputs (for testing replay) + - No race conditions (total determinism within single process) + - Clear conflict resolution (priority/FIFO/cooldown) + - Synchronized timeline for all characters + """ + + def __init__( + self, + fps: int = 30, + seed: int = 42, + conflict_policy: ConflictResolutionPolicy = ConflictResolutionPolicy.COOLDOWN, + ): + LOGGER.info( + "scheduler.loop.init.start fps=%s seed=%s conflict_policy=%s", + fps, + seed, + conflict_policy.value, + ) + self.fps = fps + self.seed = seed + self.conflict_policy = conflict_policy + + # Derive deterministic RNG state from seed + self._rng_state = seed + + # State tracking + self.tick_number = 0 + self.frame_number = 0 + self.time_ms = 0.0 + self.ms_per_frame = 1000.0 / fps + + # Per-character state + self.characters: Dict[str, CharacterSlot] = {} + + # Event log for auditing + self.tick_history: List[LoopTick] = [] + LOGGER.info("scheduler.loop.init.exit") + + def register_character( + self, + character_id: str, + skeleton_type: str, + priority: int = 0, + ) -> None: + """Register a character for this loop.""" + LOGGER.info( + "scheduler.register_character.start character_id=%s skeleton=%s priority=%s", + character_id, + skeleton_type, + priority, + ) + if character_id in self.characters: + raise ValueError(f"Character {character_id} already registered") + + self.characters[character_id] = CharacterSlot( + character_id=character_id, + skeleton_type=skeleton_type, + priority=priority, + ) + LOGGER.info("scheduler.register_character.exit character_id=%s", character_id) + + def _deterministic_rng(self) -> float: + """Generate deterministic pseudo-random number (0-1).""" + # Simple linear congruential generator seeded with loop state + self._rng_state = (self._rng_state * 1103515245 + 12345) & 0x7fffffff + return (self._rng_state / 0x7fffffff) + + def _resolve_conflict( + self, + char1_id: str, + char2_id: str, + ) -> str: + """ + Deterministically resolve conflict between two characters. + + Returns: character_id that wins the interaction. + """ + char1 = self.characters[char1_id] + char2 = self.characters[char2_id] + + if self.conflict_policy == ConflictResolutionPolicy.PRIORITY_BASED: + # Higher priority wins + if char1.priority > char2.priority: + return char1_id + elif char2.priority > char1.priority: + return char2_id + # Equal priority: use deterministic tiebreaker (alphabetical) + return min(char1_id, char2_id) + + elif self.conflict_policy == ConflictResolutionPolicy.FIFO: + # Earlier interaction time wins + if char1.last_interaction_time_ms < char2.last_interaction_time_ms: + return char1_id + else: + return char2_id + + elif self.conflict_policy == ConflictResolutionPolicy.COOLDOWN: + # Both can interact if cooldown satisfied + char1_ready = char1.can_interact(int(self.time_ms)) + char2_ready = char2.can_interact(int(self.time_ms)) + + if char1_ready and not char2_ready: + return char1_id + elif char2_ready and not char1_ready: + return char2_id + # Both or neither ready: use priority tiebreaker + if char1.priority > char2.priority: + return char1_id + else: + return char2_id + + else: # NEGOTIATION (placeholder) + return min(char1_id, char2_id) + + def advance_tick( + self, + character_motions: Dict[str, Dict[str, Any]], + ) -> LoopTick: + """ + Advance one tick forward with deterministic character updates. + + Args: + character_motions: Dict[character_id] → motion data for this frame + + Returns: + LoopTick with event history for this frame + """ + LOGGER.info( + "scheduler.advance_tick.start tick=%s frame=%s chars=%s motions=%s", + self.tick_number, + self.frame_number, + len(self.characters), + len(character_motions), + ) + tick = LoopTick( + tick_number=self.tick_number, + frame_number=self.frame_number, + time_ms=self.time_ms, + fps=self.fps, + ) + + # 1. Update character segment states (deterministic progression) + for char_id, char_slot in self.characters.items(): + if char_slot.segment_state is None: + continue + + # Advance frame counter + char_slot.segment_state.frames_elapsed += 1 + + # Check if segment complete + if char_slot.segment_state.frames_elapsed >= char_slot.segment_state.total_frames: + char_slot.segment_state.is_complete = True + tick.completed_segments.append(char_id) + char_slot.current_state = CharacterState.IDLE + else: + char_slot.current_state = CharacterState.BUSY + + tick.character_updates[char_id] = char_slot + + # 2. Detect and resolve conflicts + pending_interactions = [] + for char_id, char_slot in self.characters.items(): + if char_slot.interaction_target: + pending_interactions.append((char_id, char_slot.interaction_target)) + + # Resolve conflicts deterministically + for char1_id, char2_id in pending_interactions: + winner_id = self._resolve_conflict(char1_id, char2_id) + tick.interactions.append((winner_id, char2_id if winner_id == char1_id else char1_id)) + + # Update last interaction time + self.characters[winner_id].last_interaction_time_ms = int(self.time_ms) + + # 3. Advance time + self.tick_number += 1 + self.frame_number += 1 + self.time_ms += self.ms_per_frame + + # 4. Record tick + self.tick_history.append(tick) + + LOGGER.info( + "scheduler.advance_tick.exit tick=%s completed=%s interactions=%s", + tick.tick_number, + len(tick.completed_segments), + len(tick.interactions), + ) + + return tick + + def get_state_hash(self) -> str: + """ + Compute deterministic hash of current loop state. + + Used for seeded replay verification: + Same seed → same state hash at corresponding tick. + """ + state_dict = { + "tick_number": self.tick_number, + "frame_number": self.frame_number, + "time_ms": self.time_ms, + "rng_state": self._rng_state, + "characters": { + char_id: { + "state": char_slot.current_state.value, + "frames_elapsed": char_slot.segment_state.frames_elapsed if char_slot.segment_state else 0, + } + for char_id, char_slot in self.characters.items() + } + } + + state_json = json.dumps(state_dict, sort_keys=True) + return hashlib.sha256(state_json.encode()).hexdigest()[:16] + + def reset(self) -> None: + """Reset loop to initial state (for replay).""" + LOGGER.info( + "scheduler.reset.start tick=%s frame=%s registered_chars=%s", + self.tick_number, + self.frame_number, + len(self.characters), + ) + self.tick_number = 0 + self.frame_number = 0 + self.time_ms = 0.0 + self._rng_state = self.seed + self.tick_history = [] + + for char_slot in self.characters.values(): + char_slot.current_state = CharacterState.IDLE + char_slot.segment_state = None + LOGGER.info("scheduler.reset.exit") + + +# ============================================================================ +# Deterministic Test Scenarios +# ============================================================================ + +def two_character_interaction_scenario() -> tuple[DeterministicLoop, List[dict]]: + """ + Test scenario: Two characters dancing with synchronized transitions. + + Returns: + (loop, motion_frames_per_char) + """ + loop = DeterministicLoop(fps=30, seed=42) + + # Register characters + loop.register_character("dancer1", "soma", priority=1) + loop.register_character("dancer2", "soma", priority=1) + + # Simulate 2 segments x 30 frames each = 60 frames total + motion_sequence = [ + { + "dancer1": {"action": "walk_forward", "frame": i} for i in range(30) + }, + { + "dancer2": {"action": "follow", "frame": i} for i in range(30) + }, + ] + + return loop, motion_sequence + + +def three_character_scenario() -> tuple[DeterministicLoop, List[dict]]: + """ + Test scenario: Three characters with controlled interactions. + + Returns: + (loop, motion_frames) + """ + loop = DeterministicLoop(fps=30, seed=43, conflict_policy=ConflictResolutionPolicy.PRIORITY_BASED) + + # Register with different priorities + loop.register_character("leader", "soma", priority=3) + loop.register_character("follower1", "soma", priority=2) + loop.register_character("follower2", "soma", priority=1) + + motion_sequence = [ + { + "leader": {"action": "lead", "frame": i}, + "follower1": {"action": "follow", "frame": i}, + "follower2": {"action": "match", "frame": i}, + } + for i in range(60) + ] + + return loop, motion_sequence + + +def test_deterministic_replay(): + """ + Verify deterministic replay: same seed produces identical state hashes. + """ + print("=== Card 3: Deterministic Loop Test ===\n") + + # Scenario 1: Two-character deterministic replay + print("Test 1: Two-character deterministic replay") + + loop1, motions1 = two_character_interaction_scenario() + loop2, motions2 = two_character_interaction_scenario() + + hashes1 = [] + hashes2 = [] + + for tick_num in range(60): + loop1.advance_tick({}) + loop2.advance_tick({}) + + hash1 = loop1.get_state_hash() + hash2 = loop2.get_state_hash() + + hashes1.append(hash1) + hashes2.append(hash2) + + if hashes1 == hashes2: + print("✓ Deterministic replay (2-char): PASS") + else: + print(f"✗ Deterministic replay (2-char): FAIL") + print(f" Mismatch at frame: {[i for i, (h1, h2) in enumerate(zip(hashes1, hashes2)) if h1 != h2]}") + + print() + + # Scenario 2: Three-character with priority conflict resolution + print("Test 2: Three-character priority-based conflict resolution") + + loop3, motions3 = three_character_scenario() + loop4, motions4 = three_character_scenario() + + hashes3 = [] + hashes4 = [] + + for tick_num in range(60): + loop3.advance_tick({}) + loop4.advance_tick({}) + + hash3 = loop3.get_state_hash() + hash4 = loop4.get_state_hash() + + hashes3.append(hash3) + hashes4.append(hash4) + + if hashes3 == hashes4: + print("✓ Deterministic replay (3-char): PASS") + else: + print(f"✗ Deterministic replay (3-char): FAIL") + + print() + + # Scenario 3: Different seed produces different hashes + print("Test 3: Different seed produces different outcome") + + loop_seed42, _ = two_character_interaction_scenario() + loop_seed99 = DeterministicLoop(fps=30, seed=99) + loop_seed99.register_character("dancer1", "soma", priority=1) + loop_seed99.register_character("dancer2", "soma", priority=1) + + hashes42 = [] + hashes99 = [] + + for tick_num in range(30): + loop_seed42.advance_tick({}) + loop_seed99.advance_tick({}) + + hashes42.append(loop_seed42.get_state_hash()) + hashes99.append(loop_seed99.get_state_hash()) + + if hashes42 != hashes99: + print("✓ Different seeds produce different outcomes: PASS") + else: + print("✗ Different seeds should differ: FAIL") + + print() + print("=== All Deterministic Tests Complete ===") + + +if __name__ == "__main__": + test_deterministic_replay() diff --git a/kimodo/schemas.py b/kimodo/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..8c035919cf39a7186407b0776904c7dda195a4fb --- /dev/null +++ b/kimodo/schemas.py @@ -0,0 +1,593 @@ +""" +Card 2: Service Contracts (Pydantic Schemas) + +Defines strict request/response contracts for: +1. Qwen Planner API (story prompt → validated motion script) +2. Kimodo Generator API (motion script → motion generation) + +All schemas include defensive validation, error codes, and examples. +""" + +from enum import Enum +from typing import List, Optional, Dict, Any +from pydantic import BaseModel, Field, field_validator +import json +import logging + + +LOGGER = logging.getLogger(__name__) + + +# ============================================================================ +# Enums +# ============================================================================ + +class TransitionPolicy(str, Enum): + """How to transition between motion segments.""" + SMOOTH = "smooth" # Blend final frame of A with initial frame of B + CUT = "cut" # Hard cut from A to B + HOLD = "hold" # Hold final pose of A before B starts + OVERLAP = "overlap" # Overlap A and B for N frames + + +class ConstraintType(str, Enum): + """Types of kinematic constraints.""" + POSITIONAL = "positional" # XYZ position constraints + ROTATIONAL = "rotational" # Joint angle constraints + VELOCITY = "velocity" # Movement speed limits + CONTACT = "contact" # Foot contact, hand placement + NONE = "none" + + +# ============================================================================ +# Planner API Schemas (Qwen LLM → Motion Script) +# ============================================================================ + +class CharacterDefinition(BaseModel): + """Definition of a character in the scene.""" + + character_id: str = Field( + ..., + min_length=1, + max_length=50, + pattern="^[a-zA-Z0-9_-]+$", + description="Unique identifier for this character (alphanumeric + _ -)." + ) + + skeleton_type: str = Field( + default="soma", + description="Skeleton rig type (soma, g1, smpl-x, etc.)" + ) + + description: Optional[str] = Field( + default=None, + max_length=200, + description="Brief character description (e.g., 'tall female dancer')." + ) + + @field_validator("character_id") + @classmethod + def validate_char_id(cls, v): + if not v or len(v) > 50: + raise ValueError("character_id must be 1-50 chars") + return v.strip() + + +class MotionSegment(BaseModel): + """A single motion segment for one character.""" + + segment_id: int = Field( + ..., + ge=0, + description="Sequence order (0-based) within character's script." + ) + + action_text: str = Field( + ..., + min_length=3, + max_length=500, + description="Natural language action description (e.g., 'walk forward with arms raised')." + ) + + duration_sec: float = Field( + default=2.0, + ge=0.5, + le=30.0, + description="Duration of this motion segment in seconds (0.5-30s)." + ) + + transition_policy: TransitionPolicy = Field( + default=TransitionPolicy.SMOOTH, + description="How to transition to the next segment." + ) + + interaction_target: Optional[str] = Field( + default=None, + description="Another character_id to interact with (optional)." + ) + + constraints: Optional[Dict[str, Any]] = Field( + default_factory=dict, + description="Kinematic constraints dict (e.g., {'floor_contact': True})." + ) + + @field_validator("action_text") + @classmethod + def validate_action_text(cls, v): + if not v or len(v) < 3: + raise ValueError("action_text must be >= 3 chars") + return v.strip() + + @field_validator("duration_sec") + @classmethod + def validate_duration(cls, v): + if not (0.5 <= v <= 30.0): + raise ValueError("duration_sec must be between 0.5 and 30.0 seconds") + return v + + +class PlannerRequest(BaseModel): + """Request from frontend to Qwen planner.""" + + scene_id: str = Field( + ..., + min_length=1, + max_length=100, + description="Unique identifier for this scene/request." + ) + + user_prompt: str = Field( + ..., + min_length=10, + max_length=2000, + description="High-level story or interaction prompt from user (10-2000 chars)." + ) + + characters: List[CharacterDefinition] = Field( + ..., + min_length=1, + max_length=10, + description="List of characters in the scene (1-10 characters)." + ) + + duration_limit_sec: float = Field( + default=60.0, + ge=10.0, + le=600.0, + description="Maximum total duration for the scene (10-600 seconds)." + ) + + interactive_mode: bool = Field( + default=False, + description="If True, planner may request user input for interactions." + ) + + @field_validator("scene_id") + @classmethod + def validate_scene_id(cls, v): + if not v or len(v) > 100: + raise ValueError("scene_id must be 1-100 chars") + return v.strip() + + @field_validator("user_prompt") + @classmethod + def validate_prompt(cls, v): + if not v or len(v) < 10: + raise ValueError("user_prompt must be >= 10 chars") + return v.strip() + + @field_validator("characters") + @classmethod + def validate_characters(cls, v): + if not v or len(v) > 10: + raise ValueError("Must have 1-10 characters") + # Check for duplicate character_ids + ids = [c.character_id for c in v] + if len(ids) != len(set(ids)): + raise ValueError("Duplicate character_id found") + return v + + +class PlannerResponse(BaseModel): + """Response from Qwen planner (validated motion script).""" + + scene_id: str = Field( + ..., + description="Echo of request scene_id." + ) + + status: str = Field( + default="success", + description="Planner status (success/partial/error).", + json_schema_extra={"enum": ["success", "partial", "error"]} + ) + + error_message: Optional[str] = Field( + default=None, + description="Error details if status != success." + ) + + scripts: Dict[str, List[MotionSegment]] = Field( + default_factory=dict, + description="Per-character motion scripts: {character_id: [segments]}." + ) + + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="Optional metadata (model name, version, timestamp, etc.)." + ) + + total_duration_sec: float = Field( + default=0.0, + ge=0.0, + description="Computed total duration of all characters combined." + ) + + @field_validator("total_duration_sec") + @classmethod + def validate_total_duration(cls, v): + if v < 0: + raise ValueError("total_duration_sec must be non-negative") + return v + + +# ============================================================================ +# Generator API Schemas (Motion Script → Kimodo Generation) +# ============================================================================ + +class GenerationConstraint(BaseModel): + """Per-character generation constraint.""" + + constraint_type: ConstraintType = Field( + default=ConstraintType.NONE, + description="Type of constraint to apply." + ) + + params: Dict[str, Any] = Field( + default_factory=dict, + description="Constraint-specific parameters (e.g., target position, velocity limit)." + ) + + priority: int = Field( + default=1, + ge=1, + le=10, + description="Priority level (1-10, higher = enforced more strictly)." + ) + + +class CharacterGenerationState(BaseModel): + """Per-character state for generation.""" + + character_id: str = Field( + ..., + description="Character identifier." + ) + + skeleton_type: str = Field( + default="soma", + description="Skeleton rig type." + ) + + segments: List[MotionSegment] = Field( + ..., + description="Motion segments for this character." + ) + + initial_pose: Optional[Dict[str, Any]] = Field( + default=None, + description="Optional initial pose (joint angles or transformation)." + ) + + constraints: Optional[List[GenerationConstraint]] = Field( + default=None, + description="List of kinematic constraints." + ) + + @field_validator("character_id") + @classmethod + def validate_id(cls, v): + if not v: + raise ValueError("character_id required") + return v.strip() + + +class GeneratorRequest(BaseModel): + """Request to Kimodo generator (from planner response).""" + + scene_id: str = Field( + ..., + description="Scene identifier." + ) + + characters: List[CharacterGenerationState] = Field( + ..., + min_length=1, + max_length=10, + description="Character states and motion segments." + ) + + seed: int = Field( + default=42, + ge=0, + description="Random seed for deterministic generation." + ) + + num_samples: int = Field( + default=1, + ge=1, + le=5, + description="Number of motion samples to generate (1-5)." + ) + + device: Optional[str] = Field( + default=None, + description="Compute device (cuda, rocm, cpu). If None, auto-detect." + ) + + @field_validator("characters") + @classmethod + def validate_chars(cls, v): + if not v or len(v) > 10: + raise ValueError("Must have 1-10 characters") + return v + + +class MotionOutput(BaseModel): + """Generated motion output for a single character.""" + + character_id: str = Field( + ..., + description="Character identifier." + ) + + motion_data: Dict[str, Any] = Field( + ..., + description="Motion data (NPZ converted to dict: posed_joints, rotation_mats, foot_contacts, etc.)." + ) + + duration_sec: float = Field( + default=0.0, + description="Actual duration of generated motion." + ) + + frame_count: int = Field( + default=0, + description="Number of frames in motion." + ) + + fps: int = Field( + default=30, + description="Frames per second." + ) + + quality_score: Optional[float] = Field( + default=None, + description="Optional quality metric (0-1)." + ) + + +class GeneratorResponse(BaseModel): + """Response from Kimodo generator.""" + + scene_id: str = Field( + ..., + description="Scene identifier." + ) + + status: str = Field( + default="success", + description="Generation status.", + json_schema_extra={"enum": ["success", "partial", "error"]} + ) + + error_message: Optional[str] = Field( + default=None, + description="Error details if status != success." + ) + + motions: List[MotionOutput] = Field( + default_factory=list, + description="Generated motions per character." + ) + + total_frames: int = Field( + default=0, + description="Total frames across all characters." + ) + + generation_time_sec: float = Field( + default=0.0, + description="Wall-clock time to generate (seconds)." + ) + + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="Metadata (model version, device, seed, etc.)." + ) + + +# ============================================================================ +# Validation Examples +# ============================================================================ + +def example_planner_request() -> PlannerRequest: + """Example valid planner request.""" + return PlannerRequest( + scene_id="scene_001", + user_prompt="Two dancers interact in the middle of a stage, one leads a waltz while the other follows.", + characters=[ + CharacterDefinition(character_id="dancer1", skeleton_type="soma", description="Lead dancer"), + CharacterDefinition(character_id="dancer2", skeleton_type="soma", description="Follow dancer"), + ], + duration_limit_sec=60.0, + interactive_mode=False + ) + + +def example_planner_response() -> PlannerResponse: + """Example valid planner response.""" + return PlannerResponse( + scene_id="scene_001", + status="success", + scripts={ + "dancer1": [ + MotionSegment( + segment_id=0, + action_text="Walk forward with arms extended in waltz position", + duration_sec=5.0, + transition_policy=TransitionPolicy.SMOOTH, + interaction_target="dancer2" + ), + MotionSegment( + segment_id=1, + action_text="Turn left while leading dancer2 in circular motion", + duration_sec=5.0, + transition_policy=TransitionPolicy.SMOOTH + ), + ], + "dancer2": [ + MotionSegment( + segment_id=0, + action_text="Follow dancer1 with arms extended, matching their tempo", + duration_sec=5.0, + transition_policy=TransitionPolicy.SMOOTH, + interaction_target="dancer1" + ), + MotionSegment( + segment_id=1, + action_text="Turn right while being led by dancer1", + duration_sec=5.0, + transition_policy=TransitionPolicy.SMOOTH + ), + ] + }, + metadata={"model": "Qwen2.5-7B-Instruct", "created_at": "2026-05-09T12:00:00Z"}, + total_duration_sec=10.0 + ) + + +def example_generator_request() -> GeneratorRequest: + """Example valid generator request.""" + planner_resp = example_planner_response() + resp_dict = planner_resp.model_dump() + + chars = [ + CharacterGenerationState( + character_id="dancer1", + skeleton_type="soma", + segments=resp_dict["scripts"]["dancer1"] + ), + CharacterGenerationState( + character_id="dancer2", + skeleton_type="soma", + segments=resp_dict["scripts"]["dancer2"] + ), + ] + + return GeneratorRequest( + scene_id="scene_001", + characters=chars, + seed=42, + num_samples=1 + ) + + +# ============================================================================ +# Test & Validation Functions +# ============================================================================ + +def validate_schema_examples(): + """Test all schemas with valid and invalid payloads.""" + LOGGER.info("schemas.validate_schema_examples.start") + + print("=== Card 2: Schema Validation Tests ===\n") + + # Test 1: Valid Planner Request + try: + req = example_planner_request() + print("✓ Valid Planner Request: PASS") + except Exception as e: + print(f"✗ Valid Planner Request: FAIL - {e}") + + # Test 2: Valid Planner Response + try: + resp = example_planner_response() + print("✓ Valid Planner Response: PASS") + except Exception as e: + print(f"✗ Valid Planner Response: FAIL - {e}") + + # Test 3: Valid Generator Request + try: + req = example_generator_request() + print("✓ Valid Generator Request: PASS") + except Exception as e: + print(f"✗ Valid Generator Request: FAIL - {e}") + + print() + + # Test 4: Invalid Planner Request (missing required field) + try: + PlannerRequest( + scene_id="test", + # Missing user_prompt - should fail + characters=[ + CharacterDefinition(character_id="c1") + ] + ) + print("✗ Invalid Request (missing user_prompt): FAIL - should have raised error") + except Exception as e: + print("✓ Invalid Request (missing user_prompt): PASS - correctly rejected") + + # Test 5: Invalid Planner Request (user_prompt too short) + try: + PlannerRequest( + scene_id="test", + user_prompt="short", # < 10 chars + characters=[ + CharacterDefinition(character_id="c1") + ] + ) + print("✗ Invalid Request (short prompt): FAIL - should have raised error") + except Exception as e: + print("✓ Invalid Request (short prompt): PASS - correctly rejected") + + # Test 6: Invalid character_id (special chars not allowed) + try: + CharacterDefinition(character_id="c1@invalid") + print("✗ Invalid character_id: FAIL - should have raised error") + except Exception as e: + print("✓ Invalid character_id: PASS - correctly rejected") + + # Test 7: Invalid duration_sec (out of range) + try: + MotionSegment( + segment_id=0, + action_text="test action", + duration_sec=60.0 # > 30.0 max + ) + print("✗ Invalid duration_sec: FAIL - should have raised error") + except Exception as e: + print("✓ Invalid duration_sec: PASS - correctly rejected") + + # Test 8: Duplicate character_ids + try: + PlannerRequest( + scene_id="test", + user_prompt="this is a long enough prompt", + characters=[ + CharacterDefinition(character_id="char1"), + CharacterDefinition(character_id="char1"), # Duplicate + ] + ) + print("✗ Duplicate character_ids: FAIL - should have raised error") + except Exception as e: + print("✓ Duplicate character_ids: PASS - correctly rejected") + + print() + print("=== All Schema Tests Complete ===") + LOGGER.info("schemas.validate_schema_examples.exit") + + +if __name__ == "__main__": + validate_schema_examples() diff --git a/kimodo/scripts/__init__.py b/kimodo/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/kimodo/scripts/bones_seed.py b/kimodo/scripts/bones_seed.py new file mode 100644 index 0000000000000000000000000000000000000000..d18e6405a653a2ea7a4a86d1236e4857cb96b8a2 --- /dev/null +++ b/kimodo/scripts/bones_seed.py @@ -0,0 +1,355 @@ +"""Browse and download files from the BONES SEED Hugging Face dataset repository.""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +from dataclasses import dataclass, asdict +from datetime import datetime, timezone +from pathlib import Path +from typing import Iterable, Sequence +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + +from huggingface_hub import HfApi, get_token, hf_hub_download +from huggingface_hub.errors import HfHubHTTPError + + +DEFAULT_REPO_ID = "bones-studio/seed" +DEFAULT_REPO_TYPE = "dataset" +DEFAULT_SPACE_ID = "lablab-ai-amd-developer-hackathon/movimento" + +LOGGER = logging.getLogger(__name__) + + +def _resolve_token(token: str | None = None) -> str | None: + LOGGER.info("bones_seed.resolve_token.start") + if token: + LOGGER.info("bones_seed.resolve_token.exit source=arg") + return token + for env_name in ("HUGGING_FACE_HUB_TOKEN", "HF_TOKEN", "HF_API_TOKEN"): + value = os.environ.get(env_name) + if value: + LOGGER.info("bones_seed.resolve_token.exit source=env var=%s", env_name) + return value + resolved = get_token() + LOGGER.info("bones_seed.resolve_token.exit source=cache found=%s", bool(resolved)) + return resolved + + +@dataclass(frozen=True) +class DownloadManifest: + repo_id: str + repo_type: str + revision: str | None + local_dir: str + files: list[str] + downloaded_at: str + + +@dataclass(frozen=True) +class SpaceLogCheckResult: + space_id: str + run_status_code: int + build_status_code: int + run_ok: bool + build_ok: bool + + +def list_repo_files( + repo_id: str = DEFAULT_REPO_ID, + *, + repo_type: str = DEFAULT_REPO_TYPE, + revision: str | None = None, + token: str | None = None, +) -> list[str]: + """Return all files in a Hugging Face dataset repository.""" + LOGGER.info("bones_seed.list_repo_files.start repo_id=%s revision=%s", repo_id, revision) + api = HfApi(token=_resolve_token(token)) + files = sorted(api.list_repo_files(repo_id=repo_id, repo_type=repo_type, revision=revision)) + LOGGER.info("bones_seed.list_repo_files.exit count=%s", len(files)) + return files + + +def download_repo_files( + filenames: Sequence[str], + *, + repo_id: str = DEFAULT_REPO_ID, + repo_type: str = DEFAULT_REPO_TYPE, + revision: str | None = None, + local_dir: str | Path = "bones_seed", + token: str | None = None, +) -> list[Path]: + """Download selected files from a Hugging Face dataset repository.""" + LOGGER.info("bones_seed.download_repo_files.start repo_id=%s files=%s", repo_id, len(filenames)) + resolved_token = _resolve_token(token) + output_dir = Path(local_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + downloaded: list[Path] = [] + for filename in filenames: + # Each file is downloaded independently so partial progress is visible in logs. + local_path = hf_hub_download( + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + token=resolved_token, + local_dir=output_dir, + ) + downloaded.append(Path(local_path)) + LOGGER.info("bones_seed.download_repo_files.exit downloaded=%s", len(downloaded)) + return downloaded + + +def download_by_prefix( + prefix: str, + *, + repo_id: str = DEFAULT_REPO_ID, + repo_type: str = DEFAULT_REPO_TYPE, + revision: str | None = None, + local_dir: str | Path = "bones_seed", + token: str | None = None, +) -> list[Path]: + """Download files matching a prefix from the repository listing.""" + LOGGER.info("bones_seed.download_by_prefix.start prefix=%s", prefix) + files = [name for name in list_repo_files(repo_id, repo_type=repo_type, revision=revision, token=token) if name.startswith(prefix)] + if not files: + raise ValueError(f"No files matched prefix '{prefix}' in {repo_id}.") + downloaded = download_repo_files( + files, + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + local_dir=local_dir, + token=token, + ) + LOGGER.info("bones_seed.download_by_prefix.exit matched=%s", len(downloaded)) + return downloaded + + +def write_manifest( + local_dir: str | Path, + files: Iterable[Path], + *, + repo_id: str = DEFAULT_REPO_ID, + repo_type: str = DEFAULT_REPO_TYPE, + revision: str | None = None, +) -> Path: + """Write a manifest that records what was downloaded.""" + LOGGER.info("bones_seed.write_manifest.start local_dir=%s", local_dir) + output_dir = Path(local_dir) + manifest = DownloadManifest( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + local_dir=str(output_dir), + files=[str(path) for path in files], + downloaded_at=datetime.now(timezone.utc).isoformat(), + ) + manifest_path = output_dir / "manifest.json" + manifest_path.write_text(json.dumps(asdict(manifest), indent=2, sort_keys=True) + "\n", encoding="utf-8") + LOGGER.info("bones_seed.write_manifest.exit path=%s", manifest_path) + return manifest_path + + +def upload_manifest_to_space( + manifest_path: str | Path, + *, + space_id: str = DEFAULT_SPACE_ID, + token: str | None = None, + path_in_repo: str = "data/bones_seed/manifest.json", + commit_message: str = "Update BONES-SEED ingestion manifest", + create_pr: bool = True, +) -> str: + """Upload manifest file into a Space repository path for lablab ingestion traceability.""" + LOGGER.info("bones_seed.upload_manifest_to_space.start space_id=%s", space_id) + manifest = Path(manifest_path) + if not manifest.exists(): + raise FileNotFoundError(f"Manifest file does not exist: {manifest}") + + api = HfApi(token=_resolve_token(token)) + try: + uploaded = api.upload_file( + path_or_fileobj=str(manifest), + path_in_repo=path_in_repo, + repo_id=space_id, + repo_type="space", + commit_message=commit_message, + create_pr=False, + ) + LOGGER.info("bones_seed.upload_manifest_to_space.exit mode=direct") + return uploaded + except HfHubHTTPError as exc: + if create_pr and "create_pr=1" in str(exc): + uploaded = api.upload_file( + path_or_fileobj=str(manifest), + path_in_repo=path_in_repo, + repo_id=space_id, + repo_type="space", + commit_message=commit_message, + create_pr=True, + ) + LOGGER.info("bones_seed.upload_manifest_to_space.exit mode=create_pr") + return uploaded + raise + + +def _check_logs_endpoint(url: str, token: str | None, timeout_sec: float) -> tuple[int, bool]: + LOGGER.info("bones_seed.check_logs_endpoint.start url=%s", url) + headers = {} + resolved = _resolve_token(token) + if resolved: + headers["Authorization"] = f"Bearer {resolved}" + request = Request(url=url, headers=headers, method="GET") + try: + with urlopen(request, timeout=timeout_sec) as response: + status = int(getattr(response, "status", 0)) + LOGGER.info("bones_seed.check_logs_endpoint.exit status=%s", status) + return status, 200 <= status < 300 + except HTTPError as exc: + LOGGER.warning("bones_seed.check_logs_endpoint.http_error status=%s", exc.code) + return int(exc.code), False + except URLError: + LOGGER.warning("bones_seed.check_logs_endpoint.network_error") + return 0, False + + +def verify_space_logs( + *, + space_id: str = DEFAULT_SPACE_ID, + token: str | None = None, + timeout_sec: float = 10.0, +) -> SpaceLogCheckResult: + """Verify build and runtime log endpoints are reachable for the target Space.""" + LOGGER.info("bones_seed.verify_space_logs.start space_id=%s", space_id) + base = f"https://huggingface.co/api/spaces/{space_id}/logs" + run_status, run_ok = _check_logs_endpoint(f"{base}/run", token, timeout_sec) + build_status, build_ok = _check_logs_endpoint(f"{base}/build", token, timeout_sec) + result = SpaceLogCheckResult( + space_id=space_id, + run_status_code=run_status, + build_status_code=build_status, + run_ok=run_ok, + build_ok=build_ok, + ) + LOGGER.info("bones_seed.verify_space_logs.exit run_ok=%s build_ok=%s", run_ok, build_ok) + return result + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Browse and download BONES SEED dataset files from Hugging Face.") + parser.add_argument( + "command", + choices=("list", "download", "prefix", "verify-logs"), + help="List files, download selected files, download files by prefix, or verify Space log endpoints.", + ) + parser.add_argument("files", nargs="*", help="Exact file paths inside the dataset repository.") + parser.add_argument("--repo-id", default=DEFAULT_REPO_ID, help="Hugging Face dataset repository id.") + parser.add_argument("--repo-type", default=DEFAULT_REPO_TYPE, help="Hugging Face repo type.") + parser.add_argument("--revision", default=None, help="Optional repository revision or branch.") + parser.add_argument("--local-dir", default="bones_seed", help="Directory where files will be stored.") + parser.add_argument("--token", default=None, help="Hugging Face token override.") + parser.add_argument("--prefix", default=None, help="File prefix to match when using the prefix command.") + parser.add_argument("--manifest", action="store_true", help="Write a manifest.json after download.") + parser.add_argument("--space-id", default=DEFAULT_SPACE_ID, help="Target Space id for manifest publish or logs checks.") + parser.add_argument( + "--space-manifest-path", + default="data/bones_seed/manifest.json", + help="Path inside target Space repo where manifest will be uploaded.", + ) + parser.add_argument( + "--publish-manifest-to-space", + action="store_true", + help="Upload generated manifest to the Space repo destination.", + ) + parser.add_argument( + "--space-upload-create-pr", + action="store_true", + help="Force upload as a PR in target Space repo when direct commits are forbidden.", + ) + parser.add_argument( + "--logs-timeout-sec", + type=float, + default=10.0, + help="Timeout for log endpoint verification requests.", + ) + return parser + + +def main(argv: Sequence[str] | None = None) -> int: + LOGGER.info("bones_seed.main.start") + parser = build_parser() + args = parser.parse_args(argv) + + if args.command == "list": + try: + for name in list_repo_files(args.repo_id, repo_type=args.repo_type, revision=args.revision, token=args.token): + print(name) + except BrokenPipeError: + LOGGER.info("bones_seed.main.exit broken_pipe") + return 0 + LOGGER.info("bones_seed.main.exit command=list") + return 0 + + if args.command == "verify-logs": + result = verify_space_logs(space_id=args.space_id, token=args.token, timeout_sec=args.logs_timeout_sec) + print(json.dumps(asdict(result), indent=2, sort_keys=True)) + LOGGER.info("bones_seed.main.exit command=verify-logs") + return 0 if (result.run_ok and result.build_ok) else 2 + + if args.command == "download": + if not args.files: + raise SystemExit("download requires at least one file path") + downloaded = download_repo_files( + args.files, + repo_id=args.repo_id, + repo_type=args.repo_type, + revision=args.revision, + local_dir=args.local_dir, + token=args.token, + ) + else: + if not args.prefix: + raise SystemExit("prefix requires --prefix") + downloaded = download_by_prefix( + args.prefix, + repo_id=args.repo_id, + repo_type=args.repo_type, + revision=args.revision, + local_dir=args.local_dir, + token=args.token, + ) + + for path in downloaded: + print(path) + + if args.manifest: + manifest_path = write_manifest( + args.local_dir, + downloaded, + repo_id=args.repo_id, + repo_type=args.repo_type, + revision=args.revision, + ) + print(manifest_path) + if args.publish_manifest_to_space: + uploaded = upload_manifest_to_space( + manifest_path, + space_id=args.space_id, + token=args.token, + path_in_repo=args.space_manifest_path, + create_pr=args.space_upload_create_pr, + ) + print(uploaded) + elif args.publish_manifest_to_space: + raise SystemExit("--publish-manifest-to-space requires --manifest") + + LOGGER.info("bones_seed.main.exit command=%s", args.command) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) \ No newline at end of file diff --git a/kimodo/scripts/docker-entrypoint.sh b/kimodo/scripts/docker-entrypoint.sh new file mode 100644 index 0000000000000000000000000000000000000000..db2246e558d0aad8a6b1c0c6241e429e85e7898d --- /dev/null +++ b/kimodo/scripts/docker-entrypoint.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash +set -euo pipefail + +HOST_UID="${HOST_UID:-}" +HOST_GID="${HOST_GID:-}" +HOST_USER="${HOST_USER:-user}" + +if [[ -z "${HOST_UID}" || -z "${HOST_GID}" ]]; then + if [[ -d /workspace ]]; then + HOST_UID="$(stat -c %u /workspace)" + HOST_GID="$(stat -c %g /workspace)" + else + HOST_UID="${HOST_UID:-1000}" + HOST_GID="${HOST_GID:-1000}" + fi +fi + +if ! getent group "${HOST_GID}" >/dev/null 2>&1; then + groupadd -g "${HOST_GID}" "${HOST_USER}" +fi + +if ! getent passwd "${HOST_UID}" >/dev/null 2>&1; then + useradd -m -u "${HOST_UID}" -g "${HOST_GID}" -s /bin/bash "${HOST_USER}" +fi + +exec gosu "${HOST_UID}:${HOST_GID}" "$@" diff --git a/kimodo/scripts/generate.py b/kimodo/scripts/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..cf92c64b3436f40d9b5274a869cb30f4245e5757 --- /dev/null +++ b/kimodo/scripts/generate.py @@ -0,0 +1,422 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +from typing import Any, Dict, Optional + +import torch + +from kimodo import DEFAULT_MODEL, load_model +from kimodo.constraints import load_constraints_lst +from kimodo.exports.motion_io import save_kimodo_npz +from kimodo.meta import load_prompts_from_meta +from kimodo.model.cfg import CFG_TYPES +from kimodo.model.registry import get_model_info +from kimodo.runtime import runtime_health_report +from kimodo.tools import load_json, seed_everything + + +def parse_args(): + parser = argparse.ArgumentParser(description="Cmd line API for generation motions with kimodo") + parser.add_argument( + "prompt", + nargs="?", + type=str, + default=None, + help="Text prompt describing the motion to generate, or several prompts separated by periods.", + ) + parser.add_argument( + "--model", + type=str, + default=DEFAULT_MODEL, + help="Name of the model (e.g. Kimodo-SOMA-RP-v1, etc).", + ) + parser.add_argument( + "--duration", + type=str, + default="5.0", + help="Duration in seconds (default: 5.0). Separate by spaces in a string for different durations per prompts", + ) + parser.add_argument( + "--num_samples", + type=int, + default=1, + help="Number of samples to generate (default: 1)", + ) + parser.add_argument( + "--diffusion_steps", + type=int, + default=100, + help="Number of diffusion steps (default: 100)", + ) + parser.add_argument( + "--num_transition_frames", + type=int, + default=5, + help="Number of frames to help transitioning (default: 5)", + ) + parser.add_argument( + "--constraints", + type=str, + default=None, + help="Saved constraint list", + ) + parser.add_argument( + "--output", + type=str, + default="output", + help="Output stem name: with one sample writes a single file per format (e.g. test.npz, test.csv); with multiple samples creates a folder and writes test_00.npz, test_01.npz, ... inside it. Used for NPZ, AMASS NPZ, CSV, and BVH.", + ) + parser.add_argument( + "--bvh", + action="store_true", + help="Also export BVH (SOMA models only); uses the same stem as --output.", + ) + parser.add_argument( + "--no-postprocess", + action="store_true", + help="Don't apply motion post-processing to reduce foot skating (ignored for G1)", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Seed for reproducible results", + ) + parser.add_argument( + "--input_folder", + type=str, + default=None, + help="Folder containing meta.json and optional constraints.json. If set, generation settings are loaded from meta.json.", + ) + parser.add_argument( + "--cfg_type", + type=str, + default=argparse.SUPPRESS, + choices=CFG_TYPES, + help=( + "Classifier-free guidance mode: nocfg (no CFG), regular (single scale on cond vs uncond), " + "or separated (custom: separate text and constraint scales). " + "Use with --cfg_weight as required by the mode." + ), + ) + parser.add_argument( + "--cfg_weight", + type=float, + nargs="*", + default=argparse.SUPPRESS, + help=( + "CFG scale(s): one float for regular, or two floats [text_weight, constraint_weight] for separated. " + "Omit with --cfg_type nocfg. If omitted, two floats alone imply separated; one float alone implies regular." + ), + ) + return parser.parse_args() + + +def get_texts_and_num_frames_from_prompt(prompt: str, duration: str, fps: float): + # Get the texts + texts = [text.strip() for text in prompt.split(".")] + texts = [text + "." for text in texts if text] + + nb_prompts = len(texts) + + # Get the durations + if " " not in duration: + duration_sec = float(duration) + # same for all the prompts + num_frames = [int(duration_sec * fps)] * nb_prompts + else: + durations = duration.split(" ") + assert len(durations) == len(texts), "The number of durations should match the number of prompts" + num_frames = [int(float(duration.strip()) * fps) for duration in durations] + assert len(num_frames) == nb_prompts, "The number of durations should be 1 or match the number of texts" + + return texts, num_frames + + +def _single_file_path(path: str, ext: str) -> str: + """Return path for a single output file (no folder). + + Adds ext if missing; creates parent dirs if any. + """ + if not path.endswith(ext): + path = path.rstrip(os.sep) + ext + parent = os.path.dirname(path) + if parent: + os.makedirs(parent, exist_ok=True) + return path + + +def _output_dir_and_path(path: str, default_base: str, ext: str): + """Create output folder from path and return (dir_path, path_for_file_with_suffix, base_name). + + If path has an extension, folder name is the path stem; else the path is the folder name. + base_name is the folder basename for _00, _01, ... when n_samples > 1. + """ + folder = os.path.splitext(path)[0] if os.path.splitext(path)[1] else path + os.makedirs(folder, exist_ok=True) + base_name = os.path.basename(folder.rstrip(os.sep)) + return folder, os.path.join(folder, default_base + ext), base_name + + +def resolve_cfg_kwargs(args: argparse.Namespace, meta: Optional[Dict[str, Any]]) -> Dict[str, Any]: + """Resolve cfg_type / cfg_weight for model(...). + + Precedence: explicit CLI (--cfg_type / --cfg_weight) overrides meta.json ``cfg``; + if neither applies, returns {} so the model uses its own defaults. + """ + ns = vars(args) + has_type = "cfg_type" in ns + has_wflag = "cfg_weight" in ns + cli_type = ns.get("cfg_type") + cli_w = ns.get("cfg_weight") + + if has_wflag: + if cli_w is None or len(cli_w) == 0: + raise ValueError("--cfg_weight requires one float (regular) or two floats (separated).") + + if has_type and cli_type == "nocfg": + if has_wflag: + raise ValueError("--cfg_weight is not used with --cfg_type nocfg.") + return {"cfg_type": "nocfg"} + + if has_type or has_wflag: + if has_type: + eff_type = cli_type + if has_wflag: + if eff_type == "regular" and len(cli_w) != 1: + raise ValueError("--cfg_type regular requires exactly one --cfg_weight value.") + if eff_type == "separated" and len(cli_w) != 2: + raise ValueError("--cfg_type separated requires exactly two --cfg_weight values.") + else: + if eff_type == "regular": + raise ValueError("--cfg_type regular requires --cfg_weight with one float.") + if eff_type == "separated": + raise ValueError("--cfg_type separated requires --cfg_weight with two floats.") + else: + if len(cli_w) == 1: + eff_type = "regular" + elif len(cli_w) == 2: + eff_type = "separated" + else: + raise ValueError("--cfg_weight expects 1 float (regular) or 2 floats (separated).") + + if eff_type == "regular": + return {"cfg_type": "regular", "cfg_weight": float(cli_w[0])} + return {"cfg_type": "separated", "cfg_weight": [float(cli_w[0]), float(cli_w[1])]} + + if meta and isinstance(meta.get("cfg"), dict): + cfg = meta["cfg"] + enabled = cfg.get("enabled", True) + if not enabled: + return {"cfg_type": "nocfg"} + return { + "cfg_type": "separated", + "cfg_weight": [ + float(cfg.get("text_weight", 2.0)), + float(cfg.get("constraint_weight", 2.0)), + ], + } + + return {} + + +def get_generation_inputs(args, fps: float): + """Get texts/num_frames and parameter overrides from either CLI or input_folder.""" + if args.input_folder is None: + if not args.prompt: + raise ValueError("Either provide 'prompt' or '--input_folder'.") + texts, num_frames = get_texts_and_num_frames_from_prompt(args.prompt, args.duration, fps) + return { + "texts": texts, + "num_frames": num_frames, + "num_samples": args.num_samples, + "diffusion_steps": args.diffusion_steps, + "seed": args.seed, + "constraints_path": args.constraints, + "meta": None, + } + + meta_path = os.path.join(args.input_folder, "meta.json") + meta = load_json(meta_path) + texts, durations_sec = load_prompts_from_meta(meta_path) + num_frames = [int(float(duration) * fps) for duration in durations_sec] + + constraints_path = args.constraints + default_constraints_path = os.path.join(args.input_folder, "constraints.json") + if constraints_path is None and os.path.exists(default_constraints_path): + constraints_path = default_constraints_path + + return { + "texts": texts, + "num_frames": num_frames, + "num_samples": meta.get("num_samples", args.num_samples), + "diffusion_steps": meta.get("diffusion_steps", args.diffusion_steps), + "seed": meta.get("seed", args.seed), + "constraints_path": constraints_path, + "meta": meta, + } + + +def main(): + requested_device = os.environ.get("KIMODO_DEVICE") + report = runtime_health_report(requested_device) + device = report.selected_device + print( + "Runtime health: " + f"requested={report.requested_device} " + f"selected={report.selected_device} " + f"backend={report.backend} " + f"reason={report.reason}" + ) + + args = parse_args() + + # Load model (resolution of name done inside load_model) + model, resolved_model = load_model( + args.model, + device=device, + default_family="Kimodo", + return_resolved_name=True, + ) + info = get_model_info(resolved_model) + display = info.display_name if info else resolved_model + print(f"Loaded model: {display} ({resolved_model})") + + # Get generation inputs + generation_inputs = get_generation_inputs(args, model.fps) + texts = generation_inputs["texts"] + num_frames = generation_inputs["num_frames"] + print("Will generate motions with the following prompts") + for text, num_frame in zip(texts, num_frames): + print(f" '{text}' with {num_frame} frames") + + # Load constraints + constraints_path = generation_inputs["constraints_path"] + if constraints_path: + constraint_lst = load_constraints_lst(constraints_path, model.skeleton) + else: + constraint_lst = [] + + if constraint_lst: + print(f"Using {len(constraint_lst)} set of constraints") + for constraint in constraint_lst: + print(f" {constraint}") + + if generation_inputs["seed"] is not None: + seed_everything(generation_inputs["seed"]) + + cfg_kwargs = resolve_cfg_kwargs(args, generation_inputs.get("meta")) + if cfg_kwargs: + ct = cfg_kwargs.get("cfg_type") + cw = cfg_kwargs.get("cfg_weight") + if cw is not None: + print(f"Using CFG: cfg_type={ct!r}, cfg_weight={cw!r}") + else: + print(f"Using CFG: cfg_type={ct!r}") + + # G1: postprocessing is disabled (does not work well for this model). + use_postprocess = False if "g1" in resolved_model else (not args.no_postprocess) + output = model( + texts, + num_frames, + constraint_lst=constraint_lst, + num_denoising_steps=generation_inputs["diffusion_steps"], + num_samples=generation_inputs["num_samples"], + multi_prompt=True, + num_transition_frames=args.num_transition_frames, + post_processing=use_postprocess, + return_numpy=True, + **cfg_kwargs, + ) + + n_samples = int(output["posed_joints"].shape[0]) + # Parse the output stem once; all formats (NPZ, AMASS NPZ, CSV, BVH) use this base name. + output_base = args.output + + if n_samples == 1: + npz_path = _single_file_path(output_base, ".npz") + print(f"Saving the npz output to {npz_path}") + single = { + k: (v[0] if hasattr(v, "shape") and len(v.shape) > 0 and v.shape[0] == n_samples else v) + for k, v in output.items() + } + save_kimodo_npz(npz_path, single) + else: + out_dir, _, base_name = _output_dir_and_path(output_base, "motion", ".npz") + print(f"Saving the npz output to {out_dir}/ ({base_name}_00.npz ...)") + for i in range(n_samples): + single = { + k: (v[i] if hasattr(v, "shape") and len(v.shape) > 0 and v.shape[0] == n_samples else v) + for k, v in output.items() + } + save_kimodo_npz(os.path.join(out_dir, f"{base_name}_{i:02d}.npz"), single) + + if resolved_model == "kimodo-smplx-rp": + from kimodo.exports.smplx import AMASSConverter + + converter = AMASSConverter(skeleton=model.skeleton, fps=model.fps) + if n_samples == 1: + # Use distinct name so AMASS NPZ does not overwrite the main NPZ + amass_single_path = _single_file_path(output_base + "_amass", ".npz") + print(f"Saving the amass output to {amass_single_path}") + converter.convert_save_npz(output, amass_single_path) + else: + out_dir, _, base_name = _output_dir_and_path(output_base, "amass", ".npz") + print(f"Saving the amass output to {out_dir}/ (amass_00.npz ...)") + converter.convert_save_npz(output, os.path.join(out_dir, "amass.npz")) + + if resolved_model == "kimodo-g1-rp": + from kimodo.exports.mujoco import MujocoQposConverter + + converter = MujocoQposConverter(model.skeleton) + qpos = converter.dict_to_qpos(output, device) + if n_samples == 1: + csv_path = _single_file_path(output_base, ".csv") + print(f"Saving the csv output to {csv_path}") + converter.save_csv(qpos, csv_path) + else: + out_dir, _, base_name = _output_dir_and_path(output_base, "qpos", ".csv") + print(f"Saving the csv output to {out_dir}/ ({base_name}_00.csv ...)") + converter.save_csv(qpos, os.path.join(out_dir, base_name + ".csv")) + + if args.bvh: + skeleton = model.skeleton + if "somaskel" not in skeleton.name: + print("BVH export is only supported for SOMA skeletons. Skipping --bvh.") + else: + from kimodo.exports.bvh import save_motion_bvh + from kimodo.skeleton import SOMASkeleton30, global_rots_to_local_rots + + if isinstance(skeleton, SOMASkeleton30): + # Motion has already been converted to somaskel77 within the model for output + skeleton = skeleton.somaskel77.to(device) + + if n_samples == 1: + bvh_path = _single_file_path(output_base, ".bvh") + print(f"Saving the BVH output to {bvh_path}") + joints_pos = torch.from_numpy(output["posed_joints"][0]).to(device) + joints_rot = torch.from_numpy(output["global_rot_mats"][0]).to(device) + local_rot_mats = global_rots_to_local_rots(joints_rot, skeleton) + root_positions = joints_pos[:, skeleton.root_idx, :] + save_motion_bvh(bvh_path, local_rot_mats, root_positions, skeleton=skeleton, fps=model.fps) + else: + out_dir, _, base_name = _output_dir_and_path(output_base, "motion", ".bvh") + print(f"Saving the BVH output to {out_dir}/ ({base_name}_00.bvh ...)") + for i in range(n_samples): + joints_pos = torch.from_numpy(output["posed_joints"][i]).to(device) + joints_rot = torch.from_numpy(output["global_rot_mats"][i]).to(device) + local_rot_mats = global_rots_to_local_rots(joints_rot, skeleton) + root_positions = joints_pos[:, skeleton.root_idx, :] + save_motion_bvh( + os.path.join(out_dir, f"{base_name}_{i:02d}.bvh"), + local_rot_mats, + root_positions, + skeleton=skeleton, + fps=model.fps, + ) + + +if __name__ == "__main__": + main() diff --git a/kimodo/scripts/gradio_theme.py b/kimodo/scripts/gradio_theme.py new file mode 100644 index 0000000000000000000000000000000000000000..2a11e50bfab6b12d4f17e7bd0d07c81705123a32 --- /dev/null +++ b/kimodo/scripts/gradio_theme.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import gradio as gr + + +def get_gradio_theme(remove_gradio_footer=False): + theme = gr.themes.Base( + primary_hue="blue", + text_size=gr.themes.Size(lg="16px", md="14px", sm="12px", xl="22px", xs="10px", xxl="35px", xxs="9px"), + font=[ + gr.themes.GoogleFont("Source Sans Pro"), + "BlinkMacSystemFont", + "Segoe UI", + "Roboto", + ], + ).set( + body_text_color="*neutral_900", + body_text_color_subdued="*neutral_500", + body_text_color_subdued_dark="*neutral_500", + ) + + css = """ + @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600;700;900&display=swap'); + + /* Base text */ + body, .gradio-container { + font-family: 'Source Sans Pro', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen-Sans, Ubuntu, Cantarell, 'Helvetica Neue', sans-serif !important; + font-size: 16px !important; + } + + h1 { + // font-family: 'Source Sans Pro', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif !important; + font-weight: 700 !important; + font-size: 2.75rem !important; + // margin: 0px; + padding: 1.5rem 0px 0px 0px; + // line-height: 1.2; + } + h2 { + // font-family: 'Source Sans Pro', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif !important; + font-weight: 600 !important; + font-size: 1.5rem !important; + } + """ + if remove_gradio_footer: + css += """ + footer { + display: none !important; + } + """ + return theme, css diff --git a/kimodo/scripts/lock_requirements.py b/kimodo/scripts/lock_requirements.py new file mode 100755 index 0000000000000000000000000000000000000000..0313e8b3ca4fff287cab29f1037d4db7bc6479a7 --- /dev/null +++ b/kimodo/scripts/lock_requirements.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Regenerate `docker_requirements.txt` from `docker_requirements.in` using `uv`, targeting the +Docker image runtime, and filter out `torch` + CUDA wheels so Docker doesn't try to reinstall +PyTorch. + +Usage: + python3 kimodo/scripts/lock_requirements.py + +Optional args: + --python-version 3.10 + --python-platform x86_64-manylinux2014 + --in docker_requirements.in + --out docker_requirements.txt +""" + +import argparse +import shutil +import subprocess +from pathlib import Path +from typing import Iterable + +DEFAULT_PYTHON_VERSION = "3.10" +DEFAULT_PYTHON_PLATFORM = "x86_64-manylinux2014" + +# Packages to omit from the lockfile because the Docker base image already provides torch+CUDA. +OMIT_NAMES = {"torch", "triton", "networkx", "sympy", "mpmath"} +OMIT_PREFIXES = ("nvidia-",) + + +def _run(cmd: list[str]) -> None: + print("+", " ".join(cmd)) + subprocess.run(cmd, check=True) + + +def _ensure_uv() -> None: + if shutil.which("uv") is None: + raise SystemExit( + "ERROR: `uv` is not installed or not on PATH.\n" + "Install it (one of):\n" + " - pipx install uv\n" + " - python -m pip install --user uv\n" + "Then rerun this script." + ) + + +def _parse_req_name(line: str) -> str: + # uv emits `name==version` lines. + s = line.strip() + if "==" in s: + return s.split("==", 1)[0].strip() + # Fallback: treat the whole token before space as name. + return s.split()[0].strip() + + +def _iter_blocks(lines: list[str]) -> Iterable[list[str]]: + """Split a docker_requirements.txt into blocks: [top-level req line + indented comments].""" + i = 0 + n = len(lines) + while i < n: + line = lines[i] + # Header/comments/blank + if line.startswith("#") or line.strip() == "": + yield [line] + i += 1 + continue + + # Top-level requirement line + if not line.startswith(" "): + block = [line] + i += 1 + while i < n and (lines[i].startswith(" ") or lines[i].strip() == "" or lines[i].startswith("#")): + # Stop if we hit another top-level requirement line + if not lines[i].startswith(" ") and not lines[i].startswith("#") and lines[i].strip() != "": + break + block.append(lines[i]) + i += 1 + yield block + continue + + # Indented line without a requirement header (shouldn't happen, but keep) + yield [line] + i += 1 + + +def _should_omit(req_line: str) -> bool: + name = _parse_req_name(req_line) + if name in OMIT_NAMES: + return True + for pfx in OMIT_PREFIXES: + if name.startswith(pfx): + return True + return False + + +def filter_lockfile(path: Path) -> None: + lines = path.read_text(encoding="utf-8").splitlines(True) + out: list[str] = [] + + inserted_note = False + for block in _iter_blocks(lines): + first = block[0] + + # After the uv header lines, insert a short note once. + if (not inserted_note) and first.startswith("# This file was autogenerated by uv"): + out.extend(block) + out.append( + "# NOTE: `torch` (and its CUDA wheels) are intentionally omitted from this lockfile.\n" + "# The Docker base image (nvcr.io/nvidia/pytorch) already provides a tested PyTorch build.\n" + "#\n" + ) + inserted_note = True + continue + + if first.startswith("#") or first.strip() == "": + out.extend(block) + continue + + if _should_omit(first): + continue + + out.extend(block) + + path.write_text("".join(out), encoding="utf-8") + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--in", dest="in_file", default="docker_requirements.in") + ap.add_argument("--out", dest="out_file", default="docker_requirements.txt") + ap.add_argument("--python-version", default=DEFAULT_PYTHON_VERSION) + ap.add_argument("--python-platform", default=DEFAULT_PYTHON_PLATFORM) + args = ap.parse_args() + + _ensure_uv() + + in_path = Path(args.in_file) + out_path = Path(args.out_file) + if not in_path.exists(): + raise SystemExit(f"ERROR: missing {in_path}") + + _run( + [ + "uv", + "pip", + "compile", + "-U", + str(in_path), + "-o", + str(out_path), + "--python-version", + args.python_version, + "--python-platform", + args.python_platform, + ] + ) + filter_lockfile(out_path) + print(f"OK: wrote {out_path} (filtered torch/CUDA wheels)") + + +if __name__ == "__main__": + main() diff --git a/kimodo/scripts/motion_convert.py b/kimodo/scripts/motion_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5a2263107cb3e4c0267a51330f45a23de206db --- /dev/null +++ b/kimodo/scripts/motion_convert.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""CLI entry-point for motion format conversion. + +Library conversion logic lives in :mod:`kimodo.exports.motion_convert_lib`. +Format detection utilities live in :mod:`kimodo.exports.motion_formats`. +""" + +from __future__ import annotations + +import argparse +import sys + +from kimodo.exports.motion_convert_lib import convert_motion_files + + +def run_convert( + input_path: str, + output_path: str, + from_fmt: str | None, + to_fmt: str | None, + source_fps: float | None, + z_up: bool, + mujoco_rest_zero: bool, +) -> None: + """Thin wrapper kept for backward compatibility; delegates to :func:`convert_motion_files`.""" + convert_motion_files( + input_path, + output_path, + from_fmt=from_fmt, + to_fmt=to_fmt, + source_fps=source_fps, + z_up=z_up, + mujoco_rest_zero=mujoco_rest_zero, + ) + + +def build_argparser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + description="Convert Kimodo NPZ, AMASS NPZ, SOMA BVH, and G1 MuJoCo CSV.", + ) + p.add_argument("input", help="Input file path") + p.add_argument("output", help="Output file path") + p.add_argument( + "--from", + dest="from_fmt", + choices=("amass", "kimodo", "soma-bvh", "g1-csv"), + default=None, + help="Input format (default: infer from file contents/extension)", + ) + p.add_argument( + "--to", + dest="to_fmt", + choices=("kimodo", "amass", "soma-bvh", "g1-csv"), + default=None, + help="Output format (default: infer from output extension)", + ) + p.add_argument( + "--source-fps", + "--fps", + dest="source_fps", + type=float, + default=None, + help=( + "Source motion frame rate in Hz (default: auto-detected from " + "BVH Frame Time / AMASS mocap_frame_rate, or 30 Hz). " + "Kimodo NPZ output is always resampled to 30 Hz." + ), + ) + p.add_argument( + "--no-z-up", + action="store_true", + help="For AMASS paths: disable Z-up transform (treat trans/orient as already Kimodo Y-up).", + ) + p.add_argument( + "--mujoco-rest-zero", + action="store_true", + default=False, + help="For G1 CSV: joint angles relative to MuJoCo rest (must match export).", + ) + return p + + +def main(argv: list[str] | None = None) -> int: + args = build_argparser().parse_args(argv) + try: + convert_motion_files( + args.input, + args.output, + from_fmt=args.from_fmt, + to_fmt=args.to_fmt, + source_fps=args.source_fps, + z_up=not args.no_z_up, + mujoco_rest_zero=args.mujoco_rest_zero, + ) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/kimodo/scripts/mujoco_load.py b/kimodo/scripts/mujoco_load.py new file mode 100644 index 0000000000000000000000000000000000000000..b3e201ac84b37ab65b4a885b9484192eb79a1ac5 --- /dev/null +++ b/kimodo/scripts/mujoco_load.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import time + +import mujoco +import mujoco.viewer +import numpy as np + +from kimodo.assets import skeleton_asset_path + +qpos = np.loadtxt("motion.csv", delimiter=",") +model = mujoco.MjModel.from_xml_path(str(skeleton_asset_path("g1skel34", "xml", "g1.xml"))) +data = mujoco.MjData(model) + +fps = 30 # adjust to your intended playback rate + +with mujoco.viewer.launch_passive(model, data) as viewer: + # loop the motion + while viewer.is_running(): + for frame in qpos: + data.qpos[:] = frame + mujoco.mj_forward(model, data) + viewer.sync() + time.sleep(1.0 / fps) diff --git a/kimodo/scripts/qwen_planner.py b/kimodo/scripts/qwen_planner.py new file mode 100644 index 0000000000000000000000000000000000000000..3ff0b98062cb5e468ceb521a180cff22c0647b26 --- /dev/null +++ b/kimodo/scripts/qwen_planner.py @@ -0,0 +1,51 @@ +"""CLI entrypoint for Qwen planner adapter.""" + +from __future__ import annotations + +import argparse +import json +from typing import Sequence + +from kimodo.planner import QwenPlannerAdapter +from kimodo.schemas import CharacterDefinition, PlannerRequest + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Generate planner scripts from a story prompt using Qwen.") + parser.add_argument("--scene-id", required=True, help="Scene identifier.") + parser.add_argument("--prompt", required=True, help="User story prompt.") + parser.add_argument( + "--character", + action="append", + required=True, + help="Character definition in form id[:skeleton[:description]]. Repeat for multiple characters.", + ) + parser.add_argument("--duration-limit-sec", type=float, default=60.0) + return parser + + +def _parse_character_arg(raw: str) -> CharacterDefinition: + parts = raw.split(":", 2) + character_id = parts[0] + skeleton_type = parts[1] if len(parts) >= 2 and parts[1] else "soma" + description = parts[2] if len(parts) == 3 and parts[2] else None + return CharacterDefinition(character_id=character_id, skeleton_type=skeleton_type, description=description) + + +def main(argv: Sequence[str] | None = None) -> int: + args = build_parser().parse_args(argv) + request = PlannerRequest( + scene_id=args.scene_id, + user_prompt=args.prompt, + duration_limit_sec=args.duration_limit_sec, + characters=[_parse_character_arg(item) for item in args.character], + ) + + adapter = QwenPlannerAdapter() + response = adapter.plan(request) + print(json.dumps(response.model_dump(), indent=2)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/kimodo/scripts/run_text_encoder_server.py b/kimodo/scripts/run_text_encoder_server.py new file mode 100644 index 0000000000000000000000000000000000000000..3060675a89716ca3b770f21bd87e85f862b5d7f4 --- /dev/null +++ b/kimodo/scripts/run_text_encoder_server.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os + +import gradio as gr +import numpy as np + +from kimodo.model import resolve_target + +from .gradio_theme import get_gradio_theme + +os.environ["HF_ENABLE_PARALLEL_LOADING"] = "YES" +DEFAULT_TEXT = "A person walks and falls to the ground." +DEFAULT_SERVER_NAME = "0.0.0.0" +DEFAULT_SERVER_PORT = 9550 +DEFAULT_TMP_FOLDER = "/tmp/text_encoder/" +DEFAULT_TEXT_ENCODER = "llm2vec" +TEXT_ENCODER_PRESETS = { + "llm2vec": { + "target": "kimodo.model.LLM2VecEncoder", + "kwargs": { + "base_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp", + "peft_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised", + "dtype": "bfloat16", + "llm_dim": 4096, + }, + "display_name": "LLM2Vec", + } +} + + +class DemoWrapper: + def __init__(self, text_encoder_name, tmp_folder): + self.text_encoder_name = text_encoder_name + self.text_encoder = None + self.init_error = None + self.tmp_folder = tmp_folder + + def _get_text_encoder(self): + if self.text_encoder is not None: + return self.text_encoder + if self.init_error is not None: + raise RuntimeError(self.init_error) + try: + self.text_encoder = _build_text_encoder(self.text_encoder_name) + return self.text_encoder + except Exception as error: + self.init_error = error + raise + + def __call__(self, text, filename, progress=gr.Progress()): + try: + text_encoder = self._get_text_encoder() + except Exception as error: + output_title = gr.Markdown(visible=True, value="## Encoder initialization failed") + output_text = gr.Markdown( + visible=True, + value=( + "Text encoder could not initialize. " + "If you use gated Hugging Face models, configure a valid HF token in the runtime env.\n\n" + f"Error: `{type(error).__name__}: {error}`" + ), + ) + download = gr.DownloadButton(visible=False) + return download, output_title, output_text + + # Compute text embedding + tensor, length = text_encoder(text) + embedding = tensor[:length] + embedding = embedding.cpu().numpy() + + # Save text embedding + path = os.path.join(self.tmp_folder, filename) + np.save(path, embedding) + + output_title = gr.Markdown(visible=True) + output_text = gr.Markdown(visible=True, value=f"Text: {text}") + download = gr.DownloadButton(visible=True, value=path) + return download, output_title, output_text + + +def _get_env(name: str, default): + return os.getenv(name, default) + + +def _build_text_encoder(name: str): + if name not in TEXT_ENCODER_PRESETS: + available = ", ".join(sorted(TEXT_ENCODER_PRESETS)) + raise ValueError(f"Unknown TEXT_ENCODER='{name}'. Available: {available}") + preset = TEXT_ENCODER_PRESETS[name] + target_cls = resolve_target(preset["target"]) + return target_cls(**preset["kwargs"]) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run text encoder Gradio server.") + parser.add_argument( + "--text-encoder", + default=_get_env("TEXT_ENCODER", DEFAULT_TEXT_ENCODER), + choices=sorted(TEXT_ENCODER_PRESETS.keys()), + help="Text encoder preset.", + ) + parser.add_argument( + "--tmp-folder", + default=_get_env("TEXT_ENCODER_TMP_FOLDER", DEFAULT_TMP_FOLDER), + ) + return parser.parse_args() + + +def main(): + args = parse_args() + server_name = _get_env("GRADIO_SERVER_NAME", DEFAULT_SERVER_NAME) + server_port = int(os.environ.get("GRADIO_SERVER_PORT") or os.environ.get("PORT", str(DEFAULT_SERVER_PORT))) + theme, css = get_gradio_theme() + os.makedirs(args.tmp_folder, exist_ok=True) + display_name = TEXT_ENCODER_PRESETS[args.text_encoder]["display_name"] + + # Suppress model loading during DemoWrapper initialization to allow graceful degradation + # Model will be loaded lazily on first request + demo_wrapper_fn = DemoWrapper(args.text_encoder, args.tmp_folder) + + with gr.Blocks(title="Text encoder", css=css, theme=theme) as demo: + gr.Markdown(f"# Text encoder: {display_name}") + gr.Markdown("## Description") + gr.Markdown("Get a embeddings from a text.") + + gr.Markdown("## Inputs") + with gr.Row(): + text = gr.Textbox( + placeholder="Type the motion you want to generate with a sentence", + show_label=True, + label="Text prompt", + value=DEFAULT_TEXT, + type="text", + ) + with gr.Row(scale=3): + with gr.Column(scale=1): + btn = gr.Button("Encode", variant="primary") + with gr.Column(scale=1): + clear = gr.Button("Clear", variant="secondary") + with gr.Column(scale=3): + pass + + output_title = gr.Markdown("## Outputs", visible=False) + output_text = gr.Markdown("", visible=False) + with gr.Row(scale=3): + with gr.Column(scale=1): + download = gr.DownloadButton("Download", variant="primary", visible=False) + with gr.Column(scale=4): + pass + + filename = gr.Textbox( + visible=False, + value="embedding.npy", + ) + + def clear_fn(): + return [ + gr.DownloadButton(visible=False), + gr.Markdown(visible=False), + gr.Markdown(visible=False), + ] + + outputs = [download, output_title, output_text] + + gr.on( + triggers=[text.submit, btn.click], + fn=clear_fn, + inputs=None, + outputs=outputs, + ).then( + fn=demo_wrapper_fn, + inputs=[text, filename], + outputs=outputs, + ) + + def download_file(): + return gr.DownloadButton() + + download.click( + fn=download_file, + inputs=None, + outputs=[download], + ) + clear.click(fn=clear_fn, inputs=None, outputs=outputs) + + demo.launch(server_name=server_name, server_port=server_port) + + +if __name__ == "__main__": + main() diff --git a/kimodo/scripts/runtime_health.py b/kimodo/scripts/runtime_health.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ecbc5bf29e577efe712f315d626920ae257526 --- /dev/null +++ b/kimodo/scripts/runtime_health.py @@ -0,0 +1,39 @@ +"""Card 9 runtime health check entrypoint for backend startup validation.""" + +from __future__ import annotations + +import argparse +import json + +from kimodo.runtime import runtime_health_report + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Kimodo runtime/backend health check") + parser.add_argument( + "--device", + type=str, + default=None, + help="Requested device (auto, rocm, cuda, amd, cpu, mps, cuda:0, etc.)", + ) + parser.add_argument( + "--require-accelerator", + action="store_true", + help="Fail if selected runtime device is CPU.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + report = runtime_health_report(args.device) + print(json.dumps(report.to_dict(), indent=2, sort_keys=True)) + + if args.require_accelerator and report.selected_device == "cpu": + print("ERROR: accelerator required but runtime selected CPU") + return 2 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) \ No newline at end of file diff --git a/kimodo/skeleton/__init__.py b/kimodo/skeleton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a113cef78d5cd87c93fd998712db64ed2c3c9a --- /dev/null +++ b/kimodo/skeleton/__init__.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Skeleton definitions and utilities used across kimodo.""" + +from .base import SkeletonBase +from .definitions import ( + G1Skeleton34, + SMPLXSkeleton22, + SOMASkeleton30, + SOMASkeleton77, +) +from .kinematics import batch_rigid_transform, fk +from .registry import build_skeleton +from .transforms import global_rots_to_local_rots, to_standard_tpose + +__all__ = [ + "SkeletonBase", + "G1Skeleton34", + "SOMASkeleton30", + "SOMASkeleton77", + "SMPLXSkeleton22", + "batch_rigid_transform", + "fk", + "build_skeleton", + "global_rots_to_local_rots", + "to_standard_tpose", +] diff --git a/kimodo/skeleton/base.py b/kimodo/skeleton/base.py new file mode 100644 index 0000000000000000000000000000000000000000..19590df05e1dda0f7b3431544551c5a92debff74 --- /dev/null +++ b/kimodo/skeleton/base.py @@ -0,0 +1,278 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Base skeleton class: hierarchy, joint metadata, and helpers for kinematics and motion.""" + +from pathlib import Path +from typing import Optional + +import torch + +from kimodo.assets import skeleton_asset_path + +from .kinematics import fk +from .transforms import ( + from_standard_tpose, + global_rots_to_local_rots, + to_standard_tpose, +) + + +class SkeletonBase(torch.nn.Module): + """Base class that stores a skeleton hierarchy and helper metadata. + + Subclasses define the static joint layout (joint names and parent links) and semantic groups + (feet, hands, hips). This class builds index mappings, parent tensors, and convenience helpers + used by kinematics, constraints, and motion conversion utilities. + """ + + # these should be defined in the subclass + name = None + bone_order_names_with_parents = None + bone_order_names_no_root = None + root_idx = None + foot_joint_names = None + foot_joint_idx = None + hip_joint_names = None # in order [right, left] + hip_joint_idx = None # in order [right, left] + + def __init__( + self, + folder: Optional[str] = None, + name: Optional[str] = None, + load: bool = True, + **kwargs, # to catch addition args in configs + ): + """Initialize a skeleton instance and optional neutral-pose assets. + + Args: + folder: Folder containing serialized skeleton assets (for example + `joints.p` and optional `standard_t_pose_global_offsets_rots.p`). + name: Optional runtime name used to validate subclass compatibility. + load: Whether to load tensor assets from `folder`. + **kwargs: Unused extra config keys kept for config compatibility. + """ + super().__init__() + + if name is not None: + # Check that the name is not too far from the actual skeleton class name + assert self.name in name + self.name = name + + if folder is None: + # Take the skeleton asset folder of the repo from the name + # in case we don't override it + folder = str(skeleton_asset_path(self.name)) + + self.folder = folder + + self.dim = len(self.bone_order_names_with_parents) + + if load and folder is not None: + pfolder = Path(folder) + neutral_joints = torch.load(pfolder / "joints.p").squeeze() + self.register_buffer("neutral_joints", neutral_joints, persistent=False) + + if (pfolder / "bvh_joints.p").exists(): + bvh_neutral_joints = torch.load(pfolder / "bvh_joints.p").squeeze() + self.register_buffer("bvh_neutral_joints", bvh_neutral_joints, persistent=False) + + global_offset_path = pfolder / "standard_t_pose_global_offsets_rots.p" + if global_offset_path.exists(): + global_rot_offsets = torch.load(global_offset_path).squeeze() + self.register_buffer("global_rot_offsets", global_rot_offsets, persistent=False) + # Usefull for g1, where the rest pose is not zero + baked_rest_path = pfolder / "rest_pose_local_rot.p" + if baked_rest_path.exists(): + rest_pose_local_rot = torch.load(baked_rest_path).squeeze() + self.register_buffer("rest_pose_local_rot", rest_pose_local_rot, persistent=False) + + self.bone_order_names = [x for x, y in self.bone_order_names_with_parents] + + self.bone_parents = dict(self.bone_order_names_with_parents) + self.bone_index = {x: idx for idx, x in enumerate(self.bone_order_names)} + self.bone_order_names_index = self.bone_index + + # create the parents tensor on the fly + joint_parents = torch.tensor( + [-1 if (y := self.bone_parents[x]) is None else self.bone_index[y] for x in self.bone_order_names] + ) + self.register_buffer("joint_parents", joint_parents, persistent=False) + + self.nbjoints = len(self.bone_order_names) + + # check lengths + assert self.nbjoints == len(self.joint_parents) + if "neutral_joints" in self.__dict__: + assert self.nbjoints == len(self.neutral_joints) + + root_indices = torch.where(joint_parents == -1)[0] + assert len(root_indices) == 1 # should be one root only + self.root_idx = root_indices[0].item() + + if "neutral_joints" in self.__dict__: + assert (self.neutral_joints[0] == 0).all() + + # remove the root + self.bone_order_names_no_root = ( + self.bone_order_names[: self.root_idx] + self.bone_order_names[self.root_idx + 1 :] + ) + + self.foot_joint_names = self.left_foot_joint_names + self.right_foot_joint_names + self.foot_joint_names_index = {x: idx for idx, x in enumerate(self.foot_joint_names)} + + self.left_foot_joint_idx = [ + self.bone_order_names.index(foot_joint) for foot_joint in self.left_foot_joint_names + ] + + self.right_foot_joint_idx = [ + self.bone_order_names.index(foot_joint) for foot_joint in self.right_foot_joint_names + ] + + self.foot_joint_idx = self.left_foot_joint_idx + self.right_foot_joint_idx + + self.hip_joint_idx = [self.bone_order_names.index(hip_joint) for hip_joint in self.hip_joint_names] + + def expand_joint_names(self, joint_names): + """Expand base EE names [LeftFoot, RightFoot, LeftHand, RightHand] actual joint names to + constrain position and rotations. + + Args: + joint_names: list of list of base EE names to constrain + + Returns: + rot_joint_names: list of list of joint names to constrain rotations + pos_joint_names: list of list of joint names to constrain positions + """ + + base_ee = ["LeftFoot", "RightFoot", "LeftHand", "RightHand", "Hips"] + + pelvis_name = self.bone_order_names[self.root_idx] + + base_pos_names = [ + self.left_foot_joint_names, + self.right_foot_joint_names, + self.left_hand_joint_names, + self.right_hand_joint_names, + [pelvis_name], + ] + # base of each chain + base_rot_names = [ + self.left_foot_joint_names[:1], + self.right_foot_joint_names[:1], + self.left_hand_joint_names[:1], + self.right_hand_joint_names[:1], + [pelvis_name], + ] + rot_joint_names = [] + pos_joint_names = [] + # loop through each EE joint group to constrain in the current keyframe + for jname in joint_names: + idx = base_ee.index(jname) + rot_joint_names += base_rot_names[idx] + pos_joint_names += base_pos_names[idx] + return rot_joint_names, pos_joint_names + + def expand_joint_names_batched(self, joint_names): + """Expand base EE names [LeftFoot, RightFoot, LeftHand, RightHand] actual joint names to + constrain position and rotations. + + Args: + joint_names: list of list of base EE names to constrain + + Returns: + rot_joint_names: list of list of joint names to constrain rotations + pos_joint_names: list of list of joint names to constrain positions + """ + + base_ee = ["LeftFoot", "RightFoot", "LeftHand", "RightHand", "Hips"] + + pelvis_name = self.bone_order_names[self.root_idx] + + base_pos_names = [ + self.left_foot_joint_names, + self.right_foot_joint_names, + self.left_hand_joint_names, + self.right_hand_joint_names, + [pelvis_name], + ] + # base of each chain + base_rot_names = [ + self.left_foot_joint_names[:1], + self.right_foot_joint_names[:1], + self.left_hand_joint_names[:1], + self.right_hand_joint_names[:1], + [pelvis_name], + ] + # loop through each keyframe + rot_joint_names = [] + pos_joint_names = [] + for key_joint_names in joint_names: + key_rot_names = [] + key_pos_names = [] + # loop through each EE joint group to constrain in the current keyframe + for jname in key_joint_names: + idx = base_ee.index(jname) + key_rot_names += base_rot_names[idx] + key_pos_names += base_pos_names[idx] + rot_joint_names.append(key_rot_names) + pos_joint_names.append(key_pos_names) + return rot_joint_names, pos_joint_names + + def __repr__(self): + if self.folder is None: + return f"{self.__class__.__name__}()" + return f'{self.__class__.__name__}(folder="{self.folder}")' + + @property + def device(self): + """Device where neutral-joint buffers are stored. + + Returns 'cpu' if neutral_joints is not present. + """ + if getattr(self, "neutral_joints", None) is None: + return "cpu" + return self.neutral_joints.device + + def fk(self, local_joint_rots: torch.Tensor, root_positions: torch.Tensor): + """Run forward kinematics for this skeleton layout. + + Args: + local_joint_rots: Local joint rotation matrices with shape + `(..., J, 3, 3)`. + root_positions: Root translations with shape `(..., 3)`. + + Returns: + Tuple of `(global_joint_rots, posed_joints, posed_joints_norootpos)`. + """ + global_joint_rots, posed_joints, posed_joints_norootpos = fk(local_joint_rots, root_positions, self) + return global_joint_rots, posed_joints, posed_joints_norootpos + + def to_standard_tpose(self, local_rot_mats: torch.Tensor): + """Convert local rotations into the skeleton's standard T-pose frame.""" + return to_standard_tpose(local_rot_mats, self) + + def from_standard_tpose(self, local_rot_mats: torch.Tensor): + """Convert local rotations from the skeleton's standard T-pose frame.""" + return from_standard_tpose(local_rot_mats, self) + + def global_rots_to_local_rots(self, global_joint_rots: torch.Tensor): + """Convert global joint rotations to local rotations for this hierarchy.""" + return global_rots_to_local_rots(global_joint_rots, self) + + def get_skel_slice(self, skeleton: "SkeletonBase"): + """Build index mapping from another skeleton into this skeleton order. + + Args: + skeleton: Source skeleton whose joint order is used by input tensors. + + Returns: + A list of source indices ordered as `self.bone_order_names`. + + Raises: + ValueError: If at least one required joint is missing from `skeleton`. + """ + try: + skel_slice = [skeleton.bone_index[x] for x in self.bone_order_names] + except KeyError: + raise ValueError("The current skeleton contain joints that are not in the input") + return skel_slice diff --git a/kimodo/skeleton/bvh.py b/kimodo/skeleton/bvh.py new file mode 100644 index 0000000000000000000000000000000000000000..6fae1c337303acbe8d7ad1bd4920b2d1a92f5fc8 --- /dev/null +++ b/kimodo/skeleton/bvh.py @@ -0,0 +1,578 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""BVH parsing utilities and skeleton/animation conversion helpers.""" + +import re +from typing import Optional, Tuple + +import numpy as np +import torch +from scipy.spatial.transform import Rotation + + +class BvhNode: + """Lightweight tree node used to represent parsed BVH hierarchy lines.""" + + def __init__(self, value=[], parent=None): + """Create a node from tokenized BVH line values.""" + self.value = value + self.children = [] + self.parent = parent + if self.parent: + self.parent.add_child(self) + + def add_child(self, item): + """Attach a child node and set its parent reference.""" + item.parent = self + self.children.append(item) + + def filter(self, key): + """Yield direct children whose first token matches `key`.""" + for child in self.children: + if child.value[0] == key: + yield child + + def __iter__(self): + for child in self.children: + yield child + + def __getitem__(self, key): + """Return all tokens following `key` from the first matching child node.""" + for child in self.children: + for index, item in enumerate(child.value): + if item == key: + if index + 1 >= len(child.value): + return None + else: + return child.value[index + 1 :] + raise IndexError("key {} not found".format(key)) + + def __repr__(self): + return str(" ".join(self.value)) + + @property + def name(self): + """Joint name for `ROOT`/`JOINT` entries.""" + return self.value[1] + + +class Bvh: + """Parsed BVH file with hierarchy graph and per-frame channel values.""" + + def __init__(self, data: str, backend: Optional[str] = "graph"): + """ + Args: + data: Raw BVH file content. + backend: Parsing mode. `"graph"` keeps list-based frame storage, + while `"np"` precomputes a NumPy array and index caches. + """ + self.data = data + self.root = BvhNode() + self.frames = [] + self.backend = backend + self.tokenize() + if self.backend == "np": + # cache important info for quick access later + self.build_data_array() + elif self.backend == "graph": + pass + else: + raise ValueError(f"Unknown backend for BVH loading: {backend}") + + def build_data_array(self): + """Build cached channel indices and contiguous frame data for `"np"` backend.""" + joints = self.get_joints() + self.joint2idx = dict() + self.joint2channels = dict() + cur_idx = 0 + for joint in joints: + self.joint2idx[joint.value[1]] = cur_idx + cur_idx += int(joint["CHANNELS"][0]) + self.joint2channels[joint.value[1]] = joint["CHANNELS"][1:] + self.np_data_array = np.array(self.frames, dtype=np.float32) + + def tokenize(self): + """Tokenize BVH text and populate hierarchy plus frame values.""" + first_round = [] + accumulator = "" + for char in self.data: + if char not in ("\n", "\r"): + accumulator += char + elif accumulator: + first_round.append(re.split("\\s+", accumulator.strip())) + accumulator = "" + node_stack = [self.root] + frame_time_found = False + node = None + for item in first_round: + if frame_time_found: + self.frames.append(item) + continue + key = item[0] + if key == "{": + node_stack.append(node) + elif key == "}": + node_stack.pop() + else: + node = BvhNode(item) + # print("new node: ", node, "\nparent: ", node_stack[-1]) + node_stack[-1].add_child(node) + if item[0] == "Frame" and item[1] == "Time:": + frame_time_found = True + + def search(self, *items): + """Depth-first search for nodes matching a prefix of tokens.""" + found_nodes = [] + + def check_children(node): + if len(node.value) >= len(items): + failed = False + for index, item in enumerate(items): + if node.value[index] != item: + failed = True + break + if not failed: + found_nodes.append(node) + for child in node: + check_children(child) + + check_children(self.root) + return found_nodes + + def get_joints(self): + """Return all `ROOT`/`JOINT` hierarchy joints in BVH traversal order.""" + joints = [] + + def iterate_joints(joint): + joints.append(joint) + for child in joint.filter("JOINT"): + iterate_joints(child) + + iterate_joints(next(self.root.filter("ROOT"))) + return joints + + def get_joints_names(self): + """Return joint names in the same order as :meth:`get_joints`.""" + joints = [] + + def iterate_joints(joint): + joints.append(joint.value[1]) + for child in joint.filter("JOINT"): + iterate_joints(child) + + iterate_joints(next(self.root.filter("ROOT"))) + return joints + + def joint_direct_children(self, name): + """Return direct child joints of the given joint name.""" + joint = self.get_joint(name) + return [child for child in joint.filter("JOINT")] + + def get_joint_index(self, name): + """Return hierarchy index of the named joint.""" + return self.get_joints().index(self.get_joint(name)) + + def get_joint(self, name): + """Return hierarchy node for a joint name.""" + found = self.search("ROOT", name) + if not found: + found = self.search("JOINT", name) + if found: + return found[0] + raise LookupError("joint not found") + + def joint_offset(self, name, idx=[0, 1, 2]): + """Return selected `OFFSET` components for a joint.""" + joint = self.get_joint(name) + offset = joint["OFFSET"] + if len(offset) < max(idx): + return None + return (float(offset[idx[0]]), float(offset[idx[1]]), float(offset[idx[2]])) + + def joint_offset_rot(self, name): + """Return optional rotational offset components from custom BVH files.""" + return self.joint_offset(name, idx=[3, 4, 5]) + + def joint_channels(self, name): + """Return channel names declared for a joint.""" + if self.backend == "np": + return self.joint2channels[name] + else: + joint = self.get_joint(name) + return joint["CHANNELS"][1:] + + def get_joint_channels_index(self, joint_name): + """Return the flattened starting channel index for one joint.""" + if self.backend == "np": + return self.joint2idx[joint_name] + else: + index = 0 + for joint in self.get_joints(): + if joint.value[1] == joint_name: + return index + index += int(joint["CHANNELS"][0]) + raise LookupError("joint not found") + + def get_joint_channel_index(self, joint, channel): + """Return per-joint channel offset for a specific channel name.""" + channels = self.joint_channels(joint) + if channel in channels: + channel_index = channels.index(channel) + else: + raise ValueError(f"Channel {channel} not found in {channels}") + return channel_index + + def frame_joint_channel(self, frame_index, joint, channel, value=None): + """Return one channel value for one joint at one frame index.""" + joint_index = self.get_joint_channels_index(joint) + channel_index = self.get_joint_channel_index(joint, channel) + if channel_index == -1 and value is not None: + return value + if self.backend == "np": + return self.np_data_array[frame_index, joint_index + channel_index] + else: + return float(self.frames[frame_index][joint_index + channel_index]) + + def frame_joint_channels(self, frame_index, joint, channels, value=None): + """Get single frame data for on specific joint from multiple specific channels (e.g. + Xrotation, Yrotation, Zrotation).""" + values = [] + joint_index = self.get_joint_channels_index(joint) + if self.backend == "np": + channel_idx = [self.get_joint_channel_index(joint, channel) for channel in channels] + channel_idx = np.array(channel_idx) + joint_index + values = self.np_data_array[frame_index, channel_idx] + else: + for channel in channels: + channel_index = self.get_joint_channel_index(joint, channel) + if channel_index == -1 and value is not None: + values.append(value) + else: + values.append(float(self.frames[frame_index][joint_index + channel_index])) + return values + + def frames_joint_channels(self, joint, channels, value=None): + """Get all frame data for one joint from multiple channels (e.g. Xrotation, Yrotation, + Zrotation).""" + joint_index = self.get_joint_channels_index(joint) + if self.backend == "np": + channel_idx = [self.get_joint_channel_index(joint, channel) for channel in channels] + channel_idx = np.array(channel_idx) + joint_index + all_frames = self.np_data_array[:, channel_idx] + else: + all_frames = [] + for frame in self.frames: + values = [] + for channel in channels: + channel_index = self.get_joint_channel_index(joint, channel) + if channel_index == -1 and value is not None: + values.append(value) + else: + values.append(float(frame[joint_index + channel_index])) + all_frames.append(values) + return all_frames + + def frames_joints_channels(self, joint_names, channels): + """Get all frames for all specified joints with one specified set of channels.""" + if self.backend != "np": + raise NotImplementedError("Only np backend is supported for this function") + joint_indices = [(joint_name, self.joint2idx[joint_name]) for joint_name in joint_names] + data_indices = [] + for joint_name, joint_idx in joint_indices: + channel_indices = [self.get_joint_channel_index(joint_name, channel) for channel in channels] + data_indices.extend([joint_idx + channel_idx for channel_idx in channel_indices]) + all_frames = self.np_data_array[:, data_indices] + all_frames = all_frames.reshape(-1, len(joint_names), len(channels)) + return all_frames + + def joint_parent(self, name): + """Return parent joint node, or `None` for the root.""" + joint = self.get_joint(name) + if joint.parent == self.root: + return None + return joint.parent + + def joint_parent_index(self, name): + """Return parent joint index, or `-1` for the root.""" + joint = self.get_joint(name) + if joint.parent == self.root: + return -1 + return self.get_joints().index(joint.parent) + + @property + def nframes(self): + """Number of motion frames declared in the BVH header.""" + try: + return int(next(self.root.filter("Frames:")).value[1]) + except StopIteration: + raise LookupError("number of frames not found") + + @property + def frame_time(self): + """Frame duration in seconds declared in the BVH header.""" + try: + return float(next(self.root.filter("Frame")).value[2]) + except StopIteration: + raise LookupError("frame time not found") + + +class Bone: + """Container for one skeleton bone and its kinematic metadata.""" + + def __init__(self): + # original bone info + self.id = None + self.name = None + self.orient = np.identity(3) + self.dof_index = [] + self.channels = [] # bvh only + self.lb = [] + self.ub = [] + self.parent = None + self.child = [] + + # asf specific + self.dir = np.zeros(3) + self.len = 0 + # bvh specific + self.offset = np.zeros(3) # default offset for position + self.offset_rot = None # rotation for custom nv bvh + + # inferred info + self.pos = np.zeros(3) + self.end = np.zeros(3) + + def __repr__(self): + return f"{self.name}" + + +class SkeletonBvh: + """Skeleton structure reconstructed from BVH hierarchy metadata.""" + + def __init__(self): + self.bones = [] + self.name2bone = {} + self.mass_scale = 1.0 + self.len_scale = 1.0 + self.dof_name = ["x", "y", "z"] + self.root = None + + def get_bones_names(self): + """Return bone names in skeleton order.""" + return [x.name for x in self.bones] + + def get_parent_indices(self): + """Return parent index array aligned with `self.bones`.""" + parent_indices = [-1] * len(self.bones) + for bone in self.bones: + if bone.parent: + parent_indices[bone.id] = bone.parent.id + return parent_indices + + def get_neutral_joints(self): + """Return neutral/rest joint positions as a NumPy array `(J, 3)`.""" + joints = [] + for bone in self.bones: + joints.append(bone.pos) + joints = np.stack(joints, axis=0) + return joints + + def load_from_bvh(self, fname, exclude_bones=None, spec_channels=None, mocap=None): + """Load skeleton hierarchy and rest offsets from a BVH file. + + Args: + fname: Path to a BVH file (ignored when *mocap* is given). + exclude_bones: Bone-name substrings to ignore while constructing the + skeleton. + spec_channels: Optional per-joint channel overrides. + mocap: Pre-parsed :class:`Bvh` object. When provided the file is + not re-read from disk. + """ + if exclude_bones is None: + exclude_bones = {} + if spec_channels is None: + spec_channels = dict() + if mocap is None: + with open(fname) as f: + mocap = Bvh(f.read()) + + joint_names = list( + filter( + lambda x: all([t not in x for t in exclude_bones]), + mocap.get_joints_names(), + ) + ) + dof_ind = {"x": 0, "y": 1, "z": 2} + self.len_scale = 1.0 + self.root = Bone() + self.root.id = 0 + self.root.name = joint_names[0] + self.root.channels = mocap.joint_channels(self.root.name) + self.root.offset = np.array(mocap.joint_offset(self.root.name)) * self.len_scale + self.root.offset_rot = mocap.joint_offset_rot(self.root.name) + if self.root.offset_rot is not None: + self.root.offset_rot = np.array(self.root.offset_rot) + # self.root.offset = np.zeros_like(self.root.offset) # TODO: remove this + self.name2bone[self.root.name] = self.root + self.bones.append(self.root) + for i, joint in enumerate(joint_names[1:]): + bone = Bone() + bone.id = i + 1 + bone.name = joint + bone.channels = spec_channels[joint] if joint in spec_channels.keys() else mocap.joint_channels(joint) + bone.dof_index = [dof_ind[x[0].lower()] for x in bone.channels] + bone.offset = np.array(mocap.joint_offset(joint)) * self.len_scale + bone.offset_rot = mocap.joint_offset_rot(joint) + if bone.offset_rot is not None: + bone.offset_rot = np.array(bone.offset_rot) + bone.lb = [-180.0] * 3 + bone.ub = [180.0] * 3 + self.bones.append(bone) + self.name2bone[joint] = bone + + # for bone in self.bones: + # print(bone.name, bone.channels, bone.offset) + + for bone in self.bones[1:]: + parent_name = mocap.joint_parent(bone.name).name + if parent_name in self.name2bone.keys(): + bone_p = self.name2bone[parent_name] + bone_p.child.append(bone) + bone.parent = bone_p + + self.forward_bvh(self.root) + for bone in self.bones: + if len(bone.child) == 0: + child_vals = [str(node) for node in mocap.get_joint(bone.name).children] + if "End Site" in child_vals: + end_site_idx = child_vals.index("End Site") + end_site_offset = mocap.get_joint(bone.name).children[end_site_idx]["OFFSET"] + bone.end = bone.pos + np.array([float(x) for x in end_site_offset]) * self.len_scale + else: + pass + else: + bone.end = sum([bone_c.pos for bone_c in bone.child]) / len(bone.child) + + def forward_bvh(self, bone): + """Recursively accumulate absolute joint positions from local offsets.""" + if bone.parent: + bone.pos = bone.parent.pos + bone.offset + else: + bone.pos = bone.offset + for bone_c in bone.child: + self.forward_bvh(bone_c) + + +def load_bvh_animation( + fname: str, + skeleton: SkeletonBvh, + rot_order: Optional[str] = "native", + backend: Optional[str] = "np", + return_quat: Optional[bool] = False, + mocap: Optional["Bvh"] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Load motion channels from BVH into root translations and joint rotations. + + Args: + fname: Full path to the BVH file (ignored when *mocap* is given). + skeleton: Parsed neutral skeleton built from compatible BVH hierarchy. + rot_order: Euler order to use for conversion (`"native"` keeps BVH order). + backend: BVH parser backend (`"np"` or `"graph"`). + return_quat: If `True`, return quaternions instead of rotation matrices. + mocap: Pre-parsed :class:`Bvh` object. When provided the file is + not re-read from disk. + + Returns: + Root translations `(T, 3)` and joint rotations `(T, J, 3, 3)` or + `(T, J, 4)` when `return_quat=True`. + """ + if mocap is None: + with open(fname) as f: + mocap = Bvh(f.read(), backend=backend) + + # assume all joints are same ordering, load in with native ordering + root_channels = mocap.joint_channels(skeleton.root.name) + pos_channels = [channel for channel in root_channels if channel.endswith("position")] + rot_channels = [channel for channel in root_channels if channel.endswith("rotation")] + + root_trans = np.array(mocap.frames_joint_channels(skeleton.root.name, pos_channels)) + + effective_backend = mocap.backend + if effective_backend == "np": + # NOTE: assumes rot channel ordering is the same for all joints + joint_eulers = mocap.frames_joints_channels(skeleton.get_bones_names(), rot_channels) + joint_eulers = np.deg2rad(joint_eulers) + elif effective_backend == "graph": + joint_eulers = [] + for bone in skeleton.bones: + bone_channels = mocap.joint_channels(bone.name) + bone_rot_channels = [channel for channel in bone_channels if channel.endswith("rotation")] + assert bone_rot_channels == rot_channels, "Rotation channel ordering is not consistent across joints!" + # use native rotation order + euler = np.deg2rad(np.array(mocap.frames_joint_channels(bone.name, rot_channels))) + joint_eulers.append(euler) + joint_eulers = np.stack(joint_eulers, axis=1) + else: + raise ValueError(f"Unknown backend for BVH loading: {effective_backend}") + + if rot_order == "native": + rot_order = "" + for axis in rot_channels: + rot_order += axis[0] + else: + # need to reorder dims + ordered_joint_eulers = [] + for axis in rot_order: + i = rot_channels.index(axis + "rotation") + ordered_joint_eulers.append(joint_eulers[..., i]) + joint_eulers = np.stack(ordered_joint_eulers, axis=-1) + + rotations = Rotation.from_euler(rot_order, joint_eulers.reshape(-1, 3)) + if return_quat: + joint_rots = rotations.as_quat(scalar_first=True).reshape(joint_eulers.shape[:-1] + (4,)) + else: + joint_rots = rotations.as_matrix().reshape(joint_eulers.shape[:-1] + (3, 3)) + + return root_trans, joint_rots + + +def parse_bvh_motion(file_path_input: str, parse_neutral_joints: bool = False): + """Parse a BVH motion into tensors used by kimodo motion pipelines. + + Args: + file_path_input: Path to input BVH file. + parse_neutral_joints: If `True`, also return neutral joints in meters. + + Returns: + ``(local_rot_mats, root_trans, fps)`` or + ``(local_rot_mats, root_trans, fps, neutral_joints)`` when requested. + """ + with open(file_path_input) as f: + mocap = Bvh(f.read(), backend="np") + + fps = 1.0 / mocap.frame_time + + skeletonBVH = SkeletonBvh() + exclude_bones = {"Root"} + skeletonBVH.load_from_bvh(file_path_input, exclude_bones=exclude_bones, mocap=mocap) + + root_trans, local_rot_mats = load_bvh_animation(file_path_input, skeletonBVH, mocap=mocap) + root_trans *= 0.01 # unit change: cm -> m + root_trans = torch.tensor(root_trans) + local_rot_mats = torch.tensor(local_rot_mats) + + # Don't parse neutral_joints here + # it is not actually needed right now: + # the skeleton is always the same, and saved in the folder + # carefull: the one saved in the folder it relative to the standard t_pose + # whereas the parsed one is not + if not parse_neutral_joints: + return local_rot_mats, root_trans, fps + + neutral_joints = skeletonBVH.get_neutral_joints() + neutral_joints *= 0.01 # unit change: cm -> m + # remove the root position of the skeleton + # (it is already "included" in the root_translation) + root_idx = 0 + neutral_joints = torch.tensor(neutral_joints - neutral_joints[root_idx]) + return local_rot_mats, root_trans, fps, neutral_joints diff --git a/kimodo/skeleton/definitions.py b/kimodo/skeleton/definitions.py new file mode 100644 index 0000000000000000000000000000000000000000..1bea10c13cdb0b44cae6675923a48d385244cb4c --- /dev/null +++ b/kimodo/skeleton/definitions.py @@ -0,0 +1,371 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Concrete skeleton definitions: SOMA, G1, SMPLX with joint names and hierarchy.""" + +from pathlib import Path + +import numpy as np +import torch + +from ..tools import ensure_batched +from .base import SkeletonBase + + +class SOMASkeleton77(SkeletonBase): + """High-detail 77-joint SOMA skeleton with full finger and toe chains.""" + + name = "somaskel77" + + right_foot_joint_names = [ + "RightFoot", + "RightToeBase", + "RightToeEnd", + ] # in order of chain + left_foot_joint_names = [ + "LeftFoot", + "LeftToeBase", + "LeftToeEnd", + ] # in order of chain + right_hand_joint_names = [ + "RightHand", + "RightHandThumb1", + "RightHandThumb2", + "RightHandThumb3", + "RightHandThumbEnd", + "RightHandIndex1", + "RightHandIndex2", + "RightHandIndex3", + "RightHandIndex4", + "RightHandIndexEnd", + "RightHandMiddle1", + "RightHandMiddle2", + "RightHandMiddle3", + "RightHandMiddle4", + "RightHandMiddleEnd", + "RightHandRing1", + "RightHandRing2", + "RightHandRing3", + "RightHandRing4", + "RightHandRingEnd", + "RightHandPinky1", + "RightHandPinky2", + "RightHandPinky3", + "RightHandPinky4", + "RightHandPinkyEnd", + ] # in order of chain + left_hand_joint_names = [ + "LeftHand", + "LeftHandThumb1", + "LeftHandThumb2", + "LeftHandThumb3", + "LeftHandThumbEnd", + "LeftHandIndex1", + "LeftHandIndex2", + "LeftHandIndex3", + "LeftHandIndex4", + "LeftHandIndexEnd", + "LeftHandMiddle1", + "LeftHandMiddle2", + "LeftHandMiddle3", + "LeftHandMiddle4", + "LeftHandMiddleEnd", + "LeftHandRing1", + "LeftHandRing2", + "LeftHandRing3", + "LeftHandRing4", + "LeftHandRingEnd", + "LeftHandPinky1", + "LeftHandPinky2", + "LeftHandPinky3", + "LeftHandPinky4", + "LeftHandPinkyEnd", + ] # in order of chain + + hip_joint_names = ["RightLeg", "LeftLeg"] # in order [right, left] + + bone_order_names_with_parents = [ + ("Hips", None), + ("Spine1", "Hips"), + ("Spine2", "Spine1"), + ("Chest", "Spine2"), + ("Neck1", "Chest"), + ("Neck2", "Neck1"), + ("Head", "Neck2"), + ("HeadEnd", "Head"), + ("Jaw", "Head"), + ("LeftEye", "Head"), + ("RightEye", "Head"), + ("LeftShoulder", "Chest"), + ("LeftArm", "LeftShoulder"), + ("LeftForeArm", "LeftArm"), + ("LeftHand", "LeftForeArm"), + ("LeftHandThumb1", "LeftHand"), + ("LeftHandThumb2", "LeftHandThumb1"), + ("LeftHandThumb3", "LeftHandThumb2"), + ("LeftHandThumbEnd", "LeftHandThumb3"), + ("LeftHandIndex1", "LeftHand"), + ("LeftHandIndex2", "LeftHandIndex1"), + ("LeftHandIndex3", "LeftHandIndex2"), + ("LeftHandIndex4", "LeftHandIndex3"), + ("LeftHandIndexEnd", "LeftHandIndex4"), + ("LeftHandMiddle1", "LeftHand"), + ("LeftHandMiddle2", "LeftHandMiddle1"), + ("LeftHandMiddle3", "LeftHandMiddle2"), + ("LeftHandMiddle4", "LeftHandMiddle3"), + ("LeftHandMiddleEnd", "LeftHandMiddle4"), + ("LeftHandRing1", "LeftHand"), + ("LeftHandRing2", "LeftHandRing1"), + ("LeftHandRing3", "LeftHandRing2"), + ("LeftHandRing4", "LeftHandRing3"), + ("LeftHandRingEnd", "LeftHandRing4"), + ("LeftHandPinky1", "LeftHand"), + ("LeftHandPinky2", "LeftHandPinky1"), + ("LeftHandPinky3", "LeftHandPinky2"), + ("LeftHandPinky4", "LeftHandPinky3"), + ("LeftHandPinkyEnd", "LeftHandPinky4"), + ("RightShoulder", "Chest"), + ("RightArm", "RightShoulder"), + ("RightForeArm", "RightArm"), + ("RightHand", "RightForeArm"), + ("RightHandThumb1", "RightHand"), + ("RightHandThumb2", "RightHandThumb1"), + ("RightHandThumb3", "RightHandThumb2"), + ("RightHandThumbEnd", "RightHandThumb3"), + ("RightHandIndex1", "RightHand"), + ("RightHandIndex2", "RightHandIndex1"), + ("RightHandIndex3", "RightHandIndex2"), + ("RightHandIndex4", "RightHandIndex3"), + ("RightHandIndexEnd", "RightHandIndex4"), + ("RightHandMiddle1", "RightHand"), + ("RightHandMiddle2", "RightHandMiddle1"), + ("RightHandMiddle3", "RightHandMiddle2"), + ("RightHandMiddle4", "RightHandMiddle3"), + ("RightHandMiddleEnd", "RightHandMiddle4"), + ("RightHandRing1", "RightHand"), + ("RightHandRing2", "RightHandRing1"), + ("RightHandRing3", "RightHandRing2"), + ("RightHandRing4", "RightHandRing3"), + ("RightHandRingEnd", "RightHandRing4"), + ("RightHandPinky1", "RightHand"), + ("RightHandPinky2", "RightHandPinky1"), + ("RightHandPinky3", "RightHandPinky2"), + ("RightHandPinky4", "RightHandPinky3"), + ("RightHandPinkyEnd", "RightHandPinky4"), + ("LeftLeg", "Hips"), + ("LeftShin", "LeftLeg"), + ("LeftFoot", "LeftShin"), + ("LeftToeBase", "LeftFoot"), + ("LeftToeEnd", "LeftToeBase"), + ("RightLeg", "Hips"), + ("RightShin", "RightLeg"), + ("RightFoot", "RightShin"), + ("RightToeBase", "RightFoot"), + ("RightToeEnd", "RightToeBase"), + ] + + @property + def relaxed_hands_rest_pose(self): + # lazy loading + if hasattr(self, "_relaxed_hands_rest_pose"): + return self._relaxed_hands_rest_pose + + relaxed_hands_pose_path = Path(self.folder) / "relaxed_hands_rest_pose.npy" + relaxed_hands_rest_pose = torch.from_numpy(np.load(relaxed_hands_pose_path)).squeeze() + self.register_buffer( + "_relaxed_hands_rest_pose", + relaxed_hands_rest_pose, + persistent=False, + ) + return self._relaxed_hands_rest_pose + + +class SOMASkeleton30(SkeletonBase): + """Compact 30-joint SOMA variant with reduced hand and end-effector detail.""" + + name = "somaskel30" + + right_foot_joint_names = [ + "RightFoot", + "RightToeBase", + ] # in order of chain + left_foot_joint_names = [ + "LeftFoot", + "LeftToeBase", + ] # in order of chain + right_hand_joint_names = [ + "RightHand", + "RightHandMiddleEnd", + ] # in order of chain + left_hand_joint_names = [ + "LeftHand", + "LeftHandMiddleEnd", + ] # in order of chain + + hip_joint_names = ["RightLeg", "LeftLeg"] # in order [right, left] + + bone_order_names_with_parents = [ + ("Hips", None), + ("Spine1", "Hips"), + ("Spine2", "Spine1"), + ("Chest", "Spine2"), + ("Neck1", "Chest"), + ("Neck2", "Neck1"), + ("Head", "Neck2"), + ("Jaw", "Head"), + ("LeftEye", "Head"), + ("RightEye", "Head"), + ("LeftShoulder", "Chest"), + ("LeftArm", "LeftShoulder"), + ("LeftForeArm", "LeftArm"), + ("LeftHand", "LeftForeArm"), + ("LeftHandThumbEnd", "LeftHand"), + ("LeftHandMiddleEnd", "LeftHand"), + ("RightShoulder", "Chest"), + ("RightArm", "RightShoulder"), + ("RightForeArm", "RightArm"), + ("RightHand", "RightForeArm"), + ("RightHandThumbEnd", "RightHand"), + ("RightHandMiddleEnd", "RightHand"), + ("LeftLeg", "Hips"), + ("LeftShin", "LeftLeg"), + ("LeftFoot", "LeftShin"), + ("LeftToeBase", "LeftFoot"), + ("RightLeg", "Hips"), + ("RightShin", "RightLeg"), + ("RightFoot", "RightShin"), + ("RightToeBase", "RightFoot"), + ] + + @property + def somaskel77(self): + # lazy loading + if not hasattr(self, "_somaskel77"): + self._somaskel77 = SOMASkeleton77() + return self._somaskel77 + + @ensure_batched(local_joint_rots_subset=4) + def to_SOMASkeleton77(self, local_joint_rots_subset: torch.Tensor): + # Converting from 30-joint to 77-joint to have relaxed hands + + device = local_joint_rots_subset.device + nF = len(local_joint_rots_subset) + local_joint_rots_mats = self.somaskel77.relaxed_hands_rest_pose.clone().to(device).repeat(nF, 1, 1, 1) + + skel_slice = self.get_skel_slice(self.somaskel77) + local_joint_rots_mats[:, skel_slice] = local_joint_rots_subset + return local_joint_rots_mats + + @ensure_batched(local_joint_rots_full=4) + def from_SOMASkeleton77(self, local_joint_rots_full: torch.Tensor) -> torch.Tensor: + """Extract the 30-joint subset from 77-joint local rotation data.""" + skel_slice = self.get_skel_slice(self.somaskel77) + return local_joint_rots_full[:, skel_slice] + + def output_to_SOMASkeleton77(self, output: dict) -> dict: + """Convert model output dict from somaskel30 to somaskel77. + + Expands local_rot_mats to 77 joints, re-runs FK for global_rot_mats and posed_joints. Foot + contacts are expanded from 4 channels to 6 (toe-end copies toe-base contact). + """ + local_rot_mats_77 = self.to_SOMASkeleton77(output["local_rot_mats"]) + root_positions = output["root_positions"] + global_rot_mats_77, posed_joints_77, _ = self.somaskel77.fk(local_rot_mats_77, root_positions) + out_77 = dict(output) + out_77["local_rot_mats"] = local_rot_mats_77 + out_77["global_rot_mats"] = global_rot_mats_77 + out_77["posed_joints"] = posed_joints_77 + + if "foot_contacts" in output: + fc = output["foot_contacts"] # [..., 4]: [L_heel, L_toe, R_heel, R_toe] + # -> [..., 6]: [L_heel, L_toe, L_toe_end, R_heel, R_toe, R_toe_end] + out_77["foot_contacts"] = torch.cat([fc[..., :2], fc[..., 1:2], fc[..., 2:4], fc[..., 3:4]], dim=-1) + + return out_77 + + +class G1Skeleton34(SkeletonBase): + """Unitree G1 skeleton with 32 articulated joints plus 2 toe endpoints.""" + + name = "g1skel34" + right_foot_joint_names = ["right_ankle_roll_skel", "right_toe_base"] + left_foot_joint_names = ["left_ankle_roll_skel", "left_toe_base"] + right_hand_joint_names = ["right_wrist_yaw_skel", "right_hand_roll_skel"] + left_hand_joint_names = ["left_wrist_yaw_skel", "left_hand_roll_skel"] + + hip_joint_names = [ + "right_hip_pitch_skel", + "left_hip_pitch_skel", + ] # used to calculate root orientation, only need 1 pair of hip joints + + bone_order_names_with_parents = [ + ("pelvis_skel", None), + ("left_hip_pitch_skel", "pelvis_skel"), + ("left_hip_roll_skel", "left_hip_pitch_skel"), + ("left_hip_yaw_skel", "left_hip_roll_skel"), + ("left_knee_skel", "left_hip_yaw_skel"), + ("left_ankle_pitch_skel", "left_knee_skel"), + ("left_ankle_roll_skel", "left_ankle_pitch_skel"), + ("left_toe_base", "left_ankle_roll_skel"), + ("right_hip_pitch_skel", "pelvis_skel"), + ("right_hip_roll_skel", "right_hip_pitch_skel"), + ("right_hip_yaw_skel", "right_hip_roll_skel"), + ("right_knee_skel", "right_hip_yaw_skel"), + ("right_ankle_pitch_skel", "right_knee_skel"), + ("right_ankle_roll_skel", "right_ankle_pitch_skel"), + ("right_toe_base", "right_ankle_roll_skel"), + ("waist_yaw_skel", "pelvis_skel"), + ("waist_roll_skel", "waist_yaw_skel"), + ("waist_pitch_skel", "waist_roll_skel"), + ("left_shoulder_pitch_skel", "waist_pitch_skel"), + ("left_shoulder_roll_skel", "left_shoulder_pitch_skel"), + ("left_shoulder_yaw_skel", "left_shoulder_roll_skel"), + ("left_elbow_skel", "left_shoulder_yaw_skel"), + ("left_wrist_roll_skel", "left_elbow_skel"), + ("left_wrist_pitch_skel", "left_wrist_roll_skel"), + ("left_wrist_yaw_skel", "left_wrist_pitch_skel"), + ("left_hand_roll_skel", "left_wrist_yaw_skel"), + ("right_shoulder_pitch_skel", "waist_pitch_skel"), + ("right_shoulder_roll_skel", "right_shoulder_pitch_skel"), + ("right_shoulder_yaw_skel", "right_shoulder_roll_skel"), + ("right_elbow_skel", "right_shoulder_yaw_skel"), + ("right_wrist_roll_skel", "right_elbow_skel"), + ("right_wrist_pitch_skel", "right_wrist_roll_skel"), + ("right_wrist_yaw_skel", "right_wrist_pitch_skel"), + ("right_hand_roll_skel", "right_wrist_yaw_skel"), + ] + + +class SMPLXSkeleton22(SkeletonBase): + """SMPL-X skeleton with body-only 22 joints.""" + + name = "smplx22" + right_foot_joint_names = ["right_ankle", "right_foot"] # in order of chain + left_foot_joint_names = ["left_ankle", "left_foot"] # in order of chain + right_hand_joint_names = ["right_wrist"] # in order of chain + left_hand_joint_names = ["left_wrist"] # in order of chain + hip_joint_names = ["right_hip", "left_hip"] # in order [right, left] + + bone_order_names_with_parents = [ + ("pelvis", None), + ("left_hip", "pelvis"), + ("right_hip", "pelvis"), + ("spine1", "pelvis"), + ("left_knee", "left_hip"), + ("right_knee", "right_hip"), + ("spine2", "spine1"), + ("left_ankle", "left_knee"), + ("right_ankle", "right_knee"), + ("spine3", "spine2"), + ("left_foot", "left_ankle"), + ("right_foot", "right_ankle"), + ("neck", "spine3"), + ("left_collar", "spine3"), + ("right_collar", "spine3"), + ("head", "neck"), + ("left_shoulder", "left_collar"), + ("right_shoulder", "right_collar"), + ("left_elbow", "left_shoulder"), + ("right_elbow", "right_shoulder"), + ("left_wrist", "left_elbow"), + ("right_wrist", "right_elbow"), + ] diff --git a/kimodo/skeleton/kinematics.py b/kimodo/skeleton/kinematics.py new file mode 100644 index 0000000000000000000000000000000000000000..a408eeb35f43f5bf46f01643742b97ebd4639990 --- /dev/null +++ b/kimodo/skeleton/kinematics.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Forward-kinematics primitives for articulated skeletons.""" + +from typing import List + +import einops +import torch +import torch.nn.functional as F + +from ..tools import ensure_batched + + +@ensure_batched(local_joint_rots=4, root_positions=2) +def fk( + local_joint_rots: torch.Tensor, + root_positions: torch.Tensor, + skeleton, + root_positions_is_global: bool = True, +): + """Compute global joint rotations and positions from local rotations. + + Args: + local_joint_rots: Local rotation matrices with shape `(..., J, 3, 3)`. + root_positions: Root translations with shape `(..., 3)`. + skeleton: Skeleton object exposing `neutral_joints`, `joint_parents`, and + `root_idx`. + root_positions_is_global: If `True`, neutral joints are recentered so root + translations are interpreted in world space. + + Returns: + Tuple `(global_joint_rots, posed_joints, posed_joints_norootpos)`. + """ + device = local_joint_rots.device + dtype = local_joint_rots.dtype + + # If skeleton has baked rest (e.g. from XML), identity local = baked rest pose. + # So training/inference local rotations are in reference to XML rest *orientations*. + rest_local = getattr(skeleton, "rest_local_rots", None) + if rest_local is not None: + rest_local = rest_local.to(device=device, dtype=dtype) + local_joint_rots = torch.einsum("jmn,...jno->...jmo", rest_local, local_joint_rots) + + # Rest positions for FK. Must be consistent with rest_local: when local = identity, + # FK(rest_local, neutral_joints) should equal the XML rest pose positions. So + # neutral_joints are not necessarily the raw XML joint positions; they are the + # rest layout that, when rotated by rest_local, yields the XML rest positions. + neutral_joints = skeleton.neutral_joints.to(device=device, dtype=dtype) + + if root_positions_is_global is True: + # Removing the pelvis offset from the neutral joints + # as the root positions does not depends on the pelvis offset of the skeleton + pelvis_offset = neutral_joints[skeleton.root_idx] + neutral_joints = neutral_joints - pelvis_offset + + # compute joint position and global rotations + joints = einops.repeat( + neutral_joints, + "j k -> b j k", + b=len(local_joint_rots), + ) + posed_joints_norootpos, global_joint_rots = batch_rigid_transform( + local_joint_rots, + joints, + skeleton.joint_parents, + skeleton.root_idx, + ) + # if root_positions_is_global is True: + # posed_joints_norootpos always start at zero + # otherwise it could start with the pelvis offset + + posed_joints = posed_joints_norootpos + root_positions[:, None] + return global_joint_rots, posed_joints, posed_joints_norootpos + + +def compute_idx_levels(parents): + """Group joint indices by hierarchy depth for level-wise FK updates. + + Args: + parents: Parent index tensor of shape `(J,)` with root parent `-1`. + + Returns: + List of index tensors, where each tensor contains joints at one depth. + """ + idx_levs = [[]] + lev_dicts = {0: -1} + for i in range(1, parents.shape[0]): + assert int(parents[i]) in lev_dicts + lev = lev_dicts[int(parents[i])] + 1 + if lev + 1 > len(idx_levs): + idx_levs.append([]) + idx_levs[lev].append(int(i)) + lev_dicts[int(i)] = lev + idx_levs = [torch.tensor(x).long() for x in idx_levs] + return idx_levs + + +def batch_rigid_transform(rot_mats, joints, parents, root_idx): + """Perform batch rigid transformation on a skeletal structure. + + Args: + rot_mats: Local rotation matrices for each joint: (B, J, 3, 3) + joints: Initial joint positions: (B, J, 3) + parents: Tensor indicating the parent of each joint: (J,) + root_idx (int): index of the root + + Returns: + Transformed joint positions after applying forward kinematics. + """ + + # Compute the hierarchical levels of joints based on their parent relationships + idx_levs = compute_idx_levels(parents) + + # Apply forward kinematics to transform the joints + return forward_kinematics(rot_mats, joints, parents, idx_levs, root_idx) + + +@torch.jit.script +def transform_mat(R, t): + """Creates a batch of transformation matrices. + + Args: + - R: Bx3x3 array of a batch of rotation matrices + - t: Bx3x1 array of a batch of translation vectors + Returns: + - T: Bx4x4 Transformation matrix + """ + # No padding left or right, only add an extra row + return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1.0)], dim=2) + + +@torch.jit.script +def forward_kinematics( + rot_mats, + joints, + parents: torch.Tensor, + idx_levs: List[torch.Tensor], + root_idx: int, +): + """Perform forward kinematics to compute posed joints and global rotation matrices. + + Args: + rot_mats: Local rotation matrices for each joint: (B, J, 3, 3) + joints: Initial joint positions: (B, J, 3) + parents: Tensor indicating the parent of each joint: (J,) + idx_levs: Tensors of joint indices grouped by depth in the kinematic tree. + root_idx (int): index of the root + Returns: + Posed joints: (B, J, 3) + Global rotation matrices: (B, J, 3, 3) + """ + + # Add an extra dimension to joints + joints = torch.unsqueeze(joints, dim=-1) + + # Compute relative joint positions + rel_joints = joints.clone() + + mask_no_root = torch.ones(joints.shape[1], dtype=torch.bool) + mask_no_root[root_idx] = False + rel_joints[:, mask_no_root] -= joints[:, parents[mask_no_root]].clone() + + # Compute initial transformation matrices + # (B, J + 1, 4, 4) + transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), rel_joints.reshape(-1, 3, 1)).reshape( + -1, joints.shape[1], 4, 4 + ) + + # Initialize the root transformation matrices + transforms = torch.zeros_like(transforms_mat) + transforms[:, root_idx] = transforms_mat[:, root_idx] + + # Compute global transformations level by level + for indices in idx_levs: + curr_res = torch.matmul(transforms[:, parents[indices]], transforms_mat[:, indices]) + transforms[:, indices] = curr_res + + # Extract posed joint positions from the transformation matrices + posed_joints = transforms[:, :, :3, 3] + + # Extract global rotation matrices from the transformation matrices + global_rot_mat = transforms[:, :, :3, :3] + + return posed_joints, global_rot_mat diff --git a/kimodo/skeleton/registry.py b/kimodo/skeleton/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..19703f377059d96c8487905537a35f6aea45e68e --- /dev/null +++ b/kimodo/skeleton/registry.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Factory helpers for building predefined skeleton variants.""" + +from pathlib import Path + +from kimodo.assets import SKELETONS_ROOT + +from .definitions import ( + G1Skeleton34, + SMPLXSkeleton22, + SOMASkeleton30, + SOMASkeleton77, +) + + +def build_skeleton(nbjoints: int, assets_folder: str | Path = SKELETONS_ROOT): + """Instantiate a known skeleton class from its joint count. + + Supported joint counts: 30 (SOMA compact), 34 (G1), 77 (SOMA full), 22 (SMPLX). + + Args: + nbjoints: Number of joints expected in the skeleton representation. + assets_folder: Base skeleton-assets directory containing per-skeleton subfolders. + + Returns: + A configured `SkeletonBase` subclass instance. + + Raises: + ValueError: If `nbjoints` does not match a registered skeleton. + """ + assets_folder = Path(assets_folder) + if nbjoints == 34: + return G1Skeleton34(assets_folder / "g1skel34") + elif nbjoints == 22: + return SMPLXSkeleton22(assets_folder / "smplx22") + elif nbjoints == 30: + return SOMASkeleton30(assets_folder / "somaskel30") + elif nbjoints == 77: + return SOMASkeleton77(assets_folder / "somaskel77") + else: + raise ValueError("This skeleton is not recognized.") diff --git a/kimodo/skeleton/transforms.py b/kimodo/skeleton/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d54e0ac8bbfefeea35fbc3aeb1f108db0d6fcadb --- /dev/null +++ b/kimodo/skeleton/transforms.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Rotation-space conversion utilities for skeleton motion data.""" + +import einops +import torch + +from ..tools import ensure_batched +from .kinematics import batch_rigid_transform + + +def global_rots_to_local_rots(global_joint_rots: torch.Tensor, skeleton): + """Convert global rotations to local rotations using a skeleton hierarchy. + + Args: + global_joint_rots: Global rotation matrices with shape `(..., J, 3, 3)`. + skeleton: Skeleton object exposing `joint_parents` and `root_idx`. + + Returns: + Local rotation matrices with the same leading shape as the input. + """ + # Doing big batch + global_joint_mats, ps = einops.pack( + [global_joint_rots], + "* nbjoints dim1 dim2", + ) + + # obtain back the local rotations from the new global rotations + parent_rot_mats = global_joint_mats[:, skeleton.joint_parents] + + parent_rot_mats[:, skeleton.root_idx] = torch.eye(3) # the root joint + parent_rot_mats_inv = parent_rot_mats.transpose(2, 3) + local_rot_mats = torch.einsum( + "T N m n, T N n o -> T N m o", + parent_rot_mats_inv, + global_joint_mats, + ) + [local_rot_mats] = einops.unpack(local_rot_mats, ps, "* nbjoints dim1 dim2") + return local_rot_mats + + +@ensure_batched(local_rot_mats=4) +def change_tpose(local_rot_mats: torch.Tensor, global_rot_offsets: torch.Tensor, skeleton): + """Re-express local rotations in another t_pose based on the global rotation offsets. + + Args: + local_rot_mats: Local rotation matrices with shape `(..., J, 3, 3)`. + global_rot_offsets: Global rotation offsets with shape `(..., J, 3, 3)`. + skeleton: Skeleton object exposing `joint_parents`, + `root_idx`, and `nbjoints`. + + Returns: + Tuple `(new_local_rot_mats, new_global_rot_mats)` in the standard frame. + """ + + device, dtype = local_rot_mats.device, local_rot_mats.dtype + global_rot_offsets = global_rot_offsets.to(device=device, dtype=dtype) + + root_idx = skeleton.root_idx + joint_parents = skeleton.joint_parents + # These are dummy joint positions, will not be used + neutral_joints = torch.ones((len(local_rot_mats), skeleton.nbjoints, 3), device=device, dtype=dtype) + + # get the old joint rotations in the same global space as the t-pose + # Note: the neutral joints we use here doesn't matter, because we are only using the global rotation outputs + _, global_rot_mats = batch_rigid_transform(local_rot_mats, neutral_joints, joint_parents, root_idx) # (T, N, 3, 3) + + # compute the desired joint rotations in the frame of the new t-pose + new_global_rot_mats = torch.einsum("T N m n, N o n -> T N m o", global_rot_mats, global_rot_offsets) + # convert back to local rotations + new_local_rot_mats = global_rots_to_local_rots(new_global_rot_mats, skeleton) + return new_local_rot_mats, new_global_rot_mats + + +@ensure_batched(local_rot_mats=4) +def to_standard_tpose(local_rot_mats: torch.Tensor, skeleton): + """Re-express local rotations in the skeleton's standard T-pose convention. + + Args: + local_rot_mats: Local rotation matrices with shape `(..., J, 3, 3)`. + skeleton: Skeleton object exposing `global_rot_offsets`, `joint_parents`, + `root_idx`, and `nbjoints`. + + Returns: + Tuple `(new_local_rot_mats, new_global_rot_mats)` in the standard frame. + """ + global_rot_offsets = skeleton.global_rot_offsets + return change_tpose(local_rot_mats, global_rot_offsets, skeleton) + + +@ensure_batched(local_rot_mats=4) +def from_standard_tpose(local_rot_mats: torch.Tensor, skeleton): + """Re-express local rotations from the skeleton's standard T-pose convention to the original + formulation. + + Args: + local_rot_mats: Local rotation matrices with shape `(..., J, 3, 3)`. + skeleton: Skeleton object exposing `global_rot_offsets`, `joint_parents`, + `root_idx`, and `nbjoints`. + + Returns: + Tuple `(new_local_rot_mats, new_global_rot_mats)` in the standard frame. + """ + global_rot_offsets = skeleton.global_rot_offsets + global_rot_offsets_T = global_rot_offsets.mT # do the inverse transform + return change_tpose(local_rot_mats, global_rot_offsets_T, skeleton) diff --git a/kimodo/tools.py b/kimodo/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..50b1b8d19994d0550da286acb49e8143de2dfc85 --- /dev/null +++ b/kimodo/tools.py @@ -0,0 +1,360 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Shared utilities: validation decorator, batching, JSON I/O, seeding, tensor conversion.""" + +import inspect +import json +import math +import random +from collections.abc import Mapping, Sequence +from functools import wraps +from math import prod +from pathlib import Path +from typing import Any, Callable, Mapping, Optional, ParamSpec, TypeVar, Union + +import numpy as np +import torch + + +def validate(validator, save_args: bool = False, super_init: bool = False): + """Create a decorator function for validating user inputs. + + Args: + validator: the function to validate (pydantic dataclass) + save (bool): save all the attributes to the obj [args[0]] + super_init (bool): init parent with no arguments (useful for using save on a nn.Module) + + Returns: + decorator: the decorator function + """ + + def decorator(func): + @wraps(func) + def validated_func(*args, **kwargs): + conf = validator(**kwargs) + + if save_args: + assert len(args) != 0 + obj = args[0] + + if super_init: + # init the parent module + super(type(obj), obj).__init__() + + for key, val in conf.__dict__.items(): + setattr(obj, key, val) + return func(*args, conf) + + return validated_func + + return decorator + + +# Type alias for clarity +Tensor = Any + +P = ParamSpec("P") +R = TypeVar("R") + + +def ensure_batched(**spec: int) -> Callable[[Callable[P, R]], Callable[P, R]]: + """Decorator to flatten complex batch dimensions. + + Fixes included: + 1. Handles 1D tensors (tail_ndim=0) correctly without slicing errors. + 2. Skips .reshape() if the input is already purely flat (Optimization). + """ + if not spec: + raise ValueError("At least one argument spec must be provided.") + + def decorator(fn: Callable[P, R]) -> Callable[P, R]: + sig = inspect.signature(fn) + + @wraps(fn) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + def _sequence_shape(name: str, value: Any) -> tuple[int, ...]: + if not isinstance(value, (list, tuple)): + return () + if len(value) == 0: + return (0,) + first_shape = _sequence_shape(name, value[0]) + for item in value[1:]: + item_shape = _sequence_shape(name, item) + if item_shape != first_shape: + raise ValueError(f"'{name}' must be a rectangular nested sequence, got ragged shape.") + return (len(value), *first_shape) + + def _shape_and_ndim(name: str, value: Any) -> tuple[tuple[int, ...], int]: + if hasattr(value, "shape") and hasattr(value, "ndim"): + shape = tuple(value.shape) + return shape, int(value.ndim) + if isinstance(value, (list, tuple)): + shape = _sequence_shape(name, value) + return shape, len(shape) + raise TypeError(f"'{name}' must be tensor-like or a nested list/tuple, got {type(value)}.") + + def _reshape_like(value: Any, shape: tuple[int, ...], name: str) -> Any: + if hasattr(value, "reshape"): + return value.reshape(*shape) + + if not isinstance(value, (list, tuple)): + raise TypeError(f"Cannot reshape '{name}' of type {type(value)}.") + + flat: list[Any] = [] + + def _flatten(x: Any) -> None: + if isinstance(x, (list, tuple)): + for item in x: + _flatten(item) + else: + flat.append(x) + + _flatten(value) + expected_size = prod(shape) if shape else 1 + if len(flat) != expected_size: + raise ValueError(f"Cannot reshape '{name}' with {len(flat)} elements into shape {shape}.") + + def _build(index: int, dims: tuple[int, ...]) -> tuple[Any, int]: + if not dims: + return flat[index], index + 1 + items = [] + for _ in range(dims[0]): + item, index = _build(index, dims[1:]) + items.append(item) + return items, index + + rebuilt, used = _build(0, shape) + if used != len(flat): + raise ValueError(f"Internal reshape error for '{name}': used {used}/{len(flat)} elements.") + if isinstance(value, tuple) and isinstance(rebuilt, list): + return tuple(rebuilt) + return rebuilt + + # --- 1. CANONICAL ARGUMENT --- + spec_items = list(spec.items()) + canonical_name = None + canonical_ndim = None + x0 = None + for name, ndim in spec_items: + candidate = bound.arguments.get(name, None) + if candidate is not None: + canonical_name = name + canonical_ndim = ndim + x0 = candidate + break + if canonical_name is None: + raise ValueError( + "All canonical candidates are None: " + ", ".join(f"'{name}'" for name, _ in spec_items) + ) + + # Calculate split between Batch dims and Feature dims + expected_tail_dims = canonical_ndim - 1 # e.g. 3 - 1 = 2 (Sequence, Feat) + x0_shape, x0_ndim = _shape_and_ndim(canonical_name, x0) + + # Validation + if x0_ndim < expected_tail_dims: + raise ValueError(f"'{canonical_name}' ndim={x0_ndim} < expected {expected_tail_dims} tail dims.") + + # --- LOGIC FIX 1: Handle 0 tail dims correctly --- + if expected_tail_dims == 0: + orig_batch_shape = x0_shape + tail_shape = () + else: + orig_batch_shape = x0_shape[:-expected_tail_dims] + tail_shape = x0_shape[-expected_tail_dims:] + + # Calculate flattened batch size + # If orig_batch_shape is () (scalar input), size is 1. + B_flat = prod(orig_batch_shape) if orig_batch_shape else 1 + + # Determine if we added a fake batch dim (unbatched input) + is_unbatched_input = len(orig_batch_shape) == 0 + + # --- LOGIC FIX 2: Skip reshape if already flat (Optimization) --- + # If batch shape is already 1D (e.g. [2]), we don't need to reshape [2, 140, 5] -> [2, 140, 5] + is_already_flat = len(orig_batch_shape) == 1 + + if is_unbatched_input: + # (H, W) -> (1, H, W) + x0_batched = _reshape_like(x0, (1, *tail_shape), canonical_name) + elif is_already_flat: + # (B, H, W) -> Keep as is + x0_batched = x0 + else: + # (B1, B2, H, W) -> (B1*B2, H, W) + x0_batched = _reshape_like(x0, (B_flat, *tail_shape), canonical_name) + + bound.arguments[canonical_name] = x0_batched + + # --- 2. OTHER ARGUMENTS --- + for name, target_ndim in spec_items: + if name == canonical_name: + continue + val = bound.arguments.get(name, None) + if val is None: + continue + + arg_tail_dims = target_ndim - 1 # e.g. for lengths=1, tail=0 + val_shape, val_ndim = _shape_and_ndim(name, val) + + # Validate + if val_ndim < arg_tail_dims: + raise ValueError(f"'{name}' ndim={val_ndim} too small.") + + # --- Get Batch Shape (With 0-tail fix) --- + if arg_tail_dims == 0: + val_batch_shape = val_shape + val_tail_shape = () + else: + val_batch_shape = val_shape[:-arg_tail_dims] + val_tail_shape = val_shape[-arg_tail_dims:] + + # --- Check Mismatch --- + # Unbatched inputs must match unbatched canonical + if len(val_batch_shape) == 0: + if not is_unbatched_input: + raise ValueError(f"'{name}' is unbatched but canonical is batched.") + val_batched = _reshape_like(val, (1, *val_tail_shape), name) + else: + # Batched inputs must match canonical batch shape EXACTLY + if val_batch_shape != orig_batch_shape: + raise ValueError( + f"Batch dimensions mismatch! '{canonical_name}' has {orig_batch_shape}, " + f"but '{name}' has {val_batch_shape}." + ) + + # Optimization: Don't reshape if already flat + if is_already_flat: + val_batched = val + else: + val_batched = _reshape_like(val, (B_flat, *val_tail_shape), name) + + bound.arguments[name] = val_batched + + # --- 3. EXECUTION --- + out = fn(**bound.arguments) + + # --- 4. RESTORE --- + def restore(obj): + if isinstance(obj, Mapping): + return {k: restore(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return type(obj)(restore(x) for x in obj) + + if hasattr(obj, "shape"): + if obj.ndim == 0: + return obj + + # Verify batch dimension exists and wasn't reduced + if obj.shape[0] != B_flat: + return obj + + # If input was simple (B, ...), return simple (B, ...) + if is_already_flat: + return obj + + rest = obj.shape[1:] + + if is_unbatched_input: + assert obj.shape[0] == 1, "The batch size should be 1 for unbatched." + return obj[0] + + return obj.reshape(*orig_batch_shape, *rest) + return obj + + return restore(out) + + return wrapper + + return decorator + + +def to_numpy(obj): + """Recursively convert tensors in dicts/lists/tuples to numpy arrays; leave other types + unchanged.""" + if isinstance(obj, Mapping): + return {k: to_numpy(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return type(obj)(to_numpy(x) for x in obj) + if isinstance(obj, torch.Tensor): + return obj.cpu().numpy() + return obj + + +def to_torch(obj, device=None, dtype=None): + """Recursively convert numpy arrays in dicts/lists/tuples to torch tensors; optionally move to + device/dtype.""" + if isinstance(obj, Mapping): + return {k: to_torch(v, device, dtype) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return type(obj)(to_torch(x, device, dtype) for x in obj) + if isinstance(obj, np.ndarray): + obj = torch.from_numpy(obj) + if isinstance(obj, torch.Tensor): + if dtype is not None: + obj = obj.to(dtype=dtype) + if device is None: + return obj + return obj.to(device) + return obj + + +def seed_everything(seed: int, deterministic: bool = False) -> None: + """Seed all random number generators.""" + random.seed(seed) # for Python random module. + np.random.seed(seed) # for NumPy. + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + torch.backends.cudnn.deterministic = True # for deterministic behavior. + torch.backends.cudnn.benchmark = False # if you want to make the behavior deterministic. + + +def load_json(path: Union[str, Path]) -> Any: + """Load a JSON file and return its contents. + + Args: + path (str | Path): Path to the JSON file. + + Returns: + Any: Parsed JSON content (dict, list, etc.). + + Raises: + FileNotFoundError: If the file does not exist. + ValueError: If the file is not valid JSON. + """ + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"JSON file not found: {path}") + + try: + with path.open("r", encoding="utf-8") as f: + return json.load(f) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in file {path}: {e}") from e + + +def save_json(path: Union[str, Path], data: Any) -> None: + """Save data to a JSON file. + + Args: + path (str | Path): Path to the JSON file. + data (Any): Data to save (must be JSON serializable). + + Raises: + ValueError: If the data is not JSON serializable. + """ + path = Path(path) + + # Create parent directories if they don't exist + path.parent.mkdir(parents=True, exist_ok=True) + + try: + with path.open("w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + except (TypeError, ValueError) as e: + raise ValueError(f"Data is not JSON serializable: {e}") from e diff --git a/kimodo/viz/__init__.py b/kimodo/viz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91a4846a952064d01070de6329b3610c19ba2eb0 --- /dev/null +++ b/kimodo/viz/__init__.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Viser-based 3D visualization for skeletons and motion.""" + +from . import viser_utils +from .viser_utils import ( + Character, + CharacterMotion, + ConstraintSet, + EEJointsKeyframeSet, + FullbodyKeyframeSet, + GuiElements, + RootKeyframe2DSet, + SkeletonMesh, + WaypointMesh, + load_example_cases, +) + +__all__ = [ + "Character", + "CharacterMotion", + "ConstraintSet", + "EEJointsKeyframeSet", + "FullbodyKeyframeSet", + "GuiElements", + "RootKeyframe2DSet", + "SkeletonMesh", + "WaypointMesh", + "load_example_cases", + "viser_utils", +] diff --git a/kimodo/viz/constraint_ui.py b/kimodo/viz/constraint_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..1737714ffb22af447ed20f8de786ea4362956882 --- /dev/null +++ b/kimodo/viz/constraint_ui.py @@ -0,0 +1,1079 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Constraint visualization and frame indexing for the viz UI.""" + +from typing import List, Optional + +import numpy as np +import torch + +import viser +import viser.transforms as tf +from kimodo.motion_rep.smooth_root import get_smooth_root_pos +from kimodo.skeleton import SkeletonBase +from kimodo.tools import to_numpy, to_torch + +from .scene import SkeletonMesh, WaypointMesh + + +def update_interval(interval_start, interval_end, start_frame_idx, end_frame_idx): + """Updates an interval after removing the range from start_frame_idx to end_frame_idx.""" + # Calculate new range after removing [start_frame_idx, end_frame_idx] + # Case 1: Removal fully contains the interval -> delete entirely + if start_frame_idx <= interval_start and end_frame_idx >= interval_end: + return None, None # Already removed, don't recreate + # Case 2: Removal is at the start of interval -> shrink from start + elif start_frame_idx <= interval_start and end_frame_idx < interval_end: + new_start = end_frame_idx + 1 + new_end = interval_end + # Case 3: Removal is at the end of interval -> shrink from end + elif start_frame_idx > interval_start and end_frame_idx >= interval_end: + new_start = interval_start + new_end = start_frame_idx - 1 + # Case 4: Removal is in the middle -> keep the larger portion + else: # start_frame_idx > interval_start and end_frame_idx < interval_end + left_size = start_frame_idx - interval_start + right_size = interval_end - end_frame_idx + if left_size >= right_size: + new_start = interval_start + new_end = start_frame_idx - 1 + else: + new_start = end_frame_idx + 1 + new_end = interval_end + return new_start, new_end + + +class ConstraintSet: + def __init__( + self, + name: str, + server: viser.ViserServer, + skeleton: SkeletonBase, + display_name: Optional[str] = None, + ): + self.name = name + self.server = server + self.skeleton = skeleton + self.display_name = display_name if display_name is not None else name + + self.keyframes = dict() # frame_idx -> poses + self.frame2keyid = dict() # frame_idx -> list of keyframe ids at this frame + self.scene_elements = dict() # frame_idx -> meshes, labels, etc. + self.interval_labels = dict() # (start_frame_idx, end_frame_idx) -> interval_label + self.labels_visible = True + + def set_label_visibility(self, visible: bool) -> None: + """Show or hide constraint labels without deleting them.""" + self.labels_visible = visible + for scene_data in self.scene_elements.values(): + label = scene_data.get("label") + if label is not None: + label.visible = visible + for interval_label in self.interval_labels.values(): + interval_label.visible = visible + + def set_overlay_visibility(self, only_frame: Optional[int] = None) -> None: + """Show all overlay elements, or only those at the given frame. + + Args: + only_frame: If None, show all overlays. If int, show only overlays at that frame. + """ + raise NotImplementedError("Subclasses must implement this method") + + def add_keyframe(self, keyframe_id: str, frame_idx: int, pose_data: torch.Tensor): + """Adds a single keyframe at the given frame with the given pose data. + + Args: + keyframe_id: str, id for the keyframe. Must be unique within the given frame_idx. + frame_idx: int, frame index to add the keyframe at + pose_data: torch.Tensor, e.g. full-body pose, EE pose, 2D root pose, etc. + """ + raise NotImplementedError("Subclasses must implement this method") + + def add_interval( + self, + interval_id: str, + start_frame_idx: int, + end_frame_idx: int, + pose_seq_data: torch.Tensor, + ): + """Adds a keyframe interval between the given start and end frames with the given pose data. + + Args: + interval_id: str, id for the interval. Must be unique within the given start_frame_idx and end_frame_idx. + start_frame_idx: int, start frame index of the interval + end_frame_idx: int, end frame index of the interval + pose_seq_data: torch.Tensor, data for constrained interval, e.g. full-body poses, EE poses, 2D root poses, etc. + """ + raise NotImplementedError("Subclasses must implement this method") + + def _add_interval_label(self, start_frame_idx: int, end_frame_idx: int): + """ + Adds an interval label between the given start and end frames + Args: + start_frame_idx: int, start frame index of the interval + end_frame_idx: int, end frame index of the interval + """ + mid = int((start_frame_idx + end_frame_idx) / 2) + interval_label_pos = self._get_label_pos(mid) + interval_label = self.server.scene.add_label( + name=f"/{self.name}/interval_label_{start_frame_idx}_{end_frame_idx}", + text=f"{self.display_name} @ [{start_frame_idx}, {end_frame_idx}]", + position=interval_label_pos, + font_size_mode="screen", + font_screen_scale=0.7, + anchor="center-center", + ) + interval_label.visible = self.labels_visible + self.interval_labels[(start_frame_idx, end_frame_idx)] = interval_label + + def remove_keyframe(self, keyframe_id: str, frame_idx: int): + """ + Removes a keyframe at the given frame + Args: + keyframe_id: str, id for the keyframe to remove + frame_idx: int, frame index to remove the keyframe at + """ + raise NotImplementedError("Subclasses must implement this method") + + def remove_interval(self, interval_id: str, start_frame_idx: int, end_frame_idx: int): + """ + Removes an interval between the given start and end frames + Args: + interval_id: str, id for the interval to remove + start_frame_idx: int, start frame index of the interval + end_frame_idx: int, end frame index of the interval + """ + raise NotImplementedError("Subclasses must implement this method") + + def _get_label_pos(self, frame_idx: int): + """ + Returns the position of where to place the displayed label for the given frame index + Args: + frame_idx: int, frame index to get the label position for + """ + raise NotImplementedError("Subclasses must implement this method") + + def _remove_interval_and_update_label(self, interval_id: str, start_frame_idx: int, end_frame_idx: int): + """ + Removes an interval between the given start and end frames and updates the label + Args: + start_frame_idx: int, start frame index of the interval + end_frame_idx: int, end frame index of the interval + """ + for frame_idx in range(start_frame_idx, end_frame_idx + 1): + self.remove_keyframe(interval_id, frame_idx) + + # Update interval labels that overlap with the removed range + intervals_to_update = [] + for (interval_start, interval_end), label in list(self.interval_labels.items()): + # Check if intervals overlap + if interval_start <= end_frame_idx and interval_end >= start_frame_idx: + intervals_to_update.append((interval_start, interval_end, label)) + + for interval_start, interval_end, label in intervals_to_update: + # Remove old label from scene and dict + self.server.scene.remove_by_name(label.name) + del self.interval_labels[(interval_start, interval_end)] + + new_start, new_end = update_interval(interval_start, interval_end, start_frame_idx, end_frame_idx) + + if new_start is None or new_end is None: + continue + + # Create updated label with new range + if new_start <= new_end: + # Position label at midpoint - these keyframes are guaranteed to exist + # since the new range is outside the removal range + mid_frame = (new_start + new_end) // 2 + label_pos = self._get_label_pos(mid_frame) + new_label = self.server.scene.add_label( + name=f"/{self.name}/interval_label_{new_start}_{new_end}", + text=f"{self.display_name} @ [{new_start}, {new_end}]", + position=label_pos, + font_size_mode="screen", + font_screen_scale=0.7, + anchor="center-center", + ) + new_label.visible = self.labels_visible + self.interval_labels[(new_start, new_end)] = new_label + + def get_constraint_info(self, device: Optional[str] = None): + """Returns constraint information for generation (torch) or UI (numpy).""" + raise NotImplementedError("Subclasses must implement this method") + + def get_frame_idx(self): + """Returns all constrained frame indices in the set.""" + return [frame_idx for frame_idx in list(self.keyframes.keys())] + + def clear(self, frame_idx: Optional[int] = None): + """ + Clears all keyframes and intervals from the constraint set + Args: + frame_idx: int, sing frame index to clear if given + """ + raise NotImplementedError("Subclasses must implement this method") + + +def build_constraint_set_table_markdown(constraint_list: List[ConstraintSet]): + markdown = "| Track | Frame Num |\n" + markdown += "|------|----------|\n" + + # Sort constraints by frame_idx + for constraint in constraint_list: + frame_info = constraint.get_frame_idx() + if len(frame_info) > 0: + frame_info = ", ".join([str(frame) for frame in sorted(frame_info)]) + else: + frame_info = "-" + markdown += f"| {constraint.display_name} | {frame_info} |\n" + + return markdown + + +class FullbodyKeyframeSet(ConstraintSet): + def __init__( + self, + name: str, + server: viser.ViserServer, + skeleton: SkeletonBase, + display_name: Optional[str] = None, + ): + super().__init__(name, server, skeleton, display_name=display_name) + + def add_keyframe( + self, + keyframe_id: str, + frame_idx: int, + joints_pos: torch.Tensor | np.ndarray, + joints_rot: torch.Tensor | np.ndarray, + viz_label: bool = True, + exists_ok: bool = False, + ): + """Adds a single full-body keyframe at the given frame or updates the existing one at this + frame. Note if a keyframe already exists at this frame, it will be updated to the given + pose. + + Args: + keyframe_id: str, id for the keyframe. Must be unique within the given frame_idx. + frame_idx: int, frame index to add the keyframe at + joints_pos: torch.Tensor, [J, 3] joints positions to add the keyframe at + """ + # create/update scene elements + if frame_idx in self.keyframes: + skeleton_mesh = self.scene_elements[frame_idx]["skeleton_mesh"] + skeleton_mesh.set_pose(to_torch(joints_pos)) + if viz_label and "label" in self.scene_elements[frame_idx]: + label = self.scene_elements[frame_idx]["label"] + label.position = to_numpy(joints_pos)[self.skeleton.root_idx] + label.visible = self.labels_visible + else: + # create skeleton to visualize the full-body constraint + skeleton_mesh = SkeletonMesh( + f"/{self.name}/skeleton_{frame_idx}", + self.server, + self.skeleton, + joint_color=(255, 235, 0), + bone_color=(255, 0, 0), + starting_joints_pos=to_torch(joints_pos), + ) + self.scene_elements[frame_idx] = { + "skeleton_mesh": skeleton_mesh, + } + if viz_label: + label = self.server.scene.add_label( + name=f"/{self.name}/label_{frame_idx}", + text=f"{self.display_name} @ {frame_idx}", + position=to_numpy(joints_pos)[self.skeleton.root_idx], + font_size_mode="screen", + font_screen_scale=0.7, + anchor="center-center", + ) + label.visible = self.labels_visible + self.scene_elements[frame_idx]["label"] = label + + # set/update data + self.keyframes[frame_idx] = { + "joints_pos": to_numpy(joints_pos), + "joints_rot": to_numpy(joints_rot), + } + + if frame_idx not in self.frame2keyid: + self.frame2keyid[frame_idx] = [] + + if keyframe_id in self.frame2keyid[frame_idx]: + if not exists_ok: + raise AssertionError("keyframe_id already exists in this frame!") + else: + self.frame2keyid[frame_idx].append(keyframe_id) + + def add_interval( + self, + interval_id: str, + start_frame_idx: int, + end_frame_idx: int, + joints_pos: torch.Tensor, + joints_rot: torch.Tensor, + ): + """Adds a full-body keyframe interval between the given start and end frames. + + Args: + start_frame_idx: int, start frame index of the interval + end_frame_idx: int, end frame index of the interval + joints_pos: torch.Tensor, [T, J, 3] joints positions within the interval + """ + assert joints_pos.shape[0] == end_frame_idx - start_frame_idx + 1 + for frame_idx in range(start_frame_idx, end_frame_idx + 1): + rel_idx = frame_idx - start_frame_idx + self.add_keyframe( + interval_id, + frame_idx, + joints_pos[rel_idx], + joints_rot[rel_idx], + viz_label=False, + ) + + # add separate interval label + self._add_interval_label(start_frame_idx, end_frame_idx) + + def remove_keyframe(self, keyframe_id: str, frame_idx: int): + if frame_idx not in self.keyframes: + return + if keyframe_id not in self.frame2keyid[frame_idx]: + return + self.frame2keyid[frame_idx].remove(keyframe_id) + if len(self.frame2keyid[frame_idx]) == 0: + del self.frame2keyid[frame_idx] + self.clear(frame_idx) + + def _get_label_pos(self, frame_idx: int): + return self.keyframes[frame_idx]["joints_pos"][self.skeleton.root_idx] + + def remove_interval(self, interval_id: str, start_frame_idx: int, end_frame_idx: int): + self._remove_interval_and_update_label(interval_id, start_frame_idx, end_frame_idx) + + def get_constraint_info(self, device: Optional[str] = None): + all_joints_pos = [] + all_joints_rot = [] + for v in self.keyframes.values(): + joints_pos = to_torch(v["joints_pos"], device=device) + joints_rot = to_torch(v["joints_rot"], device=device) + if len(joints_pos.shape) == 2: + all_joints_pos.append(joints_pos[None]) + else: + all_joints_pos.append(joints_pos) + if len(joints_rot.shape) == 3: + all_joints_rot.append(joints_rot[None]) + else: + all_joints_rot.append(joints_rot) + + all_joints_pos = torch.cat(all_joints_pos, dim=0) if len(all_joints_pos) > 0 else None + all_joints_rot = torch.cat(all_joints_rot, dim=0) if len(all_joints_rot) > 0 else None + + return { + "frame_idx": self.get_frame_idx(), + "joints_pos": all_joints_pos, + "joints_rot": all_joints_rot, + } + + def clear(self, frame_idx: Optional[int] = None): + frame_idx_list = list(self.keyframes.keys()) if frame_idx is None else [frame_idx] + for fidx in frame_idx_list: + self.scene_elements[fidx]["skeleton_mesh"].clear() + if "ee_rotation_axes" in self.scene_elements[fidx]: + self.server.scene.remove_by_name(self.scene_elements[fidx]["ee_rotation_axes"].name) + if "label" in self.scene_elements[fidx]: + self.server.scene.remove_by_name(self.scene_elements[fidx]["label"].name) + + self.keyframes.pop(fidx) + self.scene_elements.pop(fidx) + self.frame2keyid.pop(fidx, None) + + if frame_idx is None: + # clear all interval labels if clearing all keyframes + for interval_label in list(self.interval_labels.values()): + self.server.scene.remove_by_name(interval_label.name) + self.interval_labels.clear() + self.frame2keyid.clear() + + def set_overlay_visibility(self, only_frame: Optional[int] = None) -> None: + show_all = only_frame is None + for fidx, scene_data in self.scene_elements.items(): + visible = show_all or fidx == only_frame + scene_data["skeleton_mesh"].set_visibility(visible) + label = scene_data.get("label") + if label is not None: + label.visible = visible and self.labels_visible + for interval_label in self.interval_labels.values(): + interval_label.visible = show_all and self.labels_visible + + +class EEJointsKeyframeSet(ConstraintSet): + def __init__( + self, + name: str, + server: viser.ViserServer, + skeleton: SkeletonBase, + display_name: Optional[str] = None, + ): + super().__init__(name, server, skeleton, display_name=display_name) + + # frame_idx -> list of (keyframe_id, joint_names) at this frame + self.frame2keyid = dict() + + def create_scene_elements( + self, + frame_idx: int, + joints_pos: torch.Tensor | np.ndarray, + joints_rot: Optional[torch.Tensor | np.ndarray], + joint_names: List[str], + viz_label: bool = True, + ): + # create skeleton to visualize the full-body constraint + ee_joint_indices = [] + ee_gizmo_indices = [] + constrained_bone_idx = [] + for joint_name in joint_names: + if joint_name == "Hips": + continue + elif joint_name in ["LeftHand", "RightHand", "LeftFoot", "RightFoot"]: + expanded_joint_names = { + "LeftHand": self.skeleton.left_hand_joint_names, + "RightHand": self.skeleton.right_hand_joint_names, + "LeftFoot": self.skeleton.left_foot_joint_names, + "RightFoot": self.skeleton.right_foot_joint_names, + }[joint_name] + ee_joint_indices.extend([self.skeleton.bone_order_names_index[joint] for joint in expanded_joint_names]) + if len(expanded_joint_names) > 1: + ee_gizmo_indices.extend( + [self.skeleton.bone_order_names_index[joint] for joint in expanded_joint_names[:1]] + ) + constrained_bone_idx.extend( + [self.skeleton.bone_order_names_index[joint] - 1 for joint in expanded_joint_names[1:]] + ) + else: + raise ValueError(f"Invalid joint name: {joint_name}") + + # de-duplicate while preserving order + ee_joint_indices = list(dict.fromkeys(ee_joint_indices)) + ee_gizmo_indices = list(dict.fromkeys(ee_gizmo_indices)) + constrained_bone_idx = list(dict.fromkeys(constrained_bone_idx)) + + constrained_idx = [self.skeleton.root_idx] + ee_joint_indices + + constrained_idx = np.array(constrained_idx) + constrained_bone_idx = np.array(constrained_bone_idx) + + # create skeleton to visualize the full-body constraint + joint_color = np.full((self.skeleton.nbjoints, 3), (220, 220, 220)) + bone_color = np.full((self.skeleton.nbjoints - 1, 3), (220, 220, 220)) + # color constrained joints differently + joint_color[constrained_idx] = (255, 0, 0) + bone_color[constrained_bone_idx] = (255, 0, 0) + skeleton_mesh = SkeletonMesh( + f"/{self.name}/skeleton_{frame_idx}", + self.server, + self.skeleton, + joint_color=joint_color, + bone_color=bone_color, + starting_joints_pos=to_torch(joints_pos), + ) + + self.scene_elements[frame_idx] = { + "skeleton_mesh": skeleton_mesh, + } + joints_pos_np = to_numpy(joints_pos) + joints_rot_np = to_numpy(joints_rot) if joints_rot is not None else None + if joints_rot_np is not None and len(ee_gizmo_indices) > 0: + ee_axes = self.server.scene.add_batched_axes( + f"/{self.name}/ee_rot_axes_{frame_idx}", + batched_wxyzs=tf.SO3.from_matrix(joints_rot_np[ee_gizmo_indices]).wxyz, + batched_positions=joints_pos_np[ee_gizmo_indices], + axes_length=0.07, + axes_radius=0.007, + ) + self.scene_elements[frame_idx]["ee_rotation_axes"] = ee_axes + if viz_label: + label = self.server.scene.add_label( + name=f"/{self.name}/label_{frame_idx}", + text=f"{self.display_name} @ {frame_idx}", + position=joints_pos_np[self.skeleton.root_idx] + np.array([0.0, 0.05, 0.0]), + font_size_mode="screen", + font_screen_scale=0.7, + anchor="bottom-center", + ) + label.visible = self.labels_visible + self.scene_elements[frame_idx]["label"] = label + + def add_keyframe( + self, + keyframe_id: str, + frame_idx: int, + joints_pos: torch.Tensor | np.ndarray, + joints_rot: torch.Tensor | np.ndarray, + joint_names: List[str], + end_effector_type: str, + viz_label: bool = True, + exists_ok: bool = False, + ): + """Adds a single EE keyframe at the given frame or updates the existing one at this frame. + + Args: + keyframe_id: str, id for the keyframe. Must be unique within the given frame_idx. + frame_idx: int, frame index to add the keyframe at + joints_pos: torch.Tensor, [J, 3] joints positions to add the keyframe at + joints_rot: torch.Tensor, [J, 3, 3] joints rotation matrices to add the keyframe at + joint_names: List[str], names of the joints to add the keyframe at + """ + need_create_viz = True + joint_names_input = joint_names + + if not isinstance(end_effector_type, set): + end_effector_type = set([end_effector_type]) + + # create/update scene elements + if frame_idx in self.keyframes: + if joint_names != self.keyframes[frame_idx]["joint_names"]: + # merge together with existing constraint if needed + joint_names = set(joint_names) + joint_names.update(set(self.keyframes[frame_idx]["joint_names"])) + joint_names = list(joint_names) + end_effector_type.update(self.keyframes[frame_idx]["end_effector_type"]) + # need to re-create viz elements + self.clear(frame_idx) + else: + need_create_viz = False + # overwrite the pose with the latest one + skeleton_mesh = self.scene_elements[frame_idx]["skeleton_mesh"] + skeleton_mesh.set_pose(to_torch(joints_pos)) + if "ee_rotation_axes" in self.scene_elements[frame_idx]: + ee_gizmo_indices = [] + for joint_name in joint_names: + if joint_name == "Hips": + continue + elif joint_name in [ + "LeftHand", + "RightHand", + "LeftFoot", + "RightFoot", + ]: + expanded_joint_names = { + "LeftHand": self.skeleton.left_hand_joint_names, + "RightHand": self.skeleton.right_hand_joint_names, + "LeftFoot": self.skeleton.left_foot_joint_names, + "RightFoot": self.skeleton.right_foot_joint_names, + }[joint_name] + if len(expanded_joint_names) > 0: + ee_gizmo_indices.extend( + [self.skeleton.bone_order_names_index[joint] for joint in expanded_joint_names[:1]] + # take only the base joint of the end effector (to avoid clutter) + ) + else: + raise ValueError(f"Invalid joint name: {joint_name}") + ee_gizmo_indices = list(dict.fromkeys(ee_gizmo_indices)) + if len(ee_gizmo_indices) > 0: + ee_axes = self.scene_elements[frame_idx]["ee_rotation_axes"] + joints_pos_np = to_numpy(joints_pos) + joints_rot_np = to_numpy(joints_rot) + ee_axes.batched_positions = joints_pos_np[ee_gizmo_indices] + ee_axes.batched_wxyzs = tf.SO3.from_matrix(joints_rot_np[ee_gizmo_indices]).wxyz + if viz_label and "label" in self.scene_elements[frame_idx]: + label = self.scene_elements[frame_idx]["label"] + label.position = to_numpy(joints_pos)[self.skeleton.root_idx] + label.visible = self.labels_visible + + if need_create_viz: + self.create_scene_elements(frame_idx, joints_pos, joints_rot, joint_names, viz_label=viz_label) + + # set/update data + self.keyframes[frame_idx] = { + "joints_pos": to_numpy(joints_pos), + "joints_rot": to_numpy(joints_rot), + "joint_names": joint_names, + "end_effector_type": end_effector_type, + } + + if frame_idx not in self.frame2keyid: + self.frame2keyid[frame_idx] = [] + + known_keyframe_ids = {k: idx for idx, (k, _) in enumerate(self.frame2keyid[frame_idx])} + + if keyframe_id in known_keyframe_ids.keys(): + if not exists_ok: + raise AssertionError("keyframe_id already exists in this frame!") + idx = known_keyframe_ids[keyframe_id] + # override previous exisiting keyframe + self.frame2keyid[frame_idx][idx] = (keyframe_id, joint_names_input) + else: + # track which subset of joints are constrained by this keyframe_id + self.frame2keyid[frame_idx].append((keyframe_id, joint_names_input)) + + def add_interval( + self, + interval_id: str, + start_frame_idx: int, + end_frame_idx: int, + joints_pos: torch.Tensor | np.ndarray, + joints_rot: torch.Tensor | np.ndarray, + joint_names: List[str], + end_effector_type: str, + ): + """Adds an interval of EE keyframes at the given frame or updates the existing one at this + frame. + + Args: + interval_id: str, id for the interval. Must be unique within the given start_frame_idx and end_frame_idx. + start_frame_idx: int, start frame index to add the interval at + end_frame_idx: int, end frame index to add the interval at + joints_pos: torch.Tensor, [T, J, 3] joints positions to add the interval at + joints_rot: torch.Tensor, [T, J, 3, 3] joints rotation matrices to add the interval at + joint_names: List[str], names of the joints to add for the entire interval + """ + num_frames = end_frame_idx - start_frame_idx + 1 + joints_pos_np = to_numpy(joints_pos) + joints_rot_np = to_numpy(joints_rot) + assert joints_pos_np.shape[0] == num_frames + assert joints_rot_np.shape[0] == num_frames + + for frame_idx in range(start_frame_idx, end_frame_idx + 1): + rel_idx = frame_idx - start_frame_idx + self.add_keyframe( + interval_id, + frame_idx, + joints_pos_np[rel_idx], + joints_rot_np[rel_idx], + joint_names, + end_effector_type, + viz_label=False, + ) + self._add_interval_label(start_frame_idx, end_frame_idx) + + def remove_keyframe(self, keyframe_id: str, frame_idx: int): + """Removes a keyframe at the given frame or updates the existing one at this frame by + removing the specified joints. + + Args: + keyframe_id: str, id for the keyframe to remove. This determines which joints to remove. + frame_idx: int, frame index to remove the keyframe at + """ + if frame_idx not in self.keyframes: + return + + remaining_joint_names = set() + delete_idx = None + for i, (keyid, joint_names) in enumerate(self.frame2keyid[frame_idx]): + if keyid == keyframe_id: + delete_idx = i + else: + remaining_joint_names.update(joint_names) + if delete_idx is None: + # this keyframe_id is not in the specified frame + return + + self.frame2keyid[frame_idx].pop(delete_idx) + if len(remaining_joint_names) == 0: + # no more keyframes in this frame, clear the frame + del self.frame2keyid[frame_idx] + self.clear(frame_idx) + return + + # only deleting part of keyframe (potentially some subset of joints) + # delete the old visualization and add a new one with the updated joint set + new_joint_names = list(remaining_joint_names) + self.clear(frame_idx, scene_elements_only=True) + joints_pos = self.keyframes[frame_idx]["joints_pos"] + joints_rot = self.keyframes[frame_idx]["joints_rot"] + self.create_scene_elements(frame_idx, joints_pos, joints_rot, new_joint_names) + self.keyframes[frame_idx]["joint_names"] = new_joint_names + + def _get_label_pos(self, frame_idx: int): + return self.keyframes[frame_idx]["joints_pos"][self.skeleton.root_idx] + + def remove_interval(self, interval_id: str, start_frame_idx: int, end_frame_idx: int): + self._remove_interval_and_update_label(interval_id, start_frame_idx, end_frame_idx) + + def get_constraint_info(self, device: Optional[str] = None): + all_joints_pos = [] + all_joints_rot = [] + all_joints_names = [] + all_end_effector_type = [] + for v in self.keyframes.values(): + joints_pos = to_torch(v["joints_pos"], device=device) + joints_rot = to_torch(v["joints_rot"], device=device) + if len(joints_pos.shape) == 2: + all_joints_pos.append(joints_pos[None]) + else: + all_joints_pos.append(joints_pos) + if len(joints_rot.shape) == 3: + all_joints_rot.append(joints_rot[None]) + else: + all_joints_rot.append(joints_rot) + all_joints_names.append(v["joint_names"]) + all_end_effector_type.append(v["end_effector_type"]) + + all_joints_pos = torch.cat(all_joints_pos, dim=0) if len(all_joints_pos) > 0 else None + all_joints_rot = torch.cat(all_joints_rot, dim=0) if len(all_joints_rot) > 0 else None + + return { + "frame_idx": self.get_frame_idx(), + "joints_pos": all_joints_pos, + "joints_rot": all_joints_rot, + "joint_names": all_joints_names, + "end_effector_type": all_end_effector_type, + } + + def clear(self, frame_idx: Optional[int] = None, scene_elements_only: bool = False): + frame_idx_list = list(self.keyframes.keys()) if frame_idx is None else [frame_idx] + for fidx in frame_idx_list: + self.scene_elements[fidx]["skeleton_mesh"].clear() + if "ee_rotation_axes" in self.scene_elements[fidx]: + self.server.scene.remove_by_name(self.scene_elements[fidx]["ee_rotation_axes"].name) + if "label" in self.scene_elements[fidx]: + self.server.scene.remove_by_name(self.scene_elements[fidx]["label"].name) + self.scene_elements.pop(fidx) + if not scene_elements_only: + self.keyframes.pop(fidx) + + if frame_idx is None: + # clear all interval labels if clearing all keyframes + for interval_label in list(self.interval_labels.values()): + self.server.scene.remove_by_name(interval_label.name) + self.interval_labels.clear() + + def set_overlay_visibility(self, only_frame: Optional[int] = None) -> None: + show_all = only_frame is None + for fidx, scene_data in self.scene_elements.items(): + visible = show_all or fidx == only_frame + scene_data["skeleton_mesh"].set_visibility(visible) + if "ee_rotation_axes" in scene_data: + scene_data["ee_rotation_axes"].visible = visible + label = scene_data.get("label") + if label is not None: + label.visible = visible and self.labels_visible + for interval_label in self.interval_labels.values(): + interval_label.visible = show_all and self.labels_visible + + +class RootKeyframe2DSet(ConstraintSet): + def __init__( + self, + name: str, + server: viser.ViserServer, + skeleton: SkeletonBase, + display_name: Optional[str] = None, + ): + super().__init__(name, server, skeleton, display_name=display_name) + self.dense_path = False + self.smooth_path = True + self.line_segments = None # visualization of dense path + self.interval_line_segments = {} + + def add_keyframe( + self, + keyframe_id: str, + frame_idx: int, + root_pos: torch.Tensor | np.ndarray, + viz_label: bool = True, + update_path: bool = True, + viz_waypoint: bool = True, + exists_ok: bool = False, + ): + """Adds a single 2D root keyframe at the given frame or updates the existing one at this + frame. + + Args: + keyframe_id: str, id for the keyframe. Must be unique within the given frame_idx. + frame_idx: int, frame index to add the keyframe at + root_pos: torch.Tensor, [3] root position to add the keyframe at, y entry (index 1) should be 0 + viz_label: bool, whether to visualize the label for the keyframe + """ + root_pos_np = to_numpy(root_pos) + if frame_idx not in self.scene_elements: + self.scene_elements[frame_idx] = {} + + scene_data = self.scene_elements[frame_idx] + if frame_idx in self.keyframes: + waypoint = scene_data.get("waypoint") + if waypoint is not None: + waypoint.update_position(root_pos_np) + elif viz_waypoint: + waypoint = WaypointMesh( + f"/{self.name}/waypoint_{frame_idx}", + self.server, + position=root_pos_np, + ) + scene_data["waypoint"] = waypoint + + label = scene_data.get("label") + if viz_label and label is not None: + label.position = root_pos_np + label.visible = self.labels_visible + elif viz_label and label is None: + label = self.server.scene.add_label( + name=f"/{self.name}/label_{frame_idx}", + text=f"{self.display_name} @ {frame_idx}", + position=root_pos_np, + font_size_mode="screen", + font_screen_scale=0.7, + anchor="bottom-left", + ) + label.visible = self.labels_visible + scene_data["label"] = label + else: + if viz_waypoint: + waypoint = WaypointMesh( + f"/{self.name}/waypoint_{frame_idx}", + self.server, + position=root_pos_np, + ) + scene_data["waypoint"] = waypoint + if viz_label: + label = self.server.scene.add_label( + name=f"/{self.name}/label_{frame_idx}", + text=f"{self.display_name} @ {frame_idx}", + position=root_pos_np, + font_size_mode="screen", + font_screen_scale=0.7, + anchor="bottom-left", + ) + label.visible = self.labels_visible + scene_data["label"] = label + + # set/update data + self.keyframes[frame_idx] = root_pos_np + if frame_idx not in self.frame2keyid: + self.frame2keyid[frame_idx] = [] + + if keyframe_id in self.frame2keyid[frame_idx]: + if not exists_ok: + raise AssertionError("keyframe_id already exists in this frame!") + else: + self.frame2keyid[frame_idx].append(keyframe_id) + + # need to update path visualization + if self.line_segments is not None and update_path: + self.update_line_segments() + + def add_interval( + self, + interval_id: str, + start_frame_idx: int, + end_frame_idx: int, + root_pos: torch.Tensor | np.ndarray, + ): + """Adds an interval of 2D root keyframes between the given start and end frames. + + Args: + interval_id: str, id for the interval. Must be unique within the given start_frame_idx and end_frame_idx. + start_frame_idx: int, start frame index to add the interval at + end_frame_idx: int, end frame index to add the interval at + root_pos: torch.Tensor, [T, 3] root positions to add the interval at + """ + root_pos_np = to_numpy(root_pos) + assert root_pos_np.shape[0] == end_frame_idx - start_frame_idx + 1 + if root_pos_np.shape[0] >= 2: + points = np.zeros((root_pos_np.shape[0] - 1, 2, 3)) + points[:, 0] = root_pos_np[:-1] + points[:, 1] = root_pos_np[1:] + if interval_id in self.interval_line_segments: + self.server.scene.remove_by_name(self.interval_line_segments[interval_id].name) + self.interval_line_segments[interval_id] = self.server.scene.add_line_segments( + name=f"/{self.name}/interval_{interval_id}_line", + points=points, + colors=(255, 0, 0), + line_width=5.0, + ) + + for frame_idx in range(start_frame_idx, end_frame_idx + 1): + rel_idx = frame_idx - start_frame_idx + self.add_keyframe( + interval_id, + frame_idx, + root_pos_np[rel_idx], + viz_label=False, + update_path=False, + viz_waypoint=False, + ) + self._add_interval_label(start_frame_idx, end_frame_idx) + if self.line_segments is not None: + self.update_line_segments() + + def set_smooth_path(self, smooth_path: bool): + self.smooth_path = smooth_path + if self.line_segments is not None: + self.update_line_segments() + + def set_dense_path(self, dense_path: bool): + """If dense_path is True, will make the path dense by interpolated between added keyframes. + + Args: + dense_path: bool, whether to make the path dense + """ + self.dense_path = dense_path + if self.dense_path: + # visualize dense path with line segments + self.line_segments = self.server.scene.add_line_segments( + name=f"/{self.name}/line_segments", + points=np.zeros((1, 2, 3)), + colors=(255, 0, 0), + line_width=5.0, + ) + self.update_line_segments() + else: + if self.line_segments is not None: + self.server.scene.remove_by_name(self.line_segments.name) + self.line_segments = None + + def interpolate_path(self, t: np.ndarray): + """Interpolates the path between the given frame indices. + + Args: + t: np.ndarray, frame indices to interpolate at + """ + from scipy.interpolate import interp1d + + cur_info = self._get_sparse_constraint_info() + frame_idx = cur_info["frame_idx"] + all_root_pos = cur_info["root_pos"] + + x = all_root_pos[:, 0] + z = all_root_pos[:, 2] + + kind = "linear" + # if self.smooth_path and len(frame_idx) >= 3: + # kind = "quadratic" + + interp_x = interp1d(frame_idx, x, kind=kind) + interp_z = interp1d(frame_idx, z, kind=kind) + + x_new = interp_x(t) + z_new = interp_z(t) + + path3d = np.stack([x_new, np.zeros_like(x_new), z_new], axis=1) + + if self.smooth_path and len(frame_idx) >= 3: + path3d = get_smooth_root_pos(torch.from_numpy(path3d[None]))[0].numpy() + return path3d + + def update_line_segments(self): + if len(self.keyframes) < 2: + return + + t = np.array(sorted(self.get_frame_idx())) + if self.smooth_path: + # more points for smoothed curve + t = np.linspace(t[0], t[-1], 100) + + path3d = self.interpolate_path(t) + + points = np.zeros((len(path3d) - 1, 2, 3)) + points[:, 0] = path3d[:-1] + points[:, 1] = path3d[1:] + + self.line_segments.points = points + + def remove_keyframe(self, keyframe_id: str, frame_idx: int): + if frame_idx not in self.keyframes: + return + if keyframe_id not in self.frame2keyid[frame_idx]: + return + self.frame2keyid[frame_idx].remove(keyframe_id) + if len(self.frame2keyid[frame_idx]) == 0: + del self.frame2keyid[frame_idx] + self.clear(frame_idx) + if self.line_segments is not None: + self.update_line_segments() + + def _get_label_pos(self, frame_idx: int): + return self.keyframes[frame_idx] + + def remove_interval(self, interval_id: str, start_frame_idx: int, end_frame_idx: int): + if interval_id in self.interval_line_segments: + self.server.scene.remove_by_name(self.interval_line_segments[interval_id].name) + del self.interval_line_segments[interval_id] + self._remove_interval_and_update_label(interval_id, start_frame_idx, end_frame_idx) + + def _get_sparse_constraint_info(self): + all_root_pos = [] + for v in self.keyframes.values(): + v_np = to_numpy(v) + if len(v_np.shape) == 1: + all_root_pos.append(v_np[None]) + else: + all_root_pos.append(v_np) + if len(all_root_pos) > 0: + all_root_pos = np.concatenate(all_root_pos, axis=0) + else: + all_root_pos = None + return { + "frame_idx": self.get_frame_idx(), + "root_pos": all_root_pos, + } + + def get_constraint_info(self, device: Optional[str] = None): + if not self.dense_path or len(self.keyframes) == 0: + info = self._get_sparse_constraint_info() + return { + "frame_idx": info["frame_idx"], + "root_pos": to_torch(info["root_pos"], device=device, dtype=torch.float32), + } + else: + frame_idx_list = self.get_frame_idx() + min_frame_idx = min(frame_idx_list) + max_frame_idx = max(frame_idx_list) + t = np.arange(min_frame_idx, max_frame_idx + 1) + path3d = self.interpolate_path(t) + return { + "frame_idx": t.tolist(), + "root_pos": to_torch(path3d, device=device, dtype=torch.float32), + } + + def clear(self, frame_idx: Optional[int] = None): + frame_idx_list = list(self.keyframes.keys()) if frame_idx is None else [frame_idx] + for fidx in frame_idx_list: + scene_data = self.scene_elements.get(fidx, {}) + waypoint = scene_data.get("waypoint") + if waypoint is not None: + waypoint.clear() + label = scene_data.get("label") + if label is not None: + self.server.scene.remove_by_name(label.name) + + self.keyframes.pop(fidx) + self.scene_elements.pop(fidx) + + if frame_idx is None: + # clear all interval labels if clearing all keyframes + for interval_label in list(self.interval_labels.values()): + self.server.scene.remove_by_name(interval_label.name) + self.interval_labels.clear() + + # clear line segments if turning off dense path + if self.line_segments is not None: + self.server.scene.remove_by_name(self.line_segments.name) + self.line_segments = None + + for interval_line in list(self.interval_line_segments.values()): + self.server.scene.remove_by_name(interval_line.name) + self.interval_line_segments.clear() + + def set_overlay_visibility(self, only_frame: Optional[int] = None) -> None: + show_all = only_frame is None + for fidx, scene_data in self.scene_elements.items(): + visible = show_all or fidx == only_frame + waypoint = scene_data.get("waypoint") + if waypoint is not None: + waypoint.set_visible(visible) + label = scene_data.get("label") + if label is not None: + label.visible = visible and self.labels_visible + if self.line_segments is not None: + self.line_segments.visible = show_all + for line_handle in self.interval_line_segments.values(): + line_handle.visible = show_all + for interval_label in self.interval_labels.values(): + interval_label.visible = show_all and self.labels_visible + + +# +# GUI Elements that need to be tracked diff --git a/kimodo/viz/coords.py b/kimodo/viz/coords.py new file mode 100644 index 0000000000000000000000000000000000000000..6152e075d472785f886f2d57514ee646346e7176 --- /dev/null +++ b/kimodo/viz/coords.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Pure numpy coordinate/rotation helpers for viz.""" + +import numpy as np + + +def skew(v: np.ndarray) -> np.ndarray: + """Skew-symmetric matrix for cross products: skew(v) @ x == np.cross(v, x).""" + vx, vy, vz = float(v[0]), float(v[1]), float(v[2]) + return np.array([[0.0, -vz, vy], [vz, 0.0, -vx], [-vy, vx, 0.0]], dtype=np.float64) + + +def rotation_matrix_from_two_vec(v_from: np.ndarray, v_to: np.ndarray, eps: float = 1e-8) -> np.ndarray: + """Return R such that R @ v_from ~= v_to (both treated as 3D vectors). + + Uses a Rodrigues-style construction, with special handling for near-parallel and near-opposite + vectors for numerical stability. + """ + a = np.asarray(v_from, dtype=np.float64).reshape(3) + b = np.asarray(v_to, dtype=np.float64).reshape(3) + na = np.linalg.norm(a) + nb = np.linalg.norm(b) + if na < eps or nb < eps: + return np.eye(3, dtype=np.float64) + a = a / na + b = b / nb + + c = float(np.clip(np.dot(a, b), -1.0, 1.0)) # cos(theta) + if c > 1.0 - eps: + return np.eye(3, dtype=np.float64) + if c < -1.0 + eps: + # 180 deg rotation about any axis orthogonal to a: + # R = -I + 2 * uu^T, where u is a unit axis orthogonal to a. + axis_seed = np.array([1.0, 0.0, 0.0], dtype=np.float64) + if abs(float(np.dot(a, axis_seed))) > 0.9: + axis_seed = np.array([0.0, 1.0, 0.0], dtype=np.float64) + u = np.cross(a, axis_seed) + u = u / np.linalg.norm(u).clip(min=eps) + return -np.eye(3, dtype=np.float64) + 2.0 * np.outer(u, u) + + v = np.cross(a, b) + s2 = float(np.dot(v, v)) # ||v||^2 == sin^2(theta) + K = skew(v) + # R = I + K + K^2 * ((1 - c) / s^2) + return np.eye(3, dtype=np.float64) + K + (K @ K) * ((1.0 - c) / s2) diff --git a/kimodo/viz/g1_rig.py b/kimodo/viz/g1_rig.py new file mode 100644 index 0000000000000000000000000000000000000000..9a400b353411f4c7137aaf708e01b4e99e29d108 --- /dev/null +++ b/kimodo/viz/g1_rig.py @@ -0,0 +1,431 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""G1 robot rig: mesh loading, joint mapping, and viser scene setup for G1 skeleton.""" + +import os +import xml.etree.ElementTree as ET +from typing import Any, Optional, Tuple + +import numpy as np +import trimesh + +import viser +import viser.transforms as tf +from kimodo.assets import skeleton_asset_path +from kimodo.skeleton import G1Skeleton34 + +# MuJoCo (z-up, x-forward) -> kimodo (y-up, z-forward) +MUJOCO_TO_KIMODO = np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], dtype=np.float64) + +# MuJoCo (z-up, x-forward) -> kimodo (y-up, z-forward) +MUJOCO_TO_KIMODO = np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], dtype=np.float64) + +G1_MESH_JOINT_MAP = { + "pelvis_skel": ["pelvis.STL", "pelvis_contour_link.STL"], + "left_hip_pitch_skel": ["left_hip_pitch_link.STL"], + "left_hip_roll_skel": ["left_hip_roll_link.STL"], + "left_hip_yaw_skel": ["left_hip_yaw_link.STL"], + "left_knee_skel": ["left_knee_link.STL"], + "left_ankle_pitch_skel": ["left_ankle_pitch_link.STL"], + "left_ankle_roll_skel": ["left_ankle_roll_link.STL"], + "right_hip_pitch_skel": ["right_hip_pitch_link.STL"], + "right_hip_roll_skel": ["right_hip_roll_link.STL"], + "right_hip_yaw_skel": ["right_hip_yaw_link.STL"], + "right_knee_skel": ["right_knee_link.STL"], + "right_ankle_pitch_skel": ["right_ankle_pitch_link.STL"], + "right_ankle_roll_skel": ["right_ankle_roll_link.STL"], + "waist_yaw_skel": ["waist_yaw_link_rev_1_0.STL", "waist_yaw_link.STL"], + "waist_roll_skel": ["waist_roll_link_rev_1_0.STL", "waist_roll_link.STL"], + "waist_pitch_skel": [ + "torso_link_rev_1_0.STL", + "torso_link.STL", + "logo_link.STL", + "head_link.STL", + ], + "left_shoulder_pitch_skel": ["left_shoulder_pitch_link.STL"], + "left_shoulder_roll_skel": ["left_shoulder_roll_link.STL"], + "left_shoulder_yaw_skel": ["left_shoulder_yaw_link.STL"], + "left_elbow_skel": ["left_elbow_link.STL"], + "left_wrist_roll_skel": ["left_wrist_roll_link.STL"], + "left_wrist_pitch_skel": ["left_wrist_pitch_link.STL"], + "left_wrist_yaw_skel": ["left_wrist_yaw_link.STL", "left_rubber_hand.STL"], + "right_shoulder_pitch_skel": ["right_shoulder_pitch_link.STL"], + "right_shoulder_roll_skel": ["right_shoulder_roll_link.STL"], + "right_shoulder_yaw_skel": ["right_shoulder_yaw_link.STL"], + "right_elbow_skel": ["right_elbow_link.STL"], + "right_wrist_roll_skel": ["right_wrist_roll_link.STL"], + "right_wrist_pitch_skel": ["right_wrist_pitch_link.STL"], + "right_wrist_yaw_skel": ["right_wrist_yaw_link.STL", "right_rubber_hand.STL"], +} + +# Joint axis/limits from g1.xml (used by exports, e.g. MujocoQposConverter) +_G1_JOINT_AXIS_INDEX_CACHE: Optional[dict[str, int]] = None +_G1_JOINT_LIMITS_CACHE: Optional[dict[str, tuple[float, float]]] = None + + +def _get_g1_joint_axis_indices() -> dict[str, int]: + """Return a map from G1 joint names to a single rotation axis index.""" + global _G1_JOINT_AXIS_INDEX_CACHE + if _G1_JOINT_AXIS_INDEX_CACHE is not None: + return _G1_JOINT_AXIS_INDEX_CACHE + + xml_path = str(skeleton_asset_path("g1skel34", "xml", "g1.xml")) + if not os.path.exists(xml_path): + _G1_JOINT_AXIS_INDEX_CACHE = {} + return _G1_JOINT_AXIS_INDEX_CACHE + + tree = ET.parse(xml_path) + root = tree.getroot() + + joint_axes = {} + for xml_class in tree.findall(".//default"): + if "class" not in xml_class.attrib: + continue + joint_nodes = xml_class.findall("joint") + if joint_nodes: + joint_axes[xml_class.get("class")] = joint_nodes[0].get("axis") + + axis_indices_by_name: dict[str, int] = {} + for joint in root.find("worldbody").findall(".//joint"): + axis_str = joint.get("axis") or joint_axes.get(joint.get("class")) + if axis_str is None: + continue + axis_vals = np.array([float(x) for x in axis_str.split()], dtype=np.float64) + if not np.any(axis_vals): + continue + axis_kimodo = MUJOCO_TO_KIMODO @ axis_vals + axis_idx = int(np.argmax(np.abs(axis_kimodo))) + axis_indices_by_name[joint.get("name").replace("_joint", "_skel")] = axis_idx + + _G1_JOINT_AXIS_INDEX_CACHE = axis_indices_by_name + return _G1_JOINT_AXIS_INDEX_CACHE + + +def _get_g1_joint_limits() -> dict[str, tuple[float, float]]: + """Return a map from G1 joint names to (min, max) angle limits in radians.""" + global _G1_JOINT_LIMITS_CACHE + if _G1_JOINT_LIMITS_CACHE is not None: + return _G1_JOINT_LIMITS_CACHE + + xml_path = str(skeleton_asset_path("g1skel34", "xml", "g1.xml")) + if not os.path.exists(xml_path): + _G1_JOINT_LIMITS_CACHE = {} + return _G1_JOINT_LIMITS_CACHE + + tree = ET.parse(xml_path) + root = tree.getroot() + + class_ranges: dict[str, tuple[float, float]] = {} + for xml_class in tree.findall(".//default"): + class_name = xml_class.get("class") + if not class_name: + continue + joint_nodes = xml_class.findall("joint") + if not joint_nodes: + continue + range_str = joint_nodes[0].get("range") + if not range_str: + continue + range_vals = [float(x) for x in range_str.split()] + if len(range_vals) != 2: + continue + class_ranges[class_name] = (range_vals[0], range_vals[1]) + + joint_limits: dict[str, tuple[float, float]] = {} + worldbody = root.find("worldbody") + if worldbody is None: + _G1_JOINT_LIMITS_CACHE = {} + return _G1_JOINT_LIMITS_CACHE + + for joint in worldbody.findall(".//joint"): + range_str = joint.get("range") or class_ranges.get(joint.get("class")) + if range_str is None: + continue + if isinstance(range_str, tuple): + joint_range = range_str + else: + range_vals = [float(x) for x in range_str.split()] + if len(range_vals) != 2: + continue + joint_range = (range_vals[0], range_vals[1]) + joint_name = joint.get("name") + if not joint_name: + continue + joint_limits[joint_name.replace("_joint", "_skel")] = joint_range + + _G1_JOINT_LIMITS_CACHE = joint_limits + return _G1_JOINT_LIMITS_CACHE + + +_G1_JOINT_F2Q_DATA_CACHE: Optional[dict[str, dict[str, Any]]] = None + + +def get_g1_joint_f2q_data( + skeleton: G1Skeleton34, +) -> dict[str, dict[str, Any]]: + """Return per-hinge-joint f2q data for correct 1-DoF + limits in offset space. + + Each entry is for a G1 hinge joint (by name) and contains: + - "offset_f2q": (3, 3) matrix such that R_f2q = offset_f2q @ R_local (kimodo). + - "axis_f2q": (3,) unit axis in f2q space; angle = dot(axis_angle(R_f2q), axis_f2q). + - "rest_dof_axis_angle": angle (rad) at T-pose in f2q space; MuJoCo q = angle_f2q - this. + + Limits from the XML apply to q = angle_f2q - rest_dof_axis_angle. + """ + global _G1_JOINT_F2Q_DATA_CACHE + if _G1_JOINT_F2Q_DATA_CACHE is not None: + return _G1_JOINT_F2Q_DATA_CACHE + + from kimodo.exports.mujoco import MujocoQposConverter + + converter = MujocoQposConverter(skeleton) + # converter: _rot_offsets_f2q[kimodo_idx], _mujoco_joint_axis_values_f2q_space[hinge_idx], + # _rest_dofs_axis_angle[hinge_idx], _kimodo_indices_to_mujoco_indices[kimodo_idx] = hinge_idx+1 or 0 + out: dict[str, dict[str, Any]] = {} + for j in range(skeleton.nbjoints): + mujoco_one_based = converter._kimodo_indices_to_mujoco_indices[j].item() + if mujoco_one_based <= 0: + continue + hinge_idx = mujoco_one_based - 1 + joint_name = skeleton.bone_order_names[j] + offset_f2q = converter._rot_offsets_f2q[j].detach().cpu().numpy().astype(np.float64) + axis_f2q = converter._mujoco_joint_axis_values_f2q_space[hinge_idx].detach().cpu().numpy().astype(np.float64) + n = np.linalg.norm(axis_f2q) + if n > 1e-10: + axis_f2q = axis_f2q / n + rest_dof = float(converter._rest_dofs_axis_angle[hinge_idx].detach().cpu().numpy()) + out[joint_name] = { + "offset_f2q": offset_f2q, + "axis_f2q": axis_f2q, + "rest_dof_axis_angle": rest_dof, + } + _G1_JOINT_F2Q_DATA_CACHE = out + return out + + +# ----------------------------------------------------------------------------- +# Mesh loading cache (shared across G1 rig instances; each rig gets its own scene meshes) +# ----------------------------------------------------------------------------- +_G1_MESH_DATA_CACHE: dict[str, list[dict]] = {} + + +def _load_g1_mesh_data( + mesh_dir: str, + skeleton: G1Skeleton34, +) -> list[dict]: + """Load STL meshes and XML transforms once per mesh_dir; shared across rig instances.""" + if mesh_dir in _G1_MESH_DATA_CACHE: + return _G1_MESH_DATA_CACHE[mesh_dir] + + mesh_geom_cache = G1MeshRig._mesh_geom_cache + mesh_transform_cache = G1MeshRig._mesh_transform_cache + + # Load XML-derived transforms (cached inside _get_mesh_local_transforms_impl) + mesh_file_transforms = _get_mesh_local_transforms_impl(mesh_dir, mesh_transform_cache) + + data_list: list[dict] = [] + for joint_name, mesh_files in G1_MESH_JOINT_MAP.items(): + if joint_name not in skeleton.bone_index: + continue + joint_idx = skeleton.bone_index[joint_name] + for mesh_file in mesh_files: + mesh_path = os.path.join(mesh_dir, mesh_file) + if not os.path.exists(mesh_path): + continue + vertices, faces = _get_mesh_geometry_impl(mesh_file, mesh_path, mesh_dir, mesh_geom_cache) + if vertices is None: + continue + geom_pos, geom_rot = mesh_file_transforms.get( + mesh_file, + (np.zeros(3, dtype=np.float64), np.eye(3, dtype=np.float64)), + ) + data_list.append( + { + "mesh_file": mesh_file, + "vertices": vertices, + "faces": faces, + "joint_idx": joint_idx, + "geom_pos": geom_pos.copy(), + "geom_rot": geom_rot.copy(), + } + ) + + _G1_MESH_DATA_CACHE[mesh_dir] = data_list + return data_list + + +def _get_mesh_geometry_impl( + mesh_file: str, + mesh_path: str, + mesh_dir: str, + mesh_geom_cache: dict, +) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]: + """Load one STL; result cached per mesh_dir and shared across rigs.""" + cached = mesh_geom_cache.get(mesh_dir) + if cached is not None and mesh_file in cached: + vertices, faces = cached[mesh_file] + return vertices.copy(), faces.copy() + + mesh = trimesh.load_mesh(mesh_path, process=True) + if isinstance(mesh, trimesh.Scene): + mesh = trimesh.util.concatenate(mesh.dump()) + vertices = mesh.vertices @ MUJOCO_TO_KIMODO.T + faces = mesh.faces + + if mesh_dir not in mesh_geom_cache: + mesh_geom_cache[mesh_dir] = {} + mesh_geom_cache[mesh_dir][mesh_file] = (vertices, faces) + return vertices.copy(), faces.copy() + + +def _get_mesh_local_transforms_impl( + mesh_dir: str, + mesh_transform_cache: dict, +) -> dict[str, tuple[np.ndarray, np.ndarray]]: + """Parse g1.xml once per mesh_dir; result shared across G1 rig instances.""" + cached = mesh_transform_cache.get(mesh_dir) + if cached is not None: + return {mesh_file: (pos.copy(), rot.copy()) for mesh_file, (pos, rot) in cached.items()} + + xml_path = os.path.abspath(os.path.join(mesh_dir, "..", "..", "xml", "g1.xml")) + if not os.path.exists(xml_path): + return {} + tree = ET.parse(xml_path) + root = tree.getroot() + + mesh_file_to_mesh_name = {} + for mesh in root.findall(".//asset/mesh"): + mesh_name = mesh.get("name") + mesh_file = mesh.get("file") + if mesh_name and mesh_file: + mesh_file_to_mesh_name[mesh_file] = mesh_name + + mesh_name_to_transform = {} + for geom in root.findall(".//geom"): + mesh_name = geom.get("mesh") + if mesh_name is None: + continue + pos = geom.get("pos") + quat = geom.get("quat") + if pos is None: + geom_pos = np.zeros(3, dtype=np.float64) + else: + geom_pos = np.array([float(x) for x in pos.split()], dtype=np.float64) + if quat is None: + geom_rot = np.eye(3, dtype=np.float64) + else: + wxyz = np.array([float(x) for x in quat.split()], dtype=np.float64) + geom_rot = tf.SO3(wxyz=wxyz).as_matrix() + mesh_name_to_transform[mesh_name] = (geom_pos, geom_rot) + + mesh_file_transforms = {} + for mesh_file, mesh_name in mesh_file_to_mesh_name.items(): + geom_pos, geom_rot = mesh_name_to_transform.get( + mesh_name, + (np.zeros(3, dtype=np.float64), np.eye(3, dtype=np.float64)), + ) + geom_pos = MUJOCO_TO_KIMODO @ geom_pos + geom_rot = MUJOCO_TO_KIMODO @ geom_rot @ MUJOCO_TO_KIMODO.T + mesh_file_transforms[mesh_file] = (geom_pos, geom_rot) + + mesh_transform_cache[mesh_dir] = {mf: (pos.copy(), rot.copy()) for mf, (pos, rot) in mesh_file_transforms.items()} + return mesh_file_transforms + + +class G1MeshRig: + """Rig for G1 STL meshes. + + Each instance has its own scene meshes (so clear() only removes one character). Loading is + shared: STL files and g1.xml are cached per mesh_dir via _load_g1_mesh_data() and the class- + level _mesh_*_cache dicts. + """ + + _mesh_geom_cache: dict[str, dict[str, tuple[np.ndarray, np.ndarray]]] = {} + _mesh_transform_cache: dict[str, dict[str, tuple[np.ndarray, np.ndarray]]] = {} + + def __init__( + self, + name: str, + server: viser.ViserServer | viser.ClientHandle, + skeleton: G1Skeleton34, + mesh_dir: str, + color: Tuple[int, int, int], + ): + self.server = server + self.skeleton = skeleton + self.mesh_dir = mesh_dir + self.color = color + self.mesh_handles: list[viser.SceneHandle] = [] + self.mesh_items: list[dict[str, object]] = [] + self._defer_initial_visibility = True + + data_list = _load_g1_mesh_data(mesh_dir, skeleton) + + for item in data_list: + mesh_file = item["mesh_file"] + vertices = item["vertices"] + faces = item["faces"] + joint_idx = item["joint_idx"] + geom_pos = item["geom_pos"] + geom_rot = item["geom_rot"] + + handle = self.server.scene.add_mesh_simple( + f"/{name}/g1_mesh/{os.path.splitext(mesh_file)[0]}", + vertices=vertices, + faces=faces, + opacity=None, + color=self.color, + wireframe=False, + visible=not self._defer_initial_visibility, + ) + self.mesh_handles.append(handle) + self.mesh_items.append( + { + "handle": handle, + "joint_idx": joint_idx, + "geom_pos": geom_pos, + "geom_rot": geom_rot, + } + ) + + if self._defer_initial_visibility: + for handle in self.mesh_handles: + handle.visible = True + + def set_visibility(self, visible: bool) -> None: + for handle in self.mesh_handles: + handle.visible = visible + + def set_opacity(self, opacity: float) -> None: + for handle in self.mesh_handles: + handle.opacity = opacity + + def set_wireframe(self, wireframe: bool) -> None: + for handle in self.mesh_handles: + handle.wireframe = wireframe + + def set_color(self, color: Tuple[int, int, int]) -> None: + self.color = color + for handle in self.mesh_handles: + handle.color = color + + def set_pose(self, joints_pos: np.ndarray, joints_rot: np.ndarray) -> None: + for item in self.mesh_items: + handle = item["handle"] + joint_idx = item["joint_idx"] + geom_pos = item["geom_pos"] + geom_rot = item["geom_rot"] + + joint_pos = joints_pos[joint_idx] + joint_rot = joints_rot[joint_idx] + mesh_pos = joint_pos + joint_rot @ geom_pos + mesh_rot = joint_rot @ geom_rot + + handle.position = mesh_pos + handle.wxyz = tf.SO3.from_matrix(mesh_rot).wxyz + + def clear(self) -> None: + for handle in self.mesh_handles: + self.server.scene.remove_by_name(handle.name) + self.mesh_handles = [] + self.mesh_items = [] diff --git a/kimodo/viz/gui.py b/kimodo/viz/gui.py new file mode 100644 index 0000000000000000000000000000000000000000..0c269ea02ad984fc792430935c95b4e8b373c590 --- /dev/null +++ b/kimodo/viz/gui.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""GUI element handles for the demo app.""" + +from dataclasses import dataclass + +import viser + + +@dataclass +class GuiElements: + gui_play_pause_button: viser.GuiInputHandle + gui_next_frame_button: viser.GuiInputHandle + gui_prev_frame_button: viser.GuiInputHandle + gui_generate_button: viser.GuiInputHandle + gui_model_fps: viser.GuiInputHandle[int] + gui_timeline: viser.GuiInputHandle[int] + gui_viz_skeleton_checkbox: viser.GuiInputHandle[bool] + gui_viz_foot_contacts_checkbox: viser.GuiInputHandle[bool] + gui_viz_skinned_mesh_checkbox: viser.GuiInputHandle[bool] + gui_viz_skinned_mesh_opacity_slider: viser.GuiInputHandle[float] + gui_camera_fov_slider: viser.GuiInputHandle[float] + + # generation controls + gui_duration_slider: viser.GuiInputHandle[float] + gui_num_samples_slider: viser.GuiInputHandle[int] + gui_cfg_checkbox: viser.GuiCheckboxHandle + gui_cfg_text_weight_slider: viser.GuiInputHandle[float] + gui_cfg_constraint_weight_slider: viser.GuiInputHandle[float] + gui_diffusion_steps_slider: viser.GuiInputHandle[int] + gui_seed: viser.GuiInputHandle[int] + gui_postprocess_checkbox: viser.GuiCheckboxHandle + gui_root_margin: viser.GuiInputHandle[float] + gui_real_robot_rotations_checkbox: viser.GuiInputHandle[bool] + # appearance + gui_dark_mode_checkbox: viser.GuiCheckboxHandle + + # which skinning method to use for SOMA + gui_use_soma_layer_checkbox: viser.GuiCheckboxHandle diff --git a/kimodo/viz/playback.py b/kimodo/viz/playback.py new file mode 100644 index 0000000000000000000000000000000000000000..8bfb677ab5bf71f38feca9108cf543d58055e084 --- /dev/null +++ b/kimodo/viz/playback.py @@ -0,0 +1,719 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Playback and motion editing: CharacterMotion.""" + +from typing import Callable, Literal, Optional + +import numpy as np +import torch + +import viser.transforms as tf +from kimodo.skeleton import ( + G1Skeleton34, + SOMASkeleton30, + SOMASkeleton77, + batch_rigid_transform, + global_rots_to_local_rots, +) +from kimodo.tools import to_numpy, to_torch + +from .g1_rig import ( + _get_g1_joint_axis_indices, + _get_g1_joint_limits, + get_g1_joint_f2q_data, +) +from .scene import Character + + +class CharacterMotion: + def __init__( + self, + character: Character, + joints_pos: torch.Tensor, + joints_rot: torch.Tensor, + foot_contacts: Optional[torch.Tensor] = None, + ): + self.character = character + self.server = character.server + self.skeleton = character.skeleton + self.name = character.name + + # [T, J, 3] global joint positions + self.joints_pos = joints_pos + # [T, J, 3, 3] global joint rotation matrices + self.joints_rot = joints_rot + assert joints_pos.shape[0] == joints_rot.shape[0] + # keep track of local rots as well for convenience during pose editing + self.joints_local_rot = global_rots_to_local_rots(joints_rot, self.skeleton) + + self.length = joints_pos.shape[0] + self.cur_frame_idx = None + + self.foot_contacts = foot_contacts + if foot_contacts is not None: + assert foot_contacts.shape[0] == self.length + + self.precompute_mesh_info() + + # gizmos for pose editing + self.root_translation_gizmo = None + self.updating_root_translation_gizmo = False + self.joint_gizmos = None + self.updating_joint_gizmos = False + self.gizmo_space: Literal["world", "local"] = "local" + self._drag_start_world_rot: list = [] + self._joint_gizmo_dragging: list[bool] = [] + + def precompute_mesh_info(self): + if self.character.skeleton_mesh is not None: + print("Caching skeleton mesh info...") + self.character.skeleton_mesh.precompute_mesh_info(self.joints_pos) + if self.character.skinned_mesh is not None: + print("Caching skinning info...") + self.character.precompute_skinning(self.joints_pos, self.joints_rot) + + def set_frame(self, idx: int): + """Sets the pose of the character to the given frame index.""" + idx = min(idx, self.length - 1) # clamp to last frame + cur_foot_contacts = self.foot_contacts[idx] if self.foot_contacts is not None else None + self.character.set_pose( + self.joints_pos[idx], + self.joints_rot[idx], + frame_idx=idx, + foot_contacts=cur_foot_contacts, + ) + self.cur_frame_idx = idx + + # update gizmos if frame has changed due to playback + cur_root_pos = self.joints_pos[self.cur_frame_idx, self.skeleton.root_idx].clone() + cur_root_pos[1] = 0.0 + if self.root_translation_gizmo is not None and not self.updating_root_translation_gizmo: + self.root_translation_gizmo.position = cur_root_pos.cpu().numpy() + if self.joint_gizmos is not None: + for i, joint_gizmo in enumerate(self.joint_gizmos): + # Do not push wxyz/position while this gizmo is being dragged; + # otherwise the client receives e.g. identity and the gizmo snaps back. + if not self.updating_joint_gizmos and not self._joint_gizmo_dragging[i]: + joint_gizmo.position = self.joints_pos[self.cur_frame_idx, i].cpu().numpy() + if self.gizmo_space == "world": + joint_gizmo.wxyz = (1.0, 0.0, 0.0, 0.0) + else: + joint_gizmo.wxyz = tf.SO3.from_matrix(self.joints_rot[self.cur_frame_idx, i].cpu().numpy()).wxyz + + def update_pose_at_frame( + self, + frame_idx: int, + joints_pos: Optional[torch.Tensor] = None, + joints_rot: Optional[torch.Tensor] = None, + joints_local_rot: Optional[torch.Tensor] = None, + foot_contacts: Optional[torch.Tensor] = None, + ): + """Overwrites one or more of the pose components at the given frame. + + If only a subset of joints_pos, joints_rot, or joints_local_rot are provided, the other + components will be updated with FK. + """ + if joints_pos is not None: + joints_pos = to_torch(joints_pos, device=self.joints_pos.device, dtype=self.joints_pos.dtype) + self.joints_pos[frame_idx] = joints_pos + if joints_local_rot is None and joints_rot is None: + raise NotImplementedError("No IK to update joint rotations accordingly.") + if joints_rot is not None: + joints_rot = to_torch(joints_rot, device=self.joints_rot.device, dtype=self.joints_rot.dtype) + self.joints_rot[frame_idx] = joints_rot + if joints_local_rot is None: + # update local rots from global rots + self.joints_local_rot[frame_idx] = global_rots_to_local_rots(joints_rot, self.skeleton) + if joints_pos is None: + # need to update with FK + new_posed_joints, _ = batch_rigid_transform( + self.joints_local_rot[frame_idx : frame_idx + 1], + self.skeleton.neutral_joints[None].to(self.joints_local_rot.device), + self.skeleton.joint_parents.to(self.joints_local_rot.device), + self.skeleton.root_idx, + ) + new_posed_joints = ( + new_posed_joints[0] + + self.joints_pos[frame_idx, self.skeleton.root_idx : self.skeleton.root_idx + 1] + - self.skeleton.neutral_joints[[self.skeleton.root_idx]] + ) + self.joints_pos[frame_idx] = new_posed_joints + if joints_local_rot is not None: + joints_local_rot = to_torch(joints_local_rot, device=self.joints_local_rot.device).to( + dtype=self.joints_local_rot.dtype + ) + self.joints_local_rot[frame_idx] = joints_local_rot + if joints_rot is None or joints_pos is None: + # need to update with FK + new_posed_joints, new_global_rots = batch_rigid_transform( + self.joints_local_rot[frame_idx : frame_idx + 1], + self.skeleton.neutral_joints[None].to(self.joints_local_rot.device), + self.skeleton.joint_parents.to(self.joints_local_rot.device), + self.skeleton.root_idx, + ) + new_posed_joints = ( + new_posed_joints[0] + + self.joints_pos[frame_idx, self.skeleton.root_idx : self.skeleton.root_idx + 1] + - self.skeleton.neutral_joints[[self.skeleton.root_idx]] + ) + if joints_rot is None: + self.joints_rot[frame_idx] = new_global_rots[0] + if joints_pos is None: + self.joints_pos[frame_idx] = new_posed_joints + if foot_contacts is not None: + foot_contacts = to_torch(foot_contacts, device=self.foot_contacts.device).to(dtype=self.foot_contacts.dtype) + self.foot_contacts[frame_idx] = foot_contacts + + if self.character.skeleton_mesh is not None: + self.character.skeleton_mesh.update_mesh_info_cache(self.joints_pos[frame_idx], frame_idx) + if self.character.skinned_mesh is not None: + self.character.update_skinning_cache(self.joints_pos[frame_idx], self.joints_rot[frame_idx], frame_idx) + + def clear(self): + self.character.clear() + + # + # Editing helpers + # + def get_current_projected_root_pos(self) -> np.ndarray: + """Get the projected root position on the ground at the current frame.""" + root_pos = self.joints_pos[self.cur_frame_idx, self.skeleton.root_idx].clone() + root_pos[1] = 0.0 + return to_numpy(root_pos) + + def get_projected_root_pos(self, start_frame_idx: int, end_frame_idx: int = None) -> np.ndarray: + """If requested frames are out of range, simply pads with the last frame to get expected + length.""" + if end_frame_idx is None: + expected_len = 1 + else: + expected_len = end_frame_idx - start_frame_idx + 1 + if start_frame_idx >= self.length: + start_frame_idx = self.length - 1 + if end_frame_idx is None or expected_len == 1: + root_pos = self.joints_pos[start_frame_idx, self.skeleton.root_idx].clone() + root_pos[1] = 0.0 + return to_numpy(root_pos) + else: + if end_frame_idx >= self.length: + end_frame_idx = self.length - 1 + root_pos = self.joints_pos[start_frame_idx : end_frame_idx + 1, self.skeleton.root_idx].clone() + root_pos[:, 1] = 0.0 + if root_pos.shape[0] < expected_len: + # pad with the last root position + root_pos = torch.cat( + [ + root_pos, + root_pos[-1:].repeat(expected_len - root_pos.shape[0], 1), + ], + dim=0, + ) + return to_numpy(root_pos) + + def set_projected_root_pos_path( + self, + root_pos_path: np.ndarray | torch.Tensor, + min_frame_idx: int = None, + max_frame_idx: int = None, + ): + """Sets the projected root position path for the character motion. Can set only a subset of + the path by providing min_frame_idx and max_frame_idx. If not provided, will set the full + path. + + Args: + root_pos_path: torch.Tensor, [T, 2] projected root positions + min_frame_idx: int, optional, minimum frame index to set the path at + max_frame_idx: int, optional, maximum frame index to set the path at + """ + if min_frame_idx is not None or max_frame_idx is not None: + assert ( + min_frame_idx is not None and max_frame_idx is not None + ), "min_frame_idx and max_frame_idx must be provided if setting path at specific frames" + if min_frame_idx >= self.length: + # both are out of bounds + return + max_frame_idx = min(max_frame_idx, self.length - 1) + root_pos_path = root_pos_path[min_frame_idx : max_frame_idx + 1] + else: + assert root_pos_path.shape[0] == self.length + min_frame_idx = 0 + max_frame_idx = self.length - 1 + + cur_joints_pos = self.joints_pos.clone()[min_frame_idx : max_frame_idx + 1] + root_pos_tensor = to_torch(root_pos_path, device=cur_joints_pos.device, dtype=cur_joints_pos.dtype) + diff = root_pos_tensor - cur_joints_pos[:, self.skeleton.root_idx, [0, 2]] + cur_joints_pos[:, :, [0, 2]] += diff.unsqueeze(1) + for frame_idx in range(min_frame_idx, max_frame_idx + 1): + rel_idx = frame_idx - min_frame_idx + self.update_pose_at_frame( + frame_idx, + joints_pos=cur_joints_pos[rel_idx], + joints_rot=self.joints_rot[frame_idx], + joints_local_rot=self.joints_local_rot[frame_idx], + ) + # update immediately to show changes + self.set_frame(self.cur_frame_idx) + + def get_joints_pos(self, start_frame_idx: int, end_frame_idx: int = None) -> np.ndarray: + """If requested frames are out of range, simply pads with the last frame to get expected + length.""" + if end_frame_idx is None: + expected_len = 1 + else: + expected_len = end_frame_idx - start_frame_idx + 1 + if start_frame_idx >= self.length: + start_frame_idx = self.length - 1 + if end_frame_idx is None or expected_len == 1: + return to_numpy(self.joints_pos[start_frame_idx].clone()) + else: + if end_frame_idx >= self.length: + end_frame_idx = self.length - 1 + return_joints_pos = self.joints_pos[start_frame_idx : end_frame_idx + 1].clone() + if return_joints_pos.shape[0] < expected_len: + # pad with the last pose + return_joints_pos = torch.cat( + [ + return_joints_pos, + return_joints_pos[-1:].repeat(expected_len - return_joints_pos.shape[0], 1, 1), + ], + dim=0, + ) + return to_numpy(return_joints_pos) + + def get_joints_rot(self, start_frame_idx: int, end_frame_idx: int = None) -> np.ndarray: + """If requested frames are out of range, simply pads with the last frame to get expected + length.""" + if end_frame_idx is None: + expected_len = 1 + else: + expected_len = end_frame_idx - start_frame_idx + 1 + if start_frame_idx >= self.length: + start_frame_idx = self.length - 1 + if end_frame_idx is None or expected_len == 1: + return to_numpy(self.joints_rot[start_frame_idx].clone()) + else: + if end_frame_idx >= self.length: + end_frame_idx = self.length - 1 + return_joints_rot = self.joints_rot[start_frame_idx : end_frame_idx + 1].clone() + if return_joints_rot.shape[0] < expected_len: + # pad with the last pose + return_joints_rot = torch.cat( + [ + return_joints_rot, + return_joints_rot[-1:].repeat(expected_len - return_joints_rot.shape[0], 1, 1, 1), + ], + dim=0, + ) + return to_numpy(return_joints_rot) + + def get_current_joints_pos(self) -> torch.Tensor: + return self.joints_pos[self.cur_frame_idx].clone() + + def get_current_joints_rot(self) -> torch.Tensor: + return self.joints_rot[self.cur_frame_idx].clone() + + def add_root_translation_gizmo( + self, + constraints: dict, + on_2d_root_drag_end: Optional[Callable[[], None]] = None, + on_drag_start: Optional[Callable[[], None]] = None, + ): + """Create and initialize gizmo to control the root translation. + + When the user drags the root 2D gizmo, path updates are skipped until release. Optional + on_2d_root_drag_end is called when the drag ends (e.g. to refresh dense path). on_drag_start + is called when the drag begins (e.g. to snapshot state for undo). + """ + # TODO: could also allow rotation around y-axis + self.root_translation_gizmo = self.server.scene.add_transform_controls( + f"/{self.name}/gizmo_root_translation", + scale=0.5, + line_width=2.5, + active_axes=(True, False, True), # only allow translation on xz plane + disable_axes=False, + disable_sliders=False, + disable_rotations=True, + depth_test=False, # render even when occluded + ) + init_position = self.get_current_projected_root_pos() + self.root_translation_gizmo.position = init_position + + @self.root_translation_gizmo.on_drag_start + def _(_): + if on_drag_start is not None: + on_drag_start() + + @self.root_translation_gizmo.on_update + def _(_): + self.updating_root_translation_gizmo = True + # translate to gizmo position + new_root_pos = to_torch( + self.root_translation_gizmo.position, + device=self.joints_pos.device, + ).to(dtype=self.joints_pos.dtype) + cur_joints_pos = self.joints_pos[self.cur_frame_idx].clone() + root_diff = new_root_pos - cur_joints_pos[self.skeleton.root_idx] + root_diff[1] = 0.0 # don't change height + cur_joints_pos += root_diff[None] + self.update_pose_at_frame( + self.cur_frame_idx, + joints_pos=cur_joints_pos, + joints_rot=self.joints_rot[self.cur_frame_idx], + joints_local_rot=self.joints_local_rot[self.cur_frame_idx], + ) + + self.updating_root_translation_gizmo = False + # update immediately to show user changes + self.set_frame(self.cur_frame_idx) + # update the 2D waypoint constraints as well if there is one + if "2D Root" in constraints: + root_2d_contraints = constraints["2D Root"] + # if there is a constraint at that frame, we want to update it + frame_idx = self.cur_frame_idx + if frame_idx in root_2d_contraints.keyframes: + for keyframe_id in root_2d_contraints.frame2keyid[frame_idx]: + # add will modify the existing constraint + # update_path=False during drag to avoid lag; path refreshes on_drag_end + root_2d_contraints.add_keyframe( + keyframe_id, + frame_idx, + root_pos=new_root_pos, + exists_ok=True, + update_path=False, + ) + if "Full-Body" in constraints: + full_body_constraints = constraints["Full-Body"] + # if there is a constraint at that frame, we want to update it + frame_idx = self.cur_frame_idx + if frame_idx in full_body_constraints.keyframes: + current_dict = full_body_constraints.keyframes[frame_idx] + for keyframe_id in full_body_constraints.frame2keyid[frame_idx]: + # add will modify the existing constraint + full_body_constraints.add_keyframe( + keyframe_id, + frame_idx, + joints_pos=cur_joints_pos, + joints_rot=current_dict["joints_rot"], + exists_ok=True, + ) + if "End-Effectors" in constraints: + end_effector_constraints = constraints["End-Effectors"] + # if there is a constraint at that frame, we want to update it + frame_idx = self.cur_frame_idx + if frame_idx in end_effector_constraints.keyframes: + current_dict = end_effector_constraints.keyframes[frame_idx] + for keyframe_id, _ in end_effector_constraints.frame2keyid[frame_idx]: + # add will modify the existing constraint + end_effector_constraints.add_keyframe( + keyframe_id, + frame_idx, + joints_pos=cur_joints_pos, + joints_rot=current_dict["joints_rot"], + joint_names=current_dict["joint_names"], + end_effector_type=current_dict["end_effector_type"], + exists_ok=True, + ) + + @self.root_translation_gizmo.on_drag_end + def _on_drag_end(_): + # Refresh path visualization and dense path after release. + if "2D Root" in constraints: + root_2d = constraints["2D Root"] + if root_2d.line_segments is not None: + root_2d.update_line_segments() + if on_2d_root_drag_end is not None: + on_2d_root_drag_end() + + def add_joint_gizmos( + self, + constraints: dict, + space: Literal["world", "local"] = "local", + on_drag_start: Optional[Callable[[], None]] = None, + ): + # Remove existing joint gizmos first so the client gets remove then add, + # avoiding in-place update that can briefly show duplicate gizmos. + if self.joint_gizmos is not None: + for joint_gizmo in self.joint_gizmos: + self.server.scene.remove_by_name(joint_gizmo.name) + self.joint_gizmos = None + + self.joint_gizmos = [] + self.gizmo_space = space + # For world mode: store joint world rotation at drag start to compose with + # PivotControls' cumulative-from-identity drag rotation. + self._drag_start_world_rot = [None] * self.skeleton.nbjoints + # Skip pushing wxyz/position in set_frame while a gizmo is being dragged, + # so the client does not receive "snap back" (e.g. identity for world mode). + self._joint_gizmo_dragging = [False] * self.skeleton.nbjoints + + joint_axis_indices = None + joint_limits = None + joint_f2q_data = None + hidden_gizmo_joints = None + if isinstance(self.skeleton, G1Skeleton34): + joint_axis_indices = _get_g1_joint_axis_indices() + joint_limits = _get_g1_joint_limits() + joint_f2q_data = get_g1_joint_f2q_data(self.skeleton) + hidden_gizmo_joints = { + "left_hand_roll_skel", + "right_hand_roll_skel", + "left_toe_base", + "right_toe_base", + } + elif isinstance(self.skeleton, SOMASkeleton77): + skel30_names = {name for name, _ in SOMASkeleton30.bone_order_names_with_parents} + hidden_gizmo_joints = {name for name in self.skeleton.bone_order_names if name not in skel30_names} + hidden_gizmo_joints |= { + "RightHandThumbEnd", + "RightHandMiddleEnd", + "LeftHandThumbEnd", + "LeftHandMiddleEnd", + "LeftEye", + "RightEye", + "Jaw", + } + elif isinstance(self.skeleton, SOMASkeleton30): + hidden_gizmo_joints = { + "RightHandThumbEnd", + "RightHandMiddleEnd", + "LeftHandThumbEnd", + "LeftHandMiddleEnd", + "LeftEye", + "RightEye", + "Jaw", + } + + if space == "world": + # World mode: gizmo rings stay scene-axis-aligned (identity). + joints_wxyzs = np.tile( + np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64), + (self.skeleton.nbjoints, 1), + ) + else: + # Local mode: gizmo shows joint world rotation so rings follow the joint. + joints_wxyzs = tf.SO3.from_matrix(self.joints_rot[self.cur_frame_idx].cpu().numpy()).wxyz + for joint_idx in range(self.skeleton.nbjoints): + disable_axes = True # by default, only rotation controls + disable_sliders = True + if joint_idx == self.skeleton.root_idx: + disable_axes = False # allow translation for root + disable_sliders = False + active_axes = (True, True, True) + if joint_axis_indices is not None: + joint_name = self.skeleton.bone_order_names[joint_idx] + axis_idx = joint_axis_indices.get(joint_name) + if axis_idx is not None: + # PivotControls shows rotation handles when a plane is active. + # To allow rotation about one axis, enable the other two axes. + active_axes = ( + axis_idx != 0, + axis_idx != 1, + axis_idx != 2, + ) + joint_visible = True + if hidden_gizmo_joints is not None: + joint_name = self.skeleton.bone_order_names[joint_idx] + joint_visible = joint_name not in hidden_gizmo_joints + cur_joint_gizmo = self.server.scene.add_transform_controls( + f"/{self.name}/gizmo_joint_{joint_idx}", + scale=0.075, + line_width=4.0, + active_axes=active_axes, + disable_axes=disable_axes, + disable_sliders=disable_sliders, + disable_rotations=False, + depth_test=False, # render even when occluded + position=self.joints_pos[self.cur_frame_idx, joint_idx].cpu().numpy(), + wxyz=joints_wxyzs[joint_idx], + visible=joint_visible, + space=space, + ) + self.joint_gizmos.append(cur_joint_gizmo) + + def set_callback_in_closure(i: int) -> None: + @cur_joint_gizmo.on_drag_start + def _on_drag_start(_) -> None: + if on_drag_start is not None: + on_drag_start() + self._joint_gizmo_dragging[i] = True + if self.gizmo_space == "world": + self._drag_start_world_rot[i] = self.joints_rot[self.cur_frame_idx, i].clone().cpu().numpy() + + @cur_joint_gizmo.on_drag_end + def _on_drag_end(_) -> None: + self._joint_gizmo_dragging[i] = False + # Force-sync so the client always receives the reset (viser setter skips on allclose). + # Use self.joint_gizmos[i] (not cur_joint_gizmo) to avoid the + # closure-in-loop bug: cur_joint_gizmo would point to the last handle. + gizmo = self.joint_gizmos[i] + gizmo.sync_position(self.joints_pos[self.cur_frame_idx, i].cpu().numpy()) + if self.gizmo_space == "world": + gizmo.sync_wxyz((1.0, 0.0, 0.0, 0.0)) + else: + gizmo.sync_wxyz(tf.SO3.from_matrix(self.joints_rot[self.cur_frame_idx, i].cpu().numpy()).wxyz) + self.set_frame(self.cur_frame_idx) + + @cur_joint_gizmo.on_update + def _(_) -> None: + self.updating_joint_gizmos = True + new_local_joint_rots = self.joints_local_rot[self.cur_frame_idx].clone() + # Gizmo parent is identity; client sends rotation as wxyz. + # World mode: wxyz is cumulative from identity, compose with + # stored initial world rotation. Local mode: wxyz is new world rotation. + gizmo_rot_mat = tf.SO3(self.joint_gizmos[i].wxyz).as_matrix() + if self.gizmo_space == "world" and self._drag_start_world_rot[i] is not None: + new_world_rot_mat = gizmo_rot_mat @ self._drag_start_world_rot[i] + else: + new_world_rot_mat = gizmo_rot_mat + parent_idx = self.skeleton.joint_parents[i].item() + if parent_idx >= 0: + R_parent_world = self.joints_rot[self.cur_frame_idx, parent_idx].detach().cpu().numpy() + new_local_rot_mat_np = (R_parent_world.T @ new_world_rot_mat).astype(np.float32) + else: + new_local_rot_mat_np = new_world_rot_mat.astype(np.float32) + new_local_rot = tf.SO3.from_matrix(new_local_rot_mat_np) + joint_name = self.skeleton.bone_order_names[i] + if joint_f2q_data is not None and joint_name in joint_f2q_data: + # G1 hinge: use offset (f2q) space so 1-DoF and limits match the robot. + # R_f2q = offset_f2q @ R_local; angle_f2q = dot(axis_angle(R_f2q), axis_f2q); + # MuJoCo q = angle_f2q - rest_dof; limits apply to q. + f2q = joint_f2q_data[joint_name] + offset_f2q = f2q["offset_f2q"] + axis_f2q = f2q["axis_f2q"] + rest_dof = f2q["rest_dof_axis_angle"] + R_local = new_local_rot_mat_np.astype(np.float64) + R_f2q = offset_f2q @ R_local + rotvec = tf.SO3.from_matrix(R_f2q).log() + angle_f2q = float(np.dot(rotvec, axis_f2q)) + # Keep angle continuous relative to current pose. + current_R_f2q = offset_f2q @ ( + self.joints_local_rot[self.cur_frame_idx, i].detach().cpu().numpy().astype(np.float64) + ) + current_angle_f2q = float(np.dot(tf.SO3.from_matrix(current_R_f2q).log(), axis_f2q)) + two_pi = 2.0 * np.pi + angle_f2q = angle_f2q + two_pi * np.round((current_angle_f2q - angle_f2q) / two_pi) + q = angle_f2q - rest_dof + if joint_limits is not None: + joint_limit = joint_limits.get(joint_name) + if joint_limit is not None: + q = float(np.clip(q, joint_limit[0], joint_limit[1])) + angle_f2q = q + rest_dof + R_f2q_new = tf.SO3.exp(angle_f2q * axis_f2q).as_matrix() + new_local_rot_mat_np = (offset_f2q.T @ R_f2q_new).astype(np.float32) + elif joint_axis_indices is not None: + axis_idx = joint_axis_indices.get(joint_name) + if axis_idx is not None: + rotvec = new_local_rot.log() + axis = np.zeros(3, dtype=np.float64) + axis[axis_idx] = 1.0 + angle = float(rotvec[axis_idx]) + # Keep angle continuous relative to current pose. + current_rot = tf.SO3.from_matrix( + self.joints_local_rot[self.cur_frame_idx, i].detach().cpu().numpy() + ) + current_angle = float(current_rot.log()[axis_idx]) + two_pi = 2.0 * np.pi + angle = angle + two_pi * np.round((current_angle - angle) / two_pi) + if joint_limits is not None: + joint_limit = joint_limits.get(joint_name) + if joint_limit is not None: + angle = float(np.clip(angle, joint_limit[0], joint_limit[1])) + new_local_rot_mat_np = tf.SO3.exp(angle * axis).as_matrix() + new_local_rot_mat = torch.tensor(new_local_rot_mat_np).to(new_local_joint_rots.device) + new_local_joint_rots[i] = new_local_rot_mat + + self.update_pose_at_frame( + self.cur_frame_idx, + joints_local_rot=new_local_joint_rots, + ) + + # handle root translation separately + cur_joints_pos = self.joints_pos[self.cur_frame_idx].clone() + if i == self.skeleton.root_idx: + new_root_pos = to_torch( + self.joint_gizmos[i].position, + device=self.joints_pos.device, + ).to(dtype=self.joints_pos.dtype) + root_diff = new_root_pos - self.joints_pos[self.cur_frame_idx, i] + if torch.norm(root_diff) > 1e-3: + # the root translation has been changed + # translate to gizmo position + cur_joints_pos += root_diff[None] + self.update_pose_at_frame( + self.cur_frame_idx, + joints_pos=cur_joints_pos, + joints_rot=self.joints_rot[self.cur_frame_idx], + joints_local_rot=self.joints_local_rot[self.cur_frame_idx], + ) + + # update immediately to show user changes. Keep updating_joint_gizmos + # True so set_frame does not overwrite gizmo wxyz mid-drag. + self.set_frame(self.cur_frame_idx) + self.updating_joint_gizmos = False + + if i == self.skeleton.root_idx: + # update the 2D waypoint constraints as well if there is one + if "2D Root" in constraints: + root_2d_contraints = constraints["2D Root"] + # if there is a constraint at that frame, we want to update it + frame_idx = self.cur_frame_idx + if frame_idx in root_2d_contraints.keyframes: + new_root_pos[1] = 0.0 # force y to 0 + for keyframe_id in root_2d_contraints.frame2keyid[frame_idx]: + # add will modify the existing constraint + root_2d_contraints.add_keyframe( + keyframe_id, + frame_idx, + root_pos=new_root_pos, + exists_ok=True, + update_path=False, + ) + + if "Full-Body" in constraints: + full_body_constraints = constraints["Full-Body"] + # if there is a constraint at that frame, we want to update it + frame_idx = self.cur_frame_idx + if frame_idx in full_body_constraints.keyframes: + for keyframe_id in full_body_constraints.frame2keyid[frame_idx]: + # add will modify the existing constraint + full_body_constraints.add_keyframe( + keyframe_id, + frame_idx, + joints_pos=self.joints_pos[frame_idx], + joints_rot=self.joints_rot[frame_idx], + exists_ok=True, + ) + if "End-Effectors" in constraints: + end_effector_constraints = constraints["End-Effectors"] + # if there is a constraint at that frame, we want to update it + frame_idx = self.cur_frame_idx + if frame_idx in end_effector_constraints.keyframes: + current_dict = end_effector_constraints.keyframes[frame_idx] + for keyframe_id, _ in end_effector_constraints.frame2keyid[frame_idx]: + # add will modify the existing constraint + end_effector_constraints.add_keyframe( + keyframe_id, + frame_idx, + joints_pos=self.joints_pos[frame_idx], + joints_rot=self.joints_rot[frame_idx], + joint_names=current_dict["joint_names"], + end_effector_type=current_dict["end_effector_type"], + exists_ok=True, + ) + + set_callback_in_closure(joint_idx) + + def clear_all_gizmos(self): + self.updating_root_translation_gizmo = True + self.updating_joint_gizmos = True + if self.root_translation_gizmo is not None: + self.server.scene.remove_by_name(self.root_translation_gizmo.name) + self.root_translation_gizmo = None + if self.joint_gizmos is not None: + for joint_gizmo in self.joint_gizmos: + self.server.scene.remove_by_name(joint_gizmo.name) + self.joint_gizmos = None + self._drag_start_world_rot = [] + self._joint_gizmo_dragging = [] + self.updating_root_translation_gizmo = False + self.updating_joint_gizmos = False diff --git a/kimodo/viz/scene.py b/kimodo/viz/scene.py new file mode 100644 index 0000000000000000000000000000000000000000..3836066659babeb5118376a0e0cb0a699641d622 --- /dev/null +++ b/kimodo/viz/scene.py @@ -0,0 +1,574 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Viser scene entities: waypoints, skeleton mesh, and character.""" + +import os +import traceback +from pathlib import Path +from typing import Optional, Tuple + +import numpy as np +import torch +import trimesh + +import viser +import viser.transforms as tf +from kimodo.skeleton import ( + G1Skeleton34, + SkeletonBase, + SMPLXSkeleton22, + SOMASkeleton30, + SOMASkeleton77, +) + +from .coords import rotation_matrix_from_two_vec +from .g1_rig import ( + G1MeshRig, +) +from .smplx_skin import SMPLXSkin +from .soma_skin import SOMASkin + + +class WaypointMesh: + def __init__( + self, + name: str, + server: viser.ViserServer, + position: np.ndarray, + heading: Optional[np.ndarray] = None, + color: Optional[Tuple[int, int, int]] = (255, 0, 0), + ): + self.server = server + + sphere = trimesh.creation.icosphere(subdivisions=3, radius=0.025) + annulus = trimesh.creation.annulus(r_min=0.1, r_max=0.2, height=0.005) + + z_to_y_up = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]]) + annulus_vertices = annulus.vertices @ z_to_y_up + + self.sphere = self.server.scene.add_mesh_simple( + name=f"{name}/sphere", + vertices=sphere.vertices, + faces=sphere.faces, + position=position, + color=color, + ) + self.annulus = self.server.scene.add_mesh_simple( + name=f"{name}/annulus", + vertices=annulus_vertices, + faces=annulus.faces, + position=position, + color=color, + ) + + self.arrow_base = None + self.arrow_head = None + if heading is not None: + assert heading.shape == (2,), "Heading must be a 2D vector" + heading = 0.3 * (heading / np.linalg.norm(heading)) + heading_3d = np.array([heading[0], 0, heading[1]]) + arrow_base = trimesh.creation.cylinder(radius=0.01, height=0.3) + arrow_head = trimesh.creation.cone(radius=0.03, height=0.075) + arrow_base_vertices = arrow_base.vertices + arrow_head_vertices = arrow_head.vertices + self.arrow_base = self.server.scene.add_mesh_simple( + name=f"{name}/arrow_base", + vertices=arrow_base_vertices, + faces=arrow_base.faces, + position=position + (heading_3d / 2), + color=color, + ) + self.arrow_head = self.server.scene.add_mesh_simple( + name=f"{name}/arrow_head", + vertices=arrow_head_vertices, + faces=arrow_head.faces, + position=position + heading_3d, + color=color, + ) + + def update_position(self, position: np.ndarray, heading: Optional[np.ndarray] = None): + self.sphere.position = position + self.annulus.position = position + if heading is not None: + assert heading.shape == (2,), "Heading must be a 2D vector" + heading = 0.3 * (heading / np.linalg.norm(heading)) + heading_3d = np.array([heading[0], 0, heading[1]]) + if self.arrow_base is not None: + self.arrow_base.position = position + (heading_3d / 2) + if self.arrow_head is not None: + self.arrow_head.position = position + heading_3d + + def clear(self): + self.server.scene.remove_by_name(self.sphere.name) + self.server.scene.remove_by_name(self.annulus.name) + if self.arrow_base is not None: + self.server.scene.remove_by_name(self.arrow_base.name) + if self.arrow_head is not None: + self.server.scene.remove_by_name(self.arrow_head.name) + + def set_visible(self, visible: bool) -> None: + self.sphere.visible = visible + self.annulus.visible = visible + if self.arrow_base is not None: + self.arrow_base.visible = visible + if self.arrow_head is not None: + self.arrow_head.visible = visible + + +class SkeletonMesh: + def __init__( + self, + name: str, + server: viser.ViserServer, + skeleton: SkeletonBase, + joint_color: Optional[Tuple[float, float, float] | np.ndarray] = ( + 255, + 235, + 0, + ), + bone_color: Optional[Tuple[float, float, float] | np.ndarray] = ( + 27, + 106, + 0, + ), + starting_joints_pos: Optional[torch.Tensor] = None, + ): + """ + name: str, name of the skeleton mesh + server: viser.ViserServer, server to add the skeleton mesh to + skeleton: SkeletonBase, skeleton to visualize + joint_color: Optional[Tuple[float, float, float] | np.ndarray], color of the joints + bone_color: Optional[Tuple[float, float, float] | np.ndarray], color of the bones + starting_joints_pos: Optional[torch.Tensor], starting joint positions + """ + self.server = server + self.skeleton = skeleton + joint_mesh = trimesh.creation.icosphere(subdivisions=3, radius=0.02) + bone_mesh = trimesh.creation.cylinder(radius=0.01, height=1.0) + + init_joints_pos = skeleton.neutral_joints.clone() + self.num_joints = init_joints_pos.shape[0] + num_bones = self.num_joints - 1 + non_root_bones = [ + joint_name + for joint_name, parent_name in self.skeleton.bone_order_names_with_parents + if parent_name is not None + ] + self.bone_to_idx = {bone_name: idx for idx, bone_name in enumerate(non_root_bones)} + + # initialize meshes + init_joints_wxyzs = np.concatenate([np.ones((self.num_joints, 1)), np.zeros((self.num_joints, 3))], axis=1) + if isinstance(joint_color, tuple): + self.joint_colors = np.full((self.num_joints, 3), joint_color) + elif isinstance(joint_color, np.ndarray): + assert joint_color.shape == ( + self.num_joints, + 3, + ), "Joint colors must be (J, 3)" + self.joint_colors = joint_color + joint_scales = np.ones((self.num_joints, 3)) + hand_roots = {"LeftHand", "RightHand"} + finger_joint_names = set(skeleton.left_hand_joint_names + skeleton.right_hand_joint_names) - hand_roots + for jname in finger_joint_names: + if jname in skeleton.bone_index: + joint_scales[skeleton.bone_index[jname]] = 0.6 + self.joint_scales = joint_scales + + self.joints_batched_mesh = server.scene.add_batched_meshes_simple( + f"{name}/joints", + vertices=joint_mesh.vertices, + faces=joint_mesh.faces, + batched_wxyzs=init_joints_wxyzs, + batched_positions=np.zeros((self.num_joints, 3)), + batched_scales=joint_scales, + batched_colors=self.joint_colors, + ) + init_bones_wxyzs = np.concatenate([np.ones((num_bones, 1)), np.zeros((num_bones, 3))], axis=1) + if isinstance(bone_color, tuple): + bone_color = np.full((num_bones, 3), bone_color) + elif isinstance(bone_color, np.ndarray): + assert bone_color.shape == (num_bones, 3), "Bone colors must be (J-1, 3)" + bone_color = bone_color + self.bones_batched_mesh = server.scene.add_batched_meshes_simple( + f"{name}/bones", + vertices=bone_mesh.vertices, + faces=bone_mesh.faces, + batched_wxyzs=init_bones_wxyzs, + batched_positions=np.zeros((num_bones, 3)), + batched_scales=np.ones((num_bones, 3)), + batched_colors=bone_color, + ) + + self.mesh_info_cache = None + + if starting_joints_pos is not None: + self.set_pose(starting_joints_pos) + else: + if isinstance(skeleton, SOMASkeleton77): + skel30 = SOMASkeleton30(load=True) + min_height = skel30.neutral_joints[:, 1].min().item() + else: + min_height = init_joints_pos[:, 1].min().item() + init_joints_pos[:, 1] -= min_height # move to be on ground + self.set_pose(init_joints_pos) + + def compute_single_pose(self, joints_pos: np.ndarray): + """Compute the mesh for a single frame. + + joints_pos: [J, 3] global joint positions. + """ + new_batched_positions = np.zeros((self.skeleton.nbjoints - 1, 3)) + new_batched_wxyzs = np.zeros((self.skeleton.nbjoints - 1, 4)) + new_batched_scales = np.ones((self.skeleton.nbjoints - 1, 3)) + for joint_name, parent_name in self.skeleton.bone_order_names_with_parents: + if parent_name is None: + continue + joint_idx = self.skeleton.bone_index[joint_name] + parent_idx = self.skeleton.bone_index[parent_name] + joint_pos = joints_pos[joint_idx] + parent_pos = joints_pos[parent_idx] + + bone_pos = (joint_pos + parent_pos) / 2.0 + bone_scale = np.linalg.norm(joint_pos - parent_pos) + if bone_scale < 1e-8: + bone_wxyz = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) + else: + bone_dir = (joint_pos - parent_pos) / bone_scale + R = rotation_matrix_from_two_vec(np.array([0.0, 0.0, 1.0], dtype=np.float64), bone_dir) + bone_wxyz = tf.SO3.from_matrix(R).wxyz + + bone_idx = self.bone_to_idx[joint_name] + new_batched_positions[bone_idx] = bone_pos + new_batched_wxyzs[bone_idx] = bone_wxyz + new_batched_scales[bone_idx] = np.array([1.0, 1.0, bone_scale], dtype=float) + + return new_batched_positions, new_batched_wxyzs, new_batched_scales + + def precompute_mesh_info(self, joints_pos: torch.Tensor): + """Precompute the meshes for all frames at once. + + joints_pos: [T, J, 3]. + """ + joints_pos = joints_pos.cpu().numpy() + num_frames = joints_pos.shape[0] + self.mesh_info_cache = { + "positions": np.zeros((num_frames, self.skeleton.nbjoints - 1, 3)), + "wxyzs": np.zeros((num_frames, self.skeleton.nbjoints - 1, 4)), + "scales": np.ones((num_frames, self.skeleton.nbjoints - 1, 3)), + } + for i in range(num_frames): + new_batched_positions, new_batched_wxyzs, new_batched_scales = self.compute_single_pose(joints_pos[i]) + self.mesh_info_cache["positions"][i] = new_batched_positions + self.mesh_info_cache["wxyzs"][i] = new_batched_wxyzs + self.mesh_info_cache["scales"][i] = new_batched_scales + + def update_mesh_info_cache(self, joints_pos: torch.Tensor, frame_idx: int): + """Update the mesh info cache for the given frame.""" + assert self.mesh_info_cache is not None + new_batched_positions, new_batched_wxyzs, new_batched_scales = self.compute_single_pose( + joints_pos.cpu().numpy() + ) + self.mesh_info_cache["positions"][frame_idx] = new_batched_positions + self.mesh_info_cache["wxyzs"][frame_idx] = new_batched_wxyzs + self.mesh_info_cache["scales"][frame_idx] = new_batched_scales + + def set_pose( + self, + joints_pos: torch.Tensor, + foot_contacts: Optional[torch.Tensor] = None, + frame_idx: Optional[int] = None, + ): + """Set pose from [J, 3] global joint positions.""" + self.cur_joints_pos = joints_pos + joints_pos = joints_pos.cpu().numpy() + + if self.mesh_info_cache is not None: + assert frame_idx is not None + new_batched_positions = self.mesh_info_cache["positions"][frame_idx] + new_batched_wxyzs = self.mesh_info_cache["wxyzs"][frame_idx] + new_batched_scales = self.mesh_info_cache["scales"][frame_idx] + else: + new_batched_positions, new_batched_wxyzs, new_batched_scales = self.compute_single_pose(joints_pos) + + self.bones_batched_mesh.batched_positions = new_batched_positions + self.bones_batched_mesh.batched_wxyzs = new_batched_wxyzs + self.bones_batched_mesh.batched_scales = new_batched_scales + self.joints_batched_mesh.batched_positions = joints_pos + + if foot_contacts is not None: + cur_joint_colors = self.joint_colors.copy() + foot_contacts = foot_contacts.bool().cpu().numpy().astype(bool) + foot_joints = np.array(self.skeleton.foot_joint_idx, dtype=int) + contact_idx = foot_joints[foot_contacts] + cur_joint_colors[contact_idx] = (255, 0, 0) + self.joints_batched_mesh.batched_colors = cur_joint_colors + else: + self.joints_batched_mesh.batched_colors = self.joint_colors + + def set_visibility(self, visible: bool): + self.joints_batched_mesh.visible = visible + self.bones_batched_mesh.visible = visible + + def get_pose(self) -> np.ndarray: + return self.cur_joints_pos + + def clear(self): + names = [mesh.name for mesh in [self.joints_batched_mesh, self.bones_batched_mesh]] + for name in names: + self.server.scene.remove_by_name(name) + + +LIGHT_THEME = dict( + mesh=(152, 189, 255), +) + +DARK_THEME = dict( + mesh=(100, 135, 195), +) + +SKIN_CACHE = {} + + +class Character: + def __init__( + self, + name: str, + server: viser.ViserServer | viser.ClientHandle, + skeleton: SkeletonBase, + create_skeleton_mesh: bool = True, + create_skinned_mesh: bool = True, + visible_skeleton: bool = False, + visible_skinned_mesh: bool = True, + skinned_mesh_opacity: float = 1.0, + show_foot_contacts: bool = True, + dark_mode: bool = False, + mesh_mode: Optional[str] = None, + gui_use_soma_layer_checkbox: Optional[viser.GuiCheckboxHandle] = None, + ): + self.server = server + self.name = name + self.skeleton = skeleton + self.cur_joints_pos = None + self.cur_joints_rot = None + self.cur_foot_contacts = None + + self.skeleton_mesh = None + self.show_foot_contacts = show_foot_contacts + if create_skeleton_mesh: + self.skeleton_mesh = SkeletonMesh(f"/{name}/skeleton", server, skeleton) + self.cur_joints_pos = self.skeleton_mesh.get_pose() + self.skeleton_mesh.set_visibility(visible_skeleton) + + self.skinned_mesh = None + self.skin = None + self.mesh_mode = mesh_mode + self.g1_mesh_rig = None + if create_skinned_mesh: + if isinstance(self.skeleton, (SOMASkeleton30, SOMASkeleton77)) and mesh_mode in [ + "soma_skin", + "soma_layer_skin", + ]: + if mesh_mode in SKIN_CACHE: + # already okay + pass + else: + if mesh_mode == "soma_layer_skin": + try: + # try importing the lib + from .soma_layer_skin import SOMASkin as SOMASkin_SOMA + + if mesh_mode not in SKIN_CACHE: + SKIN_CACHE[mesh_mode] = SOMASkin_SOMA(self.skeleton) + + except (ModuleNotFoundError, FileNotFoundError) as e: + if isinstance(e, ModuleNotFoundError): + msg = "SOMA layer skin is unavailable: the soma package is not installed." + else: + msg = "SOMA layer skin is unavailable: SOMA asset files are missing." + traceback.print_exc() + if hasattr(self.server, "add_notification"): + self.server.add_notification( + "SOMA layer skin unavailable", + msg, + auto_close_seconds=5.0, + with_close_button=True, + ) + if gui_use_soma_layer_checkbox is not None: + gui_use_soma_layer_checkbox.value = False + mesh_mode = "soma_skin" + + # another if, in case mesh_mode changed + if mesh_mode == "soma_skin" and mesh_mode not in SKIN_CACHE: + SKIN_CACHE[mesh_mode] = SOMASkin(self.skeleton) + + self.skin = SKIN_CACHE[mesh_mode] + self.skinned_mesh = server.scene.add_mesh_simple( + f"/{name}/simple_skinned", + vertices=self.skin.bind_vertices.cpu().numpy(), + faces=self.skin.faces.cpu().numpy(), + opacity=None, + color=LIGHT_THEME["mesh"] if not dark_mode else DARK_THEME["mesh"], + wireframe=False, + visible=False, + ) + self.skinned_verts_cache = None + + bind_pos = self.skeleton.neutral_joints.clone() + if isinstance(self.skeleton, SOMASkeleton77): + skel30 = SOMASkeleton30(load=True) + min_height = skel30.neutral_joints[:, 1].min().item() + else: + min_height = bind_pos[:, 1].min().item() + bind_pos[:, 1] -= min_height + bind_pos[:, 1] += 0.02 + bind_rotmat = torch.eye(3, device=bind_pos.device).repeat(bind_pos.shape[0], 1, 1) + self.set_pose(bind_pos, bind_rotmat) + self.skinned_mesh.visible = True + self.set_skinned_mesh_visibility(visible_skinned_mesh) + self.set_skinned_mesh_opacity(skinned_mesh_opacity) + elif isinstance(self.skeleton, SMPLXSkeleton22) and mesh_mode == "smplx_skin": + if mesh_mode not in SKIN_CACHE: + SKIN_CACHE[mesh_mode] = SMPLXSkin(self.skeleton) + self.skin = SKIN_CACHE[mesh_mode] + self.skinned_mesh = server.scene.add_mesh_simple( + f"/{name}/simple_skinned", + vertices=self.skin.bind_vertices.cpu().numpy(), + faces=self.skin.faces.cpu().numpy(), + opacity=None, + color=LIGHT_THEME["mesh"] if not dark_mode else DARK_THEME["mesh"], + wireframe=False, + visible=False, + ) + self.skinned_verts_cache = None + + bind_pos = self.skeleton.neutral_joints.clone() + min_height = bind_pos[:, 1].min().item() + bind_pos[:, 1] -= min_height + bind_rotmat = torch.eye(3, device=bind_pos.device).repeat(bind_pos.shape[0], 1, 1) + self.set_pose(bind_pos, bind_rotmat) + self.skinned_mesh.visible = True + self.set_skinned_mesh_visibility(visible_skinned_mesh) + self.set_skinned_mesh_opacity(skinned_mesh_opacity) + elif isinstance(self.skeleton, G1Skeleton34) and mesh_mode == "g1_stl": + g1_mesh_dir = Path(self.skeleton.folder) / "meshes/g1" + if not os.path.exists(g1_mesh_dir): + raise ValueError(f"G1 mesh directory not found: {g1_mesh_dir}") + self.g1_mesh_rig = G1MeshRig( + name, + server, + self.skeleton, + str(g1_mesh_dir), + DARK_THEME["mesh"] if dark_mode else LIGHT_THEME["mesh"], + ) + init_joints_rot = self.skeleton.rest_pose_local_rot.clone() + init_global_joint_rots, _, init_joints_pos = self.skeleton.fk( + init_joints_rot, + torch.zeros(3, device=init_joints_rot.device, dtype=init_joints_rot.dtype), + ) + min_height = init_joints_pos[:, 1].min().item() + init_joints_pos[:, 1] -= min_height + self.set_pose(init_joints_pos, init_global_joint_rots) + self.set_skinned_mesh_visibility(visible_skinned_mesh) + self.set_skinned_mesh_opacity(skinned_mesh_opacity) + else: + raise ValueError( + "Unsupported mesh mode for skeleton type: " + f"{type(self.skeleton).__name__} with mesh_mode={mesh_mode}" + ) + + def change_theme(self, is_dark_mode): + color = DARK_THEME["mesh"] if is_dark_mode else LIGHT_THEME["mesh"] + if self.skinned_mesh is not None: + self.skinned_mesh.color = color + if self.g1_mesh_rig is not None: + self.g1_mesh_rig.set_color(color) + + def set_skeleton_visibility(self, visible: bool): + if self.skeleton_mesh is not None: + self.skeleton_mesh.set_visibility(visible) + + def set_show_foot_contacts(self, show: bool, frame_idx: Optional[int] = None): + self.show_foot_contacts = show + if self.skeleton_mesh is not None and self.cur_joints_pos is not None: + fc = self.cur_foot_contacts if show else None + self.skeleton_mesh.set_pose(self.cur_joints_pos, foot_contacts=fc, frame_idx=frame_idx) + + def set_skinned_mesh_visibility(self, visible: bool): + if self.skinned_mesh is not None: + self.skinned_mesh.visible = visible + if self.g1_mesh_rig is not None: + self.g1_mesh_rig.set_visibility(visible) + + def set_skinned_mesh_opacity(self, opacity: float): + if self.skinned_mesh is not None: + self.skinned_mesh.opacity = opacity + if self.g1_mesh_rig is not None: + self.g1_mesh_rig.set_opacity(opacity) + + def set_skinned_mesh_wireframe(self, wireframe: bool): + if self.skinned_mesh is not None: + self.skinned_mesh.wireframe = wireframe + if self.g1_mesh_rig is not None: + self.g1_mesh_rig.set_wireframe(wireframe) + + def precompute_skinning(self, joints_pos: torch.Tensor, joints_rot: torch.Tensor, chunk_size: int = 512): + """Precompute skinning for all frames, processing in chunks to avoid OOM. + + joints_pos: [T, J, 3], joints_rot: [T, J, 3, 3]. + """ + assert self.skin is not None + T = joints_pos.shape[0] + if T <= chunk_size: + self.skinned_verts_cache = self.skin.skin(joints_rot, joints_pos, rot_is_global=True).cpu().numpy() + else: + chunks = [] + for start in range(0, T, chunk_size): + end = min(start + chunk_size, T) + verts = self.skin.skin(joints_rot[start:end], joints_pos[start:end], rot_is_global=True).cpu().numpy() + chunks.append(verts) + self.skinned_verts_cache = np.concatenate(chunks, axis=0) + + def update_skinning_cache(self, joints_pos: torch.Tensor, joints_rot: torch.Tensor, frame_idx: int): + """Update skinning cache for one frame.""" + if self.skinned_verts_cache is None: + return + new_skinned_verts = self.skin.skin(joints_rot[None], joints_pos[None], rot_is_global=True)[0].cpu().numpy() + self.skinned_verts_cache[frame_idx] = new_skinned_verts + + def set_pose( + self, + joints_pos: torch.Tensor, + joints_rot: torch.Tensor, + foot_contacts: Optional[torch.Tensor] = None, + frame_idx: Optional[int] = None, + ): + if self.skeleton_mesh is not None: + self.cur_foot_contacts = foot_contacts + display_fc = foot_contacts if self.show_foot_contacts else None + self.skeleton_mesh.set_pose(joints_pos, foot_contacts=display_fc, frame_idx=frame_idx) + + if self.skinned_mesh is not None: + if self.skinned_verts_cache is not None: + assert frame_idx is not None + skinned_verts = self.skinned_verts_cache[frame_idx] + else: + skinned_verts = self.skin.skin(joints_rot[None], joints_pos[None], rot_is_global=True)[0].cpu().numpy() + self.skinned_mesh.vertices = skinned_verts + if self.g1_mesh_rig is not None: + joints_pos_np = joints_pos.detach().cpu().numpy() + joints_rot_np = joints_rot.detach().cpu().numpy() + self.g1_mesh_rig.set_pose(joints_pos_np, joints_rot_np) + + self.cur_joints_pos = joints_pos + self.cur_joints_rot = joints_rot + + def get_pose(self) -> torch.Tensor: + return self.cur_joints_pos, self.cur_joints_rot + + def clear(self): + if self.skeleton_mesh is not None: + self.skeleton_mesh.clear() + if self.skinned_mesh is not None: + self.server.scene.remove_by_name(self.skinned_mesh.name) + if self.g1_mesh_rig is not None: + self.g1_mesh_rig.clear() diff --git a/kimodo/viz/smplx_skin.py b/kimodo/viz/smplx_skin.py new file mode 100644 index 0000000000000000000000000000000000000000..710c873230b2a2c893bcfd366e3fa17841e50b64 --- /dev/null +++ b/kimodo/viz/smplx_skin.py @@ -0,0 +1,305 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""SMPL-X skinning and joint mapping for visualization.""" + +import os +import warnings +from pathlib import Path + +import numpy as np +import torch + +from kimodo.geometry import axis_angle_to_matrix +from kimodo.skeleton import SMPLXSkeleton22, batch_rigid_transform + +SKIN_NAME = "SMPLX_NEUTRAL.npz" +BETA_NAME = "beta.npy" +MEAN_HANDS_NAME = "mean_hands.npy" + +SMPLX_BODY_JOINT_NAME_MAP = { + "pelvis": "Pelvis", + "left_hip": "L_Hip", + "right_hip": "R_Hip", + "spine1": "Spine1", + "left_knee": "L_Knee", + "right_knee": "R_Knee", + "spine2": "Spine2", + "left_ankle": "L_Ankle", + "right_ankle": "R_Ankle", + "spine3": "Spine3", + "left_foot": "L_Foot", + "right_foot": "R_Foot", + "neck": "Neck", + "left_collar": "L_Collar", + "right_collar": "R_Collar", + "head": "Head", + "left_shoulder": "L_Shoulder", + "right_shoulder": "R_Shoulder", + "left_elbow": "L_Elbow", + "right_elbow": "R_Elbow", + "left_wrist": "L_Wrist", + "right_wrist": "R_Wrist", +} + +# SMPL-X hand pose order (15 joints per hand) matching SMPL-X index order. +SMPLX_HAND_JOINT_ORDER = [ + "Index1", + "Index2", + "Index3", + "Middle1", + "Middle2", + "Middle3", + "Pinky1", + "Pinky2", + "Pinky3", + "Ring1", + "Ring2", + "Ring3", + "Thumb1", + "Thumb2", + "Thumb3", +] + +SMPLX_FACE_JOINT_NAMES = ["Jaw", "L_Eye", "R_Eye"] + + +class SMPLXSkin: + def __init__( + self, + skeleton, + use_mean_hands=True, + ): + skel_dir = Path(skeleton.folder) + skin_data_path = skel_dir / SKIN_NAME + + if not skin_data_path.exists(): + raise FileExistsError( + f"You should download the {SKIN_NAME} from the smplx website, and put it there: {skin_data_path}" + ) + + beta_path = skel_dir / BETA_NAME + mean_hands_path = skel_dir / MEAN_HANDS_NAME + + self.skeleton = skeleton + assert isinstance(skeleton, SMPLXSkeleton22), "SMPLXSkin only supports SMPLXSkeleton22" + assert skeleton.neutral_joints is not None, "SMPLXSkeleton22 must have neutral joints instantiated" + + device = skeleton.neutral_joints.device + with warnings.catch_warnings(): + # Ignore legacy object-dtype warning emitted while unpickling old SMPL-X assets. + warnings.filterwarnings( + "ignore", + message=r"dtype\(\): align should be passed as Python or NumPy boolean.*", + category=Warning, + module=r"numpy\.lib\._format_impl", + ) + # np.load on .npz is lazy; materialize all fields while filter is active. + with np.load(skin_data_path, allow_pickle=True) as skin_npz: + skin_data = {key: skin_npz[key] for key in skin_npz.files} + + joint2num = skin_data["joint2num"] + if isinstance(joint2num, np.ndarray): + joint2num = joint2num.item() + self.full_joint_count = int(skin_data["weights"].shape[1]) + kintree_table = np.array(skin_data["kintree_table"], dtype=np.int64) + parents = kintree_table[0].copy() + parents[parents > 1_000_000_000] = -1 + self.full_joint_parents = torch.tensor(parents, device=device, dtype=torch.long) + root_candidates = np.where(parents == -1)[0] + self.full_root_idx = int(root_candidates[0]) if root_candidates.size else 0 + self.joint_regressor = torch.tensor( + np.array(skin_data["J_regressor"], dtype=np.float32), + device=device, + dtype=torch.float, + ) + + rig_joint_names = [] + rig_joint_indices = [] + for joint_name in self.skeleton.bone_order_names: + mapped_name = SMPLX_BODY_JOINT_NAME_MAP.get(joint_name) + if mapped_name is None or mapped_name not in joint2num: + raise ValueError(f"Missing SMPL-X joint mapping for '{joint_name}'") + rig_joint_names.append(mapped_name) + rig_joint_indices.append(int(joint2num[mapped_name])) + self.body_joint_indices = np.array(rig_joint_indices, dtype=np.int64) + + # Prepare mean hand pose rotations for joints not produced by the model. + if use_mean_hands and mean_hands_path is not None and os.path.exists(mean_hands_path): + mean_hands = np.array(np.load(mean_hands_path), dtype=np.float32) + else: + mean_hands = np.zeros(90, dtype=np.float32) + if mean_hands.shape[0] != 90: + raise ValueError(f"Expected mean_hands shape (90,), got {mean_hands.shape}") + mean_hands = mean_hands.reshape(30, 3) + mean_hands_rotmats = axis_angle_to_matrix(torch.tensor(mean_hands, device=device, dtype=torch.float)) + left_hand_joint_names = [f"L_{name}" for name in SMPLX_HAND_JOINT_ORDER] + right_hand_joint_names = [f"R_{name}" for name in SMPLX_HAND_JOINT_ORDER] + left_indices = [joint2num[name] for name in left_hand_joint_names] + right_indices = [joint2num[name] for name in right_hand_joint_names] + self.hand_joint_indices = np.array(left_indices + right_indices, dtype=np.int64) + self.mean_hand_rotmats = mean_hands_rotmats + face_indices = [joint2num[name] for name in SMPLX_FACE_JOINT_NAMES if name in joint2num] + self.face_joint_indices = np.array(face_indices, dtype=np.int64) + self.mean_face_rotmats = torch.eye(3, device=device).repeat(len(self.face_joint_indices), 1, 1) + + # bind_rig_transform: [J, 4, 4] + # bind_vertices: [V, 3] + # faces: [F, 3] + # lbs indices, lbs weights: [V, W] (W = number of joints) + v_template = np.array(skin_data["v_template"], dtype=np.float32) + faces = np.array(skin_data["f"], dtype=np.int64) + weights = np.array(skin_data["weights"], dtype=np.float32) + + shapedirs = np.array(skin_data["shapedirs"], dtype=np.float32) + posedirs = np.array(skin_data["posedirs"], dtype=np.float32) + + if beta_path is not None and os.path.exists(beta_path): + betas = np.array(np.load(beta_path), dtype=np.float32) + else: + betas = np.zeros(300, dtype=np.float32) + + num_shape_coeffs = shapedirs.shape[2] # 400 = 300 + 100 (shape + expression) + if betas.shape[0] < num_shape_coeffs: + betas = np.pad(betas, (0, num_shape_coeffs - betas.shape[0]), mode="constant") + elif betas.shape[0] > num_shape_coeffs: + betas = betas[:num_shape_coeffs] + + v_shaped = v_template + np.tensordot(shapedirs, betas, axes=[2, 0]) + self.v_shaped = torch.tensor(v_shaped, device=device, dtype=torch.float) + self.posedirs = torch.tensor(posedirs, device=device, dtype=torch.float) + self.joint_rest = torch.einsum("jv,vc->jc", self.joint_regressor, self.v_shaped) + + # Align SMPL-X body rest joints to the model skeleton rest pose. + body_rest = self.skeleton.neutral_joints.to(device=device, dtype=torch.float) + if body_rest.shape[0] == self.body_joint_indices.shape[0]: + # Treat mismatches as a warning and align to the skeleton pose anyway. + max_delta = (self.joint_rest[self.body_joint_indices] - body_rest).abs().max() + if max_delta > 1e-6: + print( + "Warning: SMPL-X rest pose mismatch (max_delta=" + f"{max_delta:.2e}); aligning to skeleton neutral joints." + ) + self.joint_rest[self.body_joint_indices] = body_rest + + # Renormalize weights to avoid numerical issues. + weight_sums = weights.sum(axis=1, keepdims=True) + zero_mask = weight_sums[:, 0] < 1e-8 + weights = weights / np.clip(weight_sums, 1e-8, None) + if np.any(zero_mask): + weights[zero_mask, :] = 0.0 + weights[zero_mask, self.full_root_idx] = 1.0 + + joint_indices = np.arange(self.full_joint_count, dtype=np.int64) + lbs_indices = np.tile(joint_indices[None, :], (v_template.shape[0], 1)) + + bind_rig_np = np.zeros((self.full_joint_count, 4, 4), dtype=np.float32) + bind_rig_np[:, 3, 3] = 1.0 + bind_rig_np[:, :3, :3] = np.eye(3, dtype=np.float32) + bind_rig_np[:, :3, 3] = self.joint_rest.detach().cpu().numpy() + + self.bind_rig_transform = torch.from_numpy(bind_rig_np).to(device=device, dtype=torch.float) + bind_rig_inv_np = np.linalg.inv(bind_rig_np) + self.bind_rig_transform_inv = torch.from_numpy(bind_rig_inv_np).to(device=device, dtype=torch.float) + self.bind_vertices = torch.tensor(v_shaped, device=device, dtype=torch.float) + self.faces = torch.tensor(faces, device=device, dtype=torch.long) + self.lbs_indices = torch.tensor(lbs_indices, device=device, dtype=torch.long) + self.lbs_weights = torch.tensor(weights, device=device, dtype=torch.float) + + # double check the rig matches expected skeleton order + for sname, rname in zip(self.skeleton.bone_order_names, rig_joint_names): + mapped_name = SMPLX_BODY_JOINT_NAME_MAP.get(sname) + if mapped_name != rname: + raise ValueError(f"MISMATCH in skinning rig: expected='{mapped_name}' vs rig='{rname}'") + + def lbs(self, posed_transform, bind_vertices=None): + bind_rig_transform_inv = self.bind_rig_transform_inv + if bind_vertices is None: + bind_vertices = self.bind_vertices + lbs_weights = self.lbs_weights + # posed_transform: [B, F, J, 4, 4] or [B, J, 4, 4] or [J, 4, 4] + # unsqueeze to match posed_transform batch dims + batch_dims = posed_transform.shape[:-3] + if bind_vertices.dim() == 2: + for _ in batch_dims: + bind_vertices = bind_vertices.unsqueeze(0) + elif bind_vertices.dim() == 3: + if len(batch_dims) == 1: + if bind_vertices.shape[0] != batch_dims[0]: + bind_vertices = bind_vertices.unsqueeze(0) + elif len(batch_dims) > 1: + for _ in range(len(batch_dims) - 1): + bind_vertices = bind_vertices.unsqueeze(0) + for _ in batch_dims: + bind_rig_transform_inv = bind_rig_transform_inv.unsqueeze(0) + lbs_weights = lbs_weights.unsqueeze(0) + # bind_rig_transform_inv: [..., J, 4, 4] + # bind_vertices: [..., V, 3] + # lbs_weights: [..., V, W] + + affine_mat = (posed_transform @ bind_rig_transform_inv)[..., :3, :] # [..., J, 3, 4] + vs = ( + affine_mat[..., self.lbs_indices, :, :] + @ torch.concat([bind_vertices, torch.ones_like(bind_vertices[..., 0:1])], dim=-1)[..., None, :, None] + ) # [..., V, W, 3, 1] + ws = lbs_weights[..., None, None] + resv = (vs * ws).sum(dim=-3).squeeze(-1) # [..., V, 3] + return resv + + def skin(self, joint_rotmat, joint_pos, rot_is_global=False): + """ + joint_rotmat: [T, J, 3, 3] local or global joint rotation matrices + joint_pos: [T, J, 3] global joint positions + rot_is_global: bool, if True, joint_rotmat is global rotation matrices, + otherwise it is local rotation matrices and FK is performed internally + """ + nF, nJ = joint_pos.shape[:2] + device = joint_rotmat.device + + # import ipdb; ipdb.set_trace() + if rot_is_global: + if joint_rotmat.shape[1] == self.full_joint_count: + local_rotmat_full = joint_rotmat.clone() + parents = self.full_joint_parents.to(device) + parent_rot_mats = local_rotmat_full[:, parents] + parent_rot_mats[:, self.full_root_idx] = torch.eye(3, device=device) + parent_rot_mats_inv = parent_rot_mats.transpose(2, 3) + local_rotmat_full = torch.einsum( + "T N m n, T N n o -> T N m o", + parent_rot_mats_inv, + local_rotmat_full, + ) + else: + local_rotmat = self.skeleton.global_rots_to_local_rots(joint_rotmat) + else: + local_rotmat = joint_rotmat + + if rot_is_global and joint_rotmat.shape[1] == self.full_joint_count: + full_local = local_rotmat_full + else: + full_local = torch.eye(3, device=device).reshape(1, 1, 3, 3).repeat(nF, self.full_joint_count, 1, 1) + full_local[:, self.body_joint_indices] = local_rotmat + if self.mean_hand_rotmats is not None: + full_local[:, self.hand_joint_indices] = self.mean_hand_rotmats[None] + if self.mean_face_rotmats is not None: + full_local[:, self.face_joint_indices] = self.mean_face_rotmats[None] + pose_feature = (full_local[:, 1:] - torch.eye(3, device=device)[None, None]).reshape(nF, -1) + + pose_offsets = torch.einsum("vcp,tp->tvc", self.posedirs, pose_feature) + v_posed = self.v_shaped[None] + pose_offsets + joints_rest = self.joint_rest[None].repeat(nF, 1, 1) + posed_joints, global_joint_rots = batch_rigid_transform( + full_local, + joints_rest, + self.full_joint_parents.to(device), + self.full_root_idx, + ) + # remove the skeleton offset of the root joint + root_trans = joint_pos[:, self.skeleton.root_idx] - self.skeleton.neutral_joints[0:1] + posed_joints = posed_joints + root_trans[:, None, :] + + fk_transform = torch.eye(4, device=device)[None, None].repeat(nF, self.full_joint_count, 1, 1) + fk_transform[..., :3, :3] = global_joint_rots + fk_transform[..., :3, 3] = posed_joints + + vertices = self.lbs(fk_transform, bind_vertices=v_posed) + return vertices diff --git a/kimodo/viz/soma_layer_skin.py b/kimodo/viz/soma_layer_skin.py new file mode 100644 index 0000000000000000000000000000000000000000..63747eba3e2c347184377ad6251635173fd7c6b6 --- /dev/null +++ b/kimodo/viz/soma_layer_skin.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""SOMA layer-based skinning for visualization (SOMASkeleton30 / SOMASkeleton77).""" + +from pathlib import Path + +import numpy as np +import torch +from huggingface_hub import snapshot_download +from soma import SomaLayer as SOMALayer + +from kimodo.assets import SOMA_ASSETS_ROOT +from kimodo.skeleton import SOMASkeleton30, SOMASkeleton77, global_rots_to_local_rots + +SOMA_MHR_NEUTRAL_PATH = "somaskel30/soma_base_fit_mhr_params.npz" + + +class SOMASkin: + def __init__( + self, + skeleton, + ): + self.skeleton = skeleton + + assert isinstance( + skeleton, (SOMASkeleton30, SOMASkeleton77) + ), "SOMASkin currently only supports SOMASkeleton30 or SOMASkeleton77" + assert skeleton.neutral_joints is not None, "The skeleton must have neutral joints instantiated" + + device = skeleton.neutral_joints.device + device = "cpu" + self.device = device + + self._soma_model = SOMALayer( + identity_model_type="mhr", + device=device, + ) + self.faces = self._soma_model.faces + + neutral_mhr_path = Path(skeleton.folder).parent / SOMA_MHR_NEUTRAL_PATH + neutral_mhr = np.load(neutral_mhr_path) + + # one time call to prepare the identity + self.soma_identity = torch.from_numpy(neutral_mhr["identity_params"]) + self.scale_params = torch.from_numpy(neutral_mhr["scale_params"]) + self._soma_model.prepare_identity(self.soma_identity.to(device), scale_params=self.scale_params.to(device)) + + # dummy output to get bind_vertices + transl = torch.zeros(1, 3, device=device) + + self._full_skeleton = SOMASkeleton77() + self.skel_slice = self.skeleton.get_skel_slice(self._full_skeleton) + + self.bind_vertices = self.soma_model_pose( + self._full_skeleton.relaxed_hands_rest_pose[None], + transl=transl, + pose2rot=False, + )["vertices"][0] + + def soma_model_pose(self, *args, **kwargs): + with torch.inference_mode(): + return self._soma_model.pose(*args, **kwargs) + + def skin(self, joint_rotmat, joint_pos, rot_is_global=False): + """ + joint_rotmat: [T, J, 3, 3] local or global joint rotation matrices + joint_pos: [T, J, 3] global joint positions + rot_is_global: bool, if True, joint_rotmat is global rotation matrices, otherwise it is local rotation matrices and FK is performed internally + """ + + nF, nJ = joint_pos.shape[:2] + + if rot_is_global: + local_joint_rots_mats_subset = global_rots_to_local_rots(joint_rotmat, self.skeleton) + else: + local_joint_rots_mats_subset = joint_rotmat + + if nJ != self._full_skeleton.nbjoints: + local_joint_rots_mats = self.skeleton.to_SOMASkeleton77(local_joint_rots_mats_subset) + else: + local_joint_rots_mats = local_joint_rots_mats_subset + + # remove the skeleton offset of the root joint + transl = joint_pos[:, self.skeleton.root_idx] - self.skeleton.neutral_joints[0:1] + + output = self.soma_model_pose( + local_joint_rots_mats.to(device=self.device, dtype=torch.float32), + transl=transl.to(device=self.device, dtype=torch.float32), + pose2rot=False, + ) + return output["vertices"] diff --git a/kimodo/viz/soma_skin.py b/kimodo/viz/soma_skin.py new file mode 100644 index 0000000000000000000000000000000000000000..befea6697b1a30e010aab1260bfdfceac0d424c4 --- /dev/null +++ b/kimodo/viz/soma_skin.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""SOMA skeleton skinning for visualization (SOMASkeleton30 / SOMASkeleton77).""" + +from pathlib import Path + +import numpy as np +import torch + +from kimodo.skeleton import ( + SOMASkeleton30, + SOMASkeleton77, + batch_rigid_transform, + global_rots_to_local_rots, +) + +# Skin for SOMASkeleton77 +SKEL_PATH = "somaskel77" +SKIN_NAME = "skin_standard.npz" + + +class SOMASkin: + def __init__(self, skeleton): + skel_path = Path(skeleton.folder).parent / SKEL_PATH + skin_data_path = skel_path / SKIN_NAME + + self.skeleton_input = skeleton + assert isinstance( + skeleton, (SOMASkeleton30, SOMASkeleton77) + ), "SOMASkin currently only supports SOMASkeleton30 or SOMASkeleton77" + assert skeleton.neutral_joints is not None, "The skeleton must have neutral joints instantiated" + device = skeleton.neutral_joints.device + + # the skin is always the 77-joint skeleton + # if user is using the 30-joint skeleton, we will pad it when skinning is called + self.skeleton_skin = SOMASkeleton77(skel_path).to(device) + + # bind_rig_transform: [R, 4, 4] + # bind_vertices: [V, 3] + # faces: [F, 3] + # lbs indices, lbs weights: [V, W] (W = max (num joints vertice is related to), in our case W=5) + skin_data = np.load(skin_data_path) + bind_rig_np = np.array(skin_data["bind_rig_transform"], dtype=np.float32) + self.bind_rig_transform = torch.from_numpy(bind_rig_np).to(device=device, dtype=torch.float) + # Precompute the inverse in numpy to avoid torch lazy evaluation issues + bind_rig_inv_np = np.linalg.inv(bind_rig_np) + self.bind_rig_transform_inv = torch.from_numpy(bind_rig_inv_np).to(device=device, dtype=torch.float) + self.bind_vertices = torch.tensor(skin_data["bind_vertices"], device=device, dtype=torch.float) + self.faces = torch.tensor(skin_data["faces"], device=device, dtype=torch.long) + self.lbs_indices = torch.tensor(skin_data["lbs_indices"], device=device, dtype=torch.long) + self.lbs_weights = torch.tensor(skin_data["lbs_weights"], device=device, dtype=torch.float) + + # double check the rig matches expected skeleton + rig_joint_names = list(skin_data["rig_joint_names"]) # list(str) : [R] + for sname, rname in zip(self.skeleton_skin.bone_order_names, rig_joint_names): + if sname != rname: + raise ValueError(f"MISMATCH in skinnging rig: expected='{sname}' vs rig='{rname}'") + + def lbs(self, posed_transform): + bind_rig_transform_inv = self.bind_rig_transform_inv + bind_vertices = self.bind_vertices + lbs_weights = self.lbs_weights + # posed_transform: [B, F, J, 4, 4] or [B, J, 4, 4] or [J, 4, 4] + # unsqueeze to match posed_transform dim + for _ in range(posed_transform.dim() - 3): + bind_rig_transform_inv = bind_rig_transform_inv.unsqueeze(0) + bind_vertices = bind_vertices.unsqueeze(0) + lbs_weights = lbs_weights.unsqueeze(0) + # bind_rig_transform_inv: [..., R, 4, 4] + # bind_vertices: [..., V, 3] + # lbs_weights: [..., V, W] + + affine_mat = (posed_transform @ bind_rig_transform_inv)[..., :3, :] # [..., J, 3, 4] + vs = ( + affine_mat[..., self.lbs_indices, :, :] + @ torch.concat([bind_vertices, torch.ones_like(bind_vertices[..., 0:1])], dim=-1)[..., None, :, None] + ) # [..., V, W, 3, 1] + ws = lbs_weights[..., None, None] + resv = (vs * ws).sum(dim=-3).squeeze(-1) # [..., V, 3] + return resv + + def skin(self, joint_rotmat, joint_pos, rot_is_global=False): + """ + joint_rotmat: [T, J, 3, 3] local or global joint rotation matrices + joint_pos: [T, J, 3] global joint positions + rot_is_global: bool, if True, joint_rotmat is global rotation matrices, otherwise it is local rotation matrices and FK is performed internally + """ + nF, nJ = joint_pos.shape[:2] + device = joint_rotmat.device + + if nJ != self.skeleton_skin.nbjoints: + assert nJ == 30, "SOMASkin currently only supports 30-joint or 77-joint skeletons" + + # make sure we have local joint rotations + if rot_is_global: + local_joint_rots_mats_subset = global_rots_to_local_rots(joint_rotmat, self.skeleton_input) + else: + local_joint_rots_mats_subset = joint_rotmat + + local_joint_rots_mats = self.skeleton_input.to_SOMASkeleton77(local_joint_rots_mats_subset) + + # FK to get the global joint pos and rot + neutral_joints_seq = self.skeleton_skin.neutral_joints[None].repeat((nF, 1, 1)).to(device) + new_joint_pos, joint_rotmat = batch_rigid_transform( + local_joint_rots_mats, + neutral_joints_seq, + self.skeleton_skin.joint_parents.to(device), + self.skeleton_skin.root_idx, + ) + joint_pos = new_joint_pos + joint_pos[:, self.skeleton_input.root_idx : self.skeleton_input.root_idx + 1] + nJ = self.skeleton_skin.nbjoints + rot_is_global = True + + # prepare full transformation matrices + fk_transform = torch.eye(4, device=device)[None, None].repeat(nF, nJ, 1, 1) + fk_transform[..., :3, 3] = joint_pos + if rot_is_global: + fk_transform[..., :3, :3] = joint_rotmat + else: + neutral_joints_seq = self.skeleton_skin.neutral_joints[None].repeat((nF, 1, 1)).to(device) + # FK to get the global rotations + _, global_joint_rotmat = batch_rigid_transform( + joint_rotmat, + neutral_joints_seq, + self.skeleton_skin.joint_parents.to(device), + self.skeleton_skin.root_idx, + ) + fk_transform[..., :3, :3] = global_joint_rotmat + + vertices = self.lbs(fk_transform) + return vertices diff --git a/kimodo/viz/viser_utils.py b/kimodo/viz/viser_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..66fc919d61402152bb472f3aa8fcd842b2d18674 --- /dev/null +++ b/kimodo/viz/viser_utils.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Viser-based 3D viz: re-exports from viz submodules for backward compatibility.""" + +import os + +from .constraint_ui import ( + ConstraintSet, + EEJointsKeyframeSet, + FullbodyKeyframeSet, + RootKeyframe2DSet, + build_constraint_set_table_markdown, + update_interval, +) +from .gui import GuiElements +from .playback import CharacterMotion +from .scene import ( + DARK_THEME, + LIGHT_THEME, + SKIN_CACHE, + Character, + SkeletonMesh, + WaypointMesh, +) + + +def load_example_cases(examples_base_dir): + """List subdirectories of examples_base_dir as a name -> path dict.""" + example_dirs = os.listdir(examples_base_dir) + example_names = sorted([d for d in example_dirs if os.path.isdir(os.path.join(examples_base_dir, d))]) + return {name: os.path.join(examples_base_dir, name) for name in example_names} + + +__all__ = [ + "Character", + "CharacterMotion", + "ConstraintSet", + "DARK_THEME", + "EEJointsKeyframeSet", + "FullbodyKeyframeSet", + "GuiElements", + "LIGHT_THEME", + "RootKeyframe2DSet", + "SKIN_CACHE", + "SkeletonMesh", + "WaypointMesh", + "build_constraint_set_table_markdown", + "load_example_cases", + "update_interval", +]