Spaces:
Runtime error
Runtime error
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| from collections import defaultdict | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| import viser | |
| from kimodo.constraints import ( | |
| TYPE_TO_CLASS, | |
| FullBodyConstraintSet, | |
| Root2DConstraintSet, | |
| ) | |
| from kimodo.exports.mujoco import apply_g1_real_robot_projection | |
| from kimodo.skeleton import G1Skeleton34, SOMASkeleton30 | |
| from kimodo.tools import seed_everything | |
| from .embedding_cache import CachedTextEncoder | |
| from .state import ClientSession, ModelBundle | |
| def compute_model_constraints_lst( | |
| session: ClientSession, | |
| model_bundle: ModelBundle, | |
| num_frames: int, | |
| device: str, | |
| ): | |
| """Compute the lst of constraints for the model based on the constraints in viser.""" | |
| assert len(session.motions) == 1, "Only one motion allowed for constrained generation" | |
| if not session.constraints: | |
| return [] | |
| model_skeleton = model_bundle.model.skeleton | |
| # For SOMA, UI uses somaskel77; extract 30-joint subset for the model | |
| use_skel_slice = isinstance(model_skeleton, SOMASkeleton30) and session.skeleton.nbjoints != model_skeleton.nbjoints | |
| skel_slice = model_skeleton.get_skel_slice(session.skeleton) if use_skel_slice else None | |
| dense_smooth_root_pos_2d = None | |
| if session.constraints["2D Root"].dense_path: | |
| # get the full 2d root | |
| dense_smooth_root_pos_2d = session.constraints["2D Root"].get_constraint_info(device=device)["root_pos"][ | |
| :, [0, 2] | |
| ] | |
| model_constraints = [] | |
| for track_name, constraint in session.constraints.items(): | |
| constraint_info = constraint.get_constraint_info(device=device) | |
| frame_idx = constraint_info["frame_idx"] | |
| # drop any constraints outside the generation range | |
| valid_info = [(i, fi) for i, fi in enumerate(frame_idx) if fi < num_frames] | |
| valid_idx = [i for i, _ in valid_info] | |
| valid_frame_idx = [fi for _, fi in valid_info] | |
| if len(valid_frame_idx) == 0: | |
| continue | |
| frame_indices = torch.tensor(valid_frame_idx) | |
| if track_name == "2D Root": | |
| smooth_root_pos_2d = constraint_info["root_pos"][valid_idx][:, [0, 2]].to(device) | |
| # same as "smooth_root_2d" | |
| model_constraints.append( | |
| Root2DConstraintSet( | |
| model_skeleton, | |
| frame_indices, | |
| smooth_root_pos_2d, | |
| ) | |
| ) | |
| elif track_name == "Full-Body": | |
| constraint_joints_pos = constraint_info["joints_pos"][valid_idx].to(device) | |
| constraint_joints_rot = constraint_info["joints_rot"][valid_idx].to(device) | |
| if skel_slice is not None: | |
| constraint_joints_pos = constraint_joints_pos[:, skel_slice] | |
| constraint_joints_rot = constraint_joints_rot[:, skel_slice] | |
| smooth_root_pos_2d = None | |
| if dense_smooth_root_pos_2d is not None: | |
| smooth_root_pos_2d = dense_smooth_root_pos_2d[frame_indices] | |
| model_constraints.append( | |
| FullBodyConstraintSet( | |
| model_skeleton, | |
| frame_indices, | |
| constraint_joints_pos, | |
| constraint_joints_rot, | |
| smooth_root_2d=smooth_root_pos_2d, | |
| ) | |
| ) | |
| elif track_name == "End-Effectors": | |
| constraint_joints_pos = constraint_info["joints_pos"][valid_idx].to(device) | |
| constraint_joints_rot = constraint_info["joints_rot"][valid_idx].to(device) | |
| if skel_slice is not None: | |
| constraint_joints_pos = constraint_joints_pos[:, skel_slice] | |
| constraint_joints_rot = constraint_joints_rot[:, skel_slice] | |
| end_effector_type_set_lst = [ | |
| end_effector_type_set | |
| for i, end_effector_type_set in enumerate(constraint_info["end_effector_type"]) | |
| if i in valid_idx | |
| ] | |
| # regroup the end effector data by type | |
| cls_idx = defaultdict(list) | |
| for idx, end_effector_type_set in enumerate(end_effector_type_set_lst): | |
| for end_effector_type in end_effector_type_set: | |
| cls_idx[TYPE_TO_CLASS[end_effector_type]].append(idx) | |
| for cls, lst_idx in cls_idx.items(): | |
| frame_indices_cls = frame_indices[lst_idx] | |
| smooth_root_pos_2d = None | |
| if dense_smooth_root_pos_2d is not None: | |
| smooth_root_pos_2d = dense_smooth_root_pos_2d[frame_indices_cls] | |
| constraint_joints_pos_el = constraint_joints_pos[lst_idx] | |
| constraint_joints_rot_el = constraint_joints_rot[lst_idx] | |
| model_constraints.append( | |
| cls( | |
| model_skeleton, | |
| frame_indices_cls, | |
| constraint_joints_pos_el, | |
| constraint_joints_rot_el, | |
| smooth_root_2d=smooth_root_pos_2d, | |
| ) | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported constraint type: {constraint.display_name}") | |
| return model_constraints | |
| def generate( | |
| *, | |
| client: viser.ClientHandle, | |
| session: ClientSession, | |
| model_bundle: ModelBundle, | |
| prompts: list[str], | |
| num_frames: list[int], | |
| num_samples: int, | |
| seed: int, | |
| diffusion_steps: int, | |
| cfg_weight: Optional[list[float]] = None, | |
| cfg_type: Optional[str] = None, | |
| postprocess_parameters: Optional[dict] = None, | |
| transitions_parameters: Optional[dict] = None, | |
| real_robot_rotations: bool = False, | |
| device: str, | |
| clear_motions, | |
| add_character_motion, | |
| ) -> None: | |
| client_id = client.client_id | |
| print( | |
| f"Generating {num_samples} samples for a total of {sum(num_frames)} frames with those prompt: {prompts} (client {client_id})" | |
| ) | |
| seed_everything(seed) | |
| model_constraints = compute_model_constraints_lst(session, model_bundle, sum(num_frames), device) | |
| cfg_weight = cfg_weight or [2.0, 2.0] | |
| postprocess_parameters = postprocess_parameters or {} | |
| transitions_parameters = transitions_parameters or {} | |
| encoder = getattr(model_bundle.model, "text_encoder", None) | |
| if isinstance(encoder, CachedTextEncoder): | |
| with encoder.session_context(session): | |
| pred_joints_output = model_bundle.model( | |
| prompts, | |
| num_frames, | |
| diffusion_steps, | |
| multi_prompt=True, | |
| constraint_lst=model_constraints, | |
| cfg_weight=cfg_weight, | |
| num_samples=num_samples, | |
| cfg_type=cfg_type, | |
| **(postprocess_parameters | transitions_parameters), | |
| ) # [B, T, motion_rep_dim] | |
| else: | |
| pred_joints_output = model_bundle.model( | |
| prompts, | |
| num_frames, | |
| diffusion_steps, | |
| multi_prompt=True, | |
| constraint_lst=model_constraints, | |
| cfg_weight=cfg_weight, | |
| num_samples=num_samples, | |
| cfg_type=cfg_type, | |
| **(postprocess_parameters | transitions_parameters), | |
| ) # [B, T, motion_rep_dim] | |
| joints_pos = pred_joints_output["posed_joints"] # [B, T, J, 3] | |
| joints_rot = pred_joints_output["global_rot_mats"] | |
| foot_contacts = pred_joints_output.get("foot_contacts") | |
| # Optionally project G1 to real robot DoF (1-DoF per joint, clamped) for display. | |
| if real_robot_rotations and isinstance(session.skeleton, G1Skeleton34): | |
| joints_pos, joints_rot = apply_g1_real_robot_projection( | |
| session.skeleton, | |
| pred_joints_output["posed_joints"], | |
| pred_joints_output["global_rot_mats"], | |
| clamp_to_limits=True, | |
| ) | |
| # Display on characters (callbacks keep this module UI-agnostic). | |
| clear_motions(client_id) | |
| # Keep one sample centered at the origin so constraints align. | |
| spread_factor = 1.0 # meters | |
| center_idx = num_samples // 2 | |
| x_trans = (np.arange(num_samples) - center_idx) * spread_factor | |
| for i in range(num_samples): | |
| cur_joints_pos = joints_pos[i] | |
| cur_joints_pos[..., 0] += x_trans[i] | |
| add_character_motion( | |
| client, | |
| session.skeleton, | |
| cur_joints_pos, | |
| joints_rot[i], | |
| foot_contacts[i], | |
| ) | |