# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """Kimodo model: denoiser, text encoder, diffusion sampling, and post-processing.""" import logging from typing import Dict, List, Optional, Tuple, Union import torch from torch import nn from tqdm.auto import tqdm from kimodo.constraints import FullBodyConstraintSet from kimodo.motion_rep.feature_utils import compute_heading_angle, length_to_mask from kimodo.postprocess import post_process_motion from kimodo.sanitize import sanitize_texts from kimodo.skeleton import SOMASkeleton30 from kimodo.tools import to_numpy from .cfg import ClassifierFreeGuidedModel from .diffusion import DDIMSampler, Diffusion log = logging.getLogger(__name__) class Kimodo(nn.Module): """Helper class for test time.""" def __init__( self, denoiser: nn.Module, text_encoder: nn.Module, num_base_steps: int, device: Optional[Union[str, torch.device]] = None, cfg_type: Optional[str] = "separated", ): super().__init__() self.denoiser = denoiser.eval() if cfg_type is None: cfg_type = "nocfg" # Add Classifier-free guidance to the model if needed self.denoiser = ClassifierFreeGuidedModel(self.denoiser, cfg_type=cfg_type) self.motion_rep = denoiser.motion_rep self.skeleton = self.motion_rep.skeleton self.fps = denoiser.motion_rep.fps self.diffusion = Diffusion(num_base_steps=num_base_steps) self.sampler = DDIMSampler(self.diffusion) self.text_encoder = text_encoder self.device = device # for classifier-free guidance self.to(device) @property def output_skeleton(self): """Skeleton used for model output (somaskel77 for SOMA, else unchanged).""" if isinstance(self.skeleton, SOMASkeleton30): return self.skeleton.somaskel77 return self.skeleton def train(self, mode: bool): self.denoiser.train(mode) return self def eval(self): self.denoiser.eval() return self def denoising_step( self, motion: torch.Tensor, pad_mask: torch.Tensor, text_feat: torch.Tensor, text_pad_mask: torch.Tensor, t: torch.Tensor, first_heading_angle: Optional[torch.Tensor], motion_mask: torch.Tensor, observed_motion: torch.Tensor, num_denoising_steps: torch.Tensor, cfg_weight: Union[float, Tuple[float, float]], guide_masks: Optional[Dict] = None, cfg_type: Optional[str] = None, ) -> torch.Tensor: """Single denoising step. Returns: torch.Tensor: [B, T, D] noisy motion input to t-1 """ # subsample timesteps # NOTE: do this at every step due to ONNX export, i.e. num_samp_stepsmay change dynamically when # running onnx version so need to account for that. num_denoising_steps = num_denoising_steps[0] use_timesteps, map_tensor = self.diffusion.space_timesteps(num_denoising_steps) self.diffusion.calc_diffusion_vars(use_timesteps) # first compute initial clean prediction from denoiser t_map = map_tensor[t] with torch.inference_mode(): pred_clean = self.denoiser( cfg_weight, motion, pad_mask, text_feat, text_pad_mask, t_map, first_heading_angle, motion_mask, observed_motion, cfg_type=cfg_type, ) # sampler computes next step noisy motion x_tm1 = self.sampler(use_timesteps, motion, pred_clean, t) return x_tm1 def _multiprompt( self, prompts: list[str], num_frames: int | list[int], num_denoising_steps: int, constraint_lst: Optional[list] = [], cfg_weight: Optional[float] = [2.0, 2.0], num_samples: Optional[int] = None, cfg_type: Optional[str] = None, return_numpy: bool = False, first_heading_angle: Optional[torch.Tensor] = None, # for transitioning num_transition_frames: int = 5, share_transition: bool = True, percentage_transition_override=0.10, # for postprocess post_processing: bool = False, root_margin: float = 0.04, # progress bar progress_bar=tqdm, ) -> torch.Tensor: device = self.device bs = num_samples texts = sanitize_texts(prompts) if isinstance(num_frames, int): # same duration for all the segments num_frames = [num_frames for _ in range(num_samples)] tosqueeze = False if num_samples is None: num_samples = 1 tosqueeze = True if constraint_lst is None: constraint_lst = [] # Generate one chunck at a time current_frame = 0 generated_motions = [] for idx, (text, num_frame) in enumerate(zip(texts, num_frames)): texts_bs = [text for _ in range(num_samples)] lengths = torch.tensor( [num_frame for _ in range(num_samples)], device=device, ) is_first_motion = not generated_motions observed_motion, motion_mask = None, None # filter the constraint_lst to only keep the relevent ones constraint_lst_base = [ constraint.crop_move(current_frame, current_frame + num_frame) for constraint in constraint_lst ] # this move temporally but not spatially observed_motion, motion_mask = self.motion_rep.create_conditions_from_constraints_batched( constraint_lst_base, lengths, to_normalize=False, # don't normalize yet, it needs to be moved around device=device, ) if not is_first_motion: prev_num_frame = num_frames[idx - 1] if share_transition: # starting the transitioning earlier, to "share" the transition between A and B # in any case, we still use "num_transition_frames" for conditioning # we don't condition until the end of A # we compute the number of frames of transition as a percentage of the last motion nb_transition_frames = num_transition_frames + int(prev_num_frame * percentage_transition_override) else: nb_transition_frames = num_transition_frames latest_motions = generated_motions.pop() # remove the transition part of A (will be put back afterward) generated_motions.append(latest_motions[:, :-nb_transition_frames]) latest_frames = latest_motions[:, -nb_transition_frames:] # latest_frames[..., 2] += 0.5 last_output = self.motion_rep.inverse( latest_frames, is_normalized=False, return_numpy=False, ) smooth_root_2d = last_output["smooth_root_pos"][..., [0, 2]] # add constraints at the begining to allow natural transitions constraint_lst_transition = [] for batch_id in range(bs): new_constraint = FullBodyConstraintSet( self.skeleton, torch.arange(num_transition_frames), last_output["posed_joints"][batch_id, :num_transition_frames], last_output["local_rot_mats"][batch_id, :num_transition_frames], smooth_root_2d[batch_id, :num_transition_frames], ) # new lists constraint_lst_transition.append([new_constraint]) transition_lengths = torch.tensor( [nb_transition_frames for _ in range(num_samples)], device=device, ) observed_motion_transition, motion_mask_transition = ( self.motion_rep.create_conditions_from_constraints_batched( constraint_lst_transition, transition_lengths, to_normalize=False, # don't normalize yet device=device, ) ) # concatenate the obversed motion / motion mask observed_motion = torch.cat([observed_motion_transition, observed_motion], axis=1) motion_mask = torch.cat([motion_mask_transition, motion_mask], axis=1) # we need to move each observed motion in the batch to the new starting points last_smooth_root_2d = smooth_root_2d[:, 0] observed_motion = self.motion_rep.translate_2d( observed_motion, -last_smooth_root_2d ) # equivalent to: self.motion_rep.translate_2d_to_zero(observed_motion) # remove dummy values after moving observed_motion = observed_motion * motion_mask lengths = lengths + transition_lengths first_heading_angle = compute_heading_angle(last_output["posed_joints"], self.skeleton)[:, 0] else: if first_heading_angle is None: # Start at 0 angle, but this will change afterward first_heading_angle = torch.tensor([0.0] * bs, device=device) else: first_heading_angle = torch.as_tensor(first_heading_angle, device=device) if first_heading_angle.numel() == 1: first_heading_angle = first_heading_angle.repeat(bs) observed_motion = self.motion_rep.normalize(observed_motion) max_frames = max(lengths) motion_pad_mask = length_to_mask(lengths) motion = self._generate( texts_bs, max_frames, num_denoising_steps=num_denoising_steps, pad_mask=motion_pad_mask, first_heading_angle=first_heading_angle, motion_mask=motion_mask, observed_motion=observed_motion, cfg_weight=cfg_weight, cfg_type=cfg_type, ) motion = self.motion_rep.unnormalize(motion) if not is_first_motion: motion_with_transition = self.motion_rep.translate_2d( motion, last_smooth_root_2d, ) motion = motion_with_transition[:, num_transition_frames:] transition_frames = motion_with_transition[:, :num_transition_frames] # for sharing = True, the new motion contains the very last of A # linearly combine the previously generated transitions with the newly generated ones # so that we linearly go from previous gen to new gen alpha = torch.linspace(1, 0, num_transition_frames, device=device)[:, None] new_transition_frames = ( latest_frames[:, :num_transition_frames] * alpha + (1 - alpha) * transition_frames ) # add new transitions frames for A (merging with B predition of the history) # for share_transition == True, this remove (do not add back) a small part of the end of A # the small last part of A has been re-generated by B generated_motions.append(new_transition_frames) # motion[..., 2] += 0.5 generated_motions.append(motion) current_frame += num_frame generated_motions = torch.cat(generated_motions, axis=1) # temporal axis (b, t, d) if tosqueeze: generated_motions = generated_motions[0] output = self.motion_rep.inverse( generated_motions, is_normalized=False, return_numpy=False, ) # Apply post-processing if requested if post_processing: corrected = post_process_motion( output["local_rot_mats"], output["root_positions"], output["foot_contacts"], self.skeleton, constraint_lst, root_margin=root_margin, ) output.update(corrected) # Convert SOMA output to somaskel77 for external API if isinstance(self.skeleton, SOMASkeleton30): output = self.skeleton.output_to_SOMASkeleton77(output) # Convert to numpy if requested if return_numpy: output = to_numpy(output) return output def __call__( self, prompts: str | list[str], num_frames: int | list[int], num_denoising_steps: int, multi_prompt: bool = False, constraint_lst: Optional[list] = [], cfg_weight: Optional[float] = [2.0, 2.0], num_samples: Optional[int] = None, cfg_type: Optional[str] = None, return_numpy: bool = False, first_heading_angle: Optional[torch.Tensor] = None, # for transitioning num_transition_frames: int = 5, share_transition: bool = True, percentage_transition_override=0.10, # for postprocess post_processing: bool = False, root_margin: float = 0.04, # progress bar progress_bar=tqdm, ) -> dict: """Generate motion from text prompts and optional kinematic constraints. When a single prompt/num_frames pair is given, one motion is generated. Passing lists of prompts and/or num_frames produces a batch of independent motions. With ``multi_prompt=True``, the prompts are treated as sequential segments that are generated and stitched together with smooth transitions. Args: prompts: One or more text descriptions of the desired motion. A single string generates one sample; a list generates a batch (or sequential segments when ``multi_prompt=True``). num_frames: Duration of the generated motion in frames. Can be a single int applied to every prompt or a per-prompt list. num_denoising_steps: Number of DDIM denoising steps. More steps generally improve quality at the cost of speed. multi_prompt: If ``True``, treat ``prompts`` as an ordered sequence of segments and concatenate them with transitions. constraint_lst: Per-sample list of kinematic constraints (e.g. keyframe poses, end-effector targets, 2-D paths). Pass an empty list for unconstrained generation. cfg_weight: Classifier-free guidance scale(s). A two-element list ``[text_cfg, constraint_cfg]`` controls text and constraint guidance independently. num_samples: Number of samples to generate. cfg_type: Override the default CFG strategy set at init (e.g. ``"separated"``). return_numpy: If ``True``, convert all output tensors to numpy arrays. first_heading_angle: Initial body heading in radians. Shape ``(B,)`` or scalar. Defaults to ``0`` (facing +Z). num_transition_frames: Number of overlapping frames used to blend consecutive segments in multi-prompt mode. share_transition: If ``True``, transition frames are shared between adjacent segments rather than appended. percentage_transition_override: Fraction of each segment's length that may be overridden by the transition blend. post_processing: If ``True``, apply post-processing (foot-skate cleanup and constraint enforcement). root_margin: Horizontal margin (in meters) used by the post-processor to determine when to correct root motion. When root deviates more than margin from the constraint, the post-processor will correct it. progress_bar: Callable wrapping an iterable to display progress (default: ``tqdm``). Pass a no-op to silence output. Returns: dict: A dictionary of motion tensors (or numpy arrays if ``return_numpy=True``) with the following keys: - ``local_rot_mats`` – Local joint rotations as rotation matrices. - ``global_rot_mats`` – Global joint rotations as rotation matrices. - ``posed_joints`` – Joint positions in world space. - ``root_positions`` – Root joint positions. - ``smooth_root_pos`` – Smoothed root trajectory. - ``foot_contacts`` – Boolean foot-contact labels [left heel, left toe, right heel, right toe]. - ``global_root_heading`` – Root heading angle over time. """ device = self.device if multi_prompt: # multi prompt generation return self._multiprompt( prompts, num_frames, num_denoising_steps, constraint_lst, cfg_weight, num_samples, cfg_type, return_numpy, first_heading_angle, num_transition_frames, share_transition, percentage_transition_override, post_processing, root_margin, progress_bar, ) # Input checking tosqueeze = False if isinstance(prompts, list) and isinstance(num_frames, list): assert len(prompts) == len(num_frames), "The number of prompts should match the number of num_frames." num_samples = len(prompts) elif isinstance(prompts, list): num_samples = len(prompts) num_frames = [num_frames for _ in range(num_samples)] elif isinstance(num_frames, list): num_samples = len(num_frames) prompts = [prompts for _ in range(num_samples)] else: if num_samples is None: tosqueeze = True num_samples = 1 prompts = [prompts for _ in range(num_samples)] num_frames = [num_frames for _ in range(num_samples)] bs = num_samples texts = sanitize_texts(prompts) lengths = torch.tensor( num_frames, device=device, ) max_frames = max(lengths) motion_pad_mask = length_to_mask(lengths) if first_heading_angle is None: # Start at 0 angle first_heading_angle = torch.tensor([0.0] * bs, device=device) else: first_heading_angle = torch.as_tensor(first_heading_angle, device=device) if first_heading_angle.numel() == 1: first_heading_angle = first_heading_angle.repeat(bs) observed_motion, motion_mask = None, None if constraint_lst: observed_motion, motion_mask = self.motion_rep.create_conditions_from_constraints_batched( constraint_lst, lengths, to_normalize=True, device=device, ) motion = self._generate( texts, max_frames, num_denoising_steps=num_denoising_steps, pad_mask=motion_pad_mask, first_heading_angle=first_heading_angle, motion_mask=motion_mask, observed_motion=observed_motion, cfg_weight=cfg_weight, cfg_type=cfg_type, progress_bar=progress_bar, ) if tosqueeze: motion = motion[0] output = self.motion_rep.inverse( motion, is_normalized=True, return_numpy=False, # Keep as tensor for potential post-processing ) # Apply post-processing if requested if post_processing: corrected = post_process_motion( output["local_rot_mats"], output["root_positions"], output["foot_contacts"], self.skeleton, constraint_lst, root_margin=root_margin, ) # key frame outputs / foot contacts are not changed output.update(corrected) # Convert SOMA output to somaskel77 for external API if isinstance(self.skeleton, SOMASkeleton30): output = self.skeleton.output_to_SOMASkeleton77(output) # Convert to numpy if requested if return_numpy: output = to_numpy(output) return output def _generate( self, texts: List[str], max_frames: int, num_denoising_steps: int, pad_mask: torch.Tensor, first_heading_angle: Optional[torch.Tensor], motion_mask: torch.Tensor, observed_motion: torch.Tensor, cfg_weight: Optional[float] = 2.0, text_feat: Optional[torch.Tensor] = None, text_pad_mask: Optional[torch.Tensor] = None, guide_masks: Optional[Dict] = None, cfg_type: Optional[str] = None, progress_bar=tqdm, ) -> torch.Tensor: """Sample full denoising loop. Args: texts (List[str]): batch of text prompts to use for sampling (if text_feat is not passed in) """ device = self.device if text_feat is None: assert text_pad_mask is None log.info("Encoding text...") text_feat, text_length = self.text_encoder(texts) text_feat = text_feat.to(device) # handle empty string (set to zero) empty_text_mask = [len(text.strip()) == 0 for text in texts] text_feat[empty_text_mask] = 0 # Create the pad mask for the text batch_size, maxlen = text_feat.shape[:2] tensor_text_length = torch.tensor(text_length, device=device) tensor_text_length[empty_text_mask] = 0 text_pad_mask = torch.arange(maxlen, device=device).expand(batch_size, maxlen) < tensor_text_length[:, None] if motion_mask is not None: if motion_mask.dtype == torch.bool: motion_mask = 1 * motion_mask batch_size = text_feat.shape[0] # sample loop indices = list(range(num_denoising_steps))[::-1] shape = (batch_size, max_frames, self.motion_rep.motion_rep_dim) cur_mot = torch.randn(shape, device=self.device) num_denoising_steps = torch.tensor( [num_denoising_steps], device=self.device ) # this and t need to be tensor for onnx export # init diffusion with correct num steps before looping use_timesteps = self.diffusion.space_timesteps(num_denoising_steps[0])[0] self.diffusion.calc_diffusion_vars(use_timesteps) for i in progress_bar(indices): t = torch.tensor([i] * cur_mot.size(0), device=self.device) with torch.inference_mode(): cur_mot = self.denoising_step( cur_mot, pad_mask, text_feat, text_pad_mask, t, first_heading_angle, motion_mask, observed_motion, num_denoising_steps, cfg_weight, guide_masks=guide_masks, cfg_type=cfg_type, ) return cur_mot