movimento / kimodo /model /kimodo_model.py
Kimodo Bot
Add core kimodo package modules required by native demo
6d5047c
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Kimodo model: denoiser, text encoder, diffusion sampling, and post-processing."""
import logging
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from tqdm.auto import tqdm
from kimodo.constraints import FullBodyConstraintSet
from kimodo.motion_rep.feature_utils import compute_heading_angle, length_to_mask
from kimodo.postprocess import post_process_motion
from kimodo.sanitize import sanitize_texts
from kimodo.skeleton import SOMASkeleton30
from kimodo.tools import to_numpy
from .cfg import ClassifierFreeGuidedModel
from .diffusion import DDIMSampler, Diffusion
log = logging.getLogger(__name__)
class Kimodo(nn.Module):
"""Helper class for test time."""
def __init__(
self,
denoiser: nn.Module,
text_encoder: nn.Module,
num_base_steps: int,
device: Optional[Union[str, torch.device]] = None,
cfg_type: Optional[str] = "separated",
):
super().__init__()
self.denoiser = denoiser.eval()
if cfg_type is None:
cfg_type = "nocfg"
# Add Classifier-free guidance to the model if needed
self.denoiser = ClassifierFreeGuidedModel(self.denoiser, cfg_type=cfg_type)
self.motion_rep = denoiser.motion_rep
self.skeleton = self.motion_rep.skeleton
self.fps = denoiser.motion_rep.fps
self.diffusion = Diffusion(num_base_steps=num_base_steps)
self.sampler = DDIMSampler(self.diffusion)
self.text_encoder = text_encoder
self.device = device
# for classifier-free guidance
self.to(device)
@property
def output_skeleton(self):
"""Skeleton used for model output (somaskel77 for SOMA, else unchanged)."""
if isinstance(self.skeleton, SOMASkeleton30):
return self.skeleton.somaskel77
return self.skeleton
def train(self, mode: bool):
self.denoiser.train(mode)
return self
def eval(self):
self.denoiser.eval()
return self
def denoising_step(
self,
motion: torch.Tensor,
pad_mask: torch.Tensor,
text_feat: torch.Tensor,
text_pad_mask: torch.Tensor,
t: torch.Tensor,
first_heading_angle: Optional[torch.Tensor],
motion_mask: torch.Tensor,
observed_motion: torch.Tensor,
num_denoising_steps: torch.Tensor,
cfg_weight: Union[float, Tuple[float, float]],
guide_masks: Optional[Dict] = None,
cfg_type: Optional[str] = None,
) -> torch.Tensor:
"""Single denoising step.
Returns:
torch.Tensor: [B, T, D] noisy motion input to t-1
"""
# subsample timesteps
# NOTE: do this at every step due to ONNX export, i.e. num_samp_stepsmay change dynamically when
# running onnx version so need to account for that.
num_denoising_steps = num_denoising_steps[0]
use_timesteps, map_tensor = self.diffusion.space_timesteps(num_denoising_steps)
self.diffusion.calc_diffusion_vars(use_timesteps)
# first compute initial clean prediction from denoiser
t_map = map_tensor[t]
with torch.inference_mode():
pred_clean = self.denoiser(
cfg_weight,
motion,
pad_mask,
text_feat,
text_pad_mask,
t_map,
first_heading_angle,
motion_mask,
observed_motion,
cfg_type=cfg_type,
)
# sampler computes next step noisy motion
x_tm1 = self.sampler(use_timesteps, motion, pred_clean, t)
return x_tm1
def _multiprompt(
self,
prompts: list[str],
num_frames: int | list[int],
num_denoising_steps: int,
constraint_lst: Optional[list] = [],
cfg_weight: Optional[float] = [2.0, 2.0],
num_samples: Optional[int] = None,
cfg_type: Optional[str] = None,
return_numpy: bool = False,
first_heading_angle: Optional[torch.Tensor] = None,
# for transitioning
num_transition_frames: int = 5,
share_transition: bool = True,
percentage_transition_override=0.10,
# for postprocess
post_processing: bool = False,
root_margin: float = 0.04,
# progress bar
progress_bar=tqdm,
) -> torch.Tensor:
device = self.device
bs = num_samples
texts = sanitize_texts(prompts)
if isinstance(num_frames, int):
# same duration for all the segments
num_frames = [num_frames for _ in range(num_samples)]
tosqueeze = False
if num_samples is None:
num_samples = 1
tosqueeze = True
if constraint_lst is None:
constraint_lst = []
# Generate one chunck at a time
current_frame = 0
generated_motions = []
for idx, (text, num_frame) in enumerate(zip(texts, num_frames)):
texts_bs = [text for _ in range(num_samples)]
lengths = torch.tensor(
[num_frame for _ in range(num_samples)],
device=device,
)
is_first_motion = not generated_motions
observed_motion, motion_mask = None, None
# filter the constraint_lst to only keep the relevent ones
constraint_lst_base = [
constraint.crop_move(current_frame, current_frame + num_frame) for constraint in constraint_lst
] # this move temporally but not spatially
observed_motion, motion_mask = self.motion_rep.create_conditions_from_constraints_batched(
constraint_lst_base,
lengths,
to_normalize=False, # don't normalize yet, it needs to be moved around
device=device,
)
if not is_first_motion:
prev_num_frame = num_frames[idx - 1]
if share_transition:
# starting the transitioning earlier, to "share" the transition between A and B
# in any case, we still use "num_transition_frames" for conditioning
# we don't condition until the end of A
# we compute the number of frames of transition as a percentage of the last motion
nb_transition_frames = num_transition_frames + int(prev_num_frame * percentage_transition_override)
else:
nb_transition_frames = num_transition_frames
latest_motions = generated_motions.pop()
# remove the transition part of A (will be put back afterward)
generated_motions.append(latest_motions[:, :-nb_transition_frames])
latest_frames = latest_motions[:, -nb_transition_frames:]
# latest_frames[..., 2] += 0.5
last_output = self.motion_rep.inverse(
latest_frames,
is_normalized=False,
return_numpy=False,
)
smooth_root_2d = last_output["smooth_root_pos"][..., [0, 2]]
# add constraints at the begining to allow natural transitions
constraint_lst_transition = []
for batch_id in range(bs):
new_constraint = FullBodyConstraintSet(
self.skeleton,
torch.arange(num_transition_frames),
last_output["posed_joints"][batch_id, :num_transition_frames],
last_output["local_rot_mats"][batch_id, :num_transition_frames],
smooth_root_2d[batch_id, :num_transition_frames],
)
# new lists
constraint_lst_transition.append([new_constraint])
transition_lengths = torch.tensor(
[nb_transition_frames for _ in range(num_samples)],
device=device,
)
observed_motion_transition, motion_mask_transition = (
self.motion_rep.create_conditions_from_constraints_batched(
constraint_lst_transition,
transition_lengths,
to_normalize=False, # don't normalize yet
device=device,
)
)
# concatenate the obversed motion / motion mask
observed_motion = torch.cat([observed_motion_transition, observed_motion], axis=1)
motion_mask = torch.cat([motion_mask_transition, motion_mask], axis=1)
# we need to move each observed motion in the batch to the new starting points
last_smooth_root_2d = smooth_root_2d[:, 0]
observed_motion = self.motion_rep.translate_2d(
observed_motion, -last_smooth_root_2d
) # equivalent to: self.motion_rep.translate_2d_to_zero(observed_motion)
# remove dummy values after moving
observed_motion = observed_motion * motion_mask
lengths = lengths + transition_lengths
first_heading_angle = compute_heading_angle(last_output["posed_joints"], self.skeleton)[:, 0]
else:
if first_heading_angle is None:
# Start at 0 angle, but this will change afterward
first_heading_angle = torch.tensor([0.0] * bs, device=device)
else:
first_heading_angle = torch.as_tensor(first_heading_angle, device=device)
if first_heading_angle.numel() == 1:
first_heading_angle = first_heading_angle.repeat(bs)
observed_motion = self.motion_rep.normalize(observed_motion)
max_frames = max(lengths)
motion_pad_mask = length_to_mask(lengths)
motion = self._generate(
texts_bs,
max_frames,
num_denoising_steps=num_denoising_steps,
pad_mask=motion_pad_mask,
first_heading_angle=first_heading_angle,
motion_mask=motion_mask,
observed_motion=observed_motion,
cfg_weight=cfg_weight,
cfg_type=cfg_type,
)
motion = self.motion_rep.unnormalize(motion)
if not is_first_motion:
motion_with_transition = self.motion_rep.translate_2d(
motion,
last_smooth_root_2d,
)
motion = motion_with_transition[:, num_transition_frames:]
transition_frames = motion_with_transition[:, :num_transition_frames]
# for sharing = True, the new motion contains the very last of A
# linearly combine the previously generated transitions with the newly generated ones
# so that we linearly go from previous gen to new gen
alpha = torch.linspace(1, 0, num_transition_frames, device=device)[:, None]
new_transition_frames = (
latest_frames[:, :num_transition_frames] * alpha + (1 - alpha) * transition_frames
)
# add new transitions frames for A (merging with B predition of the history)
# for share_transition == True, this remove (do not add back) a small part of the end of A
# the small last part of A has been re-generated by B
generated_motions.append(new_transition_frames)
# motion[..., 2] += 0.5
generated_motions.append(motion)
current_frame += num_frame
generated_motions = torch.cat(generated_motions, axis=1) # temporal axis (b, t, d)
if tosqueeze:
generated_motions = generated_motions[0]
output = self.motion_rep.inverse(
generated_motions,
is_normalized=False,
return_numpy=False,
)
# Apply post-processing if requested
if post_processing:
corrected = post_process_motion(
output["local_rot_mats"],
output["root_positions"],
output["foot_contacts"],
self.skeleton,
constraint_lst,
root_margin=root_margin,
)
output.update(corrected)
# Convert SOMA output to somaskel77 for external API
if isinstance(self.skeleton, SOMASkeleton30):
output = self.skeleton.output_to_SOMASkeleton77(output)
# Convert to numpy if requested
if return_numpy:
output = to_numpy(output)
return output
def __call__(
self,
prompts: str | list[str],
num_frames: int | list[int],
num_denoising_steps: int,
multi_prompt: bool = False,
constraint_lst: Optional[list] = [],
cfg_weight: Optional[float] = [2.0, 2.0],
num_samples: Optional[int] = None,
cfg_type: Optional[str] = None,
return_numpy: bool = False,
first_heading_angle: Optional[torch.Tensor] = None,
# for transitioning
num_transition_frames: int = 5,
share_transition: bool = True,
percentage_transition_override=0.10,
# for postprocess
post_processing: bool = False,
root_margin: float = 0.04,
# progress bar
progress_bar=tqdm,
) -> dict:
"""Generate motion from text prompts and optional kinematic constraints.
When a single prompt/num_frames pair is given, one motion is generated.
Passing lists of prompts and/or num_frames produces a batch of
independent motions. With ``multi_prompt=True``, the prompts are
treated as sequential segments that are generated and stitched together
with smooth transitions.
Args:
prompts: One or more text descriptions of the desired motion.
A single string generates one sample; a list generates a batch
(or sequential segments when ``multi_prompt=True``).
num_frames: Duration of the generated motion in frames. Can be a
single int applied to every prompt or a per-prompt list.
num_denoising_steps: Number of DDIM denoising steps. More steps
generally improve quality at the cost of speed.
multi_prompt: If ``True``, treat ``prompts`` as an ordered sequence
of segments and concatenate them with transitions.
constraint_lst: Per-sample list of kinematic constraints (e.g.
keyframe poses, end-effector targets, 2-D paths). Pass an
empty list for unconstrained generation.
cfg_weight: Classifier-free guidance scale(s). A two-element list
``[text_cfg, constraint_cfg]`` controls text and constraint
guidance independently.
num_samples: Number of samples to generate.
cfg_type: Override the default CFG strategy set at init
(e.g. ``"separated"``).
return_numpy: If ``True``, convert all output tensors to numpy
arrays.
first_heading_angle: Initial body heading in radians. Shape
``(B,)`` or scalar. Defaults to ``0`` (facing +Z).
num_transition_frames: Number of overlapping frames used to blend
consecutive segments in multi-prompt mode.
share_transition: If ``True``, transition frames are shared between
adjacent segments rather than appended.
percentage_transition_override: Fraction of each segment's length
that may be overridden by the transition blend.
post_processing: If ``True``, apply post-processing
(foot-skate cleanup and constraint enforcement).
root_margin: Horizontal margin (in meters) used by the post-processor
to determine when to correct root motion. When root deviates more than
margin from the constraint, the post-processor will correct it.
progress_bar: Callable wrapping an iterable to display progress
(default: ``tqdm``). Pass a no-op to silence output.
Returns:
dict: A dictionary of motion tensors (or numpy arrays if
``return_numpy=True``) with the following keys:
- ``local_rot_mats`` – Local joint rotations as rotation matrices.
- ``global_rot_mats`` – Global joint rotations as rotation matrices.
- ``posed_joints`` – Joint positions in world space.
- ``root_positions`` – Root joint positions.
- ``smooth_root_pos`` – Smoothed root trajectory.
- ``foot_contacts`` – Boolean foot-contact labels [left heel, left toe, right heel, right toe].
- ``global_root_heading`` – Root heading angle over time.
"""
device = self.device
if multi_prompt:
# multi prompt generation
return self._multiprompt(
prompts,
num_frames,
num_denoising_steps,
constraint_lst,
cfg_weight,
num_samples,
cfg_type,
return_numpy,
first_heading_angle,
num_transition_frames,
share_transition,
percentage_transition_override,
post_processing,
root_margin,
progress_bar,
)
# Input checking
tosqueeze = False
if isinstance(prompts, list) and isinstance(num_frames, list):
assert len(prompts) == len(num_frames), "The number of prompts should match the number of num_frames."
num_samples = len(prompts)
elif isinstance(prompts, list):
num_samples = len(prompts)
num_frames = [num_frames for _ in range(num_samples)]
elif isinstance(num_frames, list):
num_samples = len(num_frames)
prompts = [prompts for _ in range(num_samples)]
else:
if num_samples is None:
tosqueeze = True
num_samples = 1
prompts = [prompts for _ in range(num_samples)]
num_frames = [num_frames for _ in range(num_samples)]
bs = num_samples
texts = sanitize_texts(prompts)
lengths = torch.tensor(
num_frames,
device=device,
)
max_frames = max(lengths)
motion_pad_mask = length_to_mask(lengths)
if first_heading_angle is None:
# Start at 0 angle
first_heading_angle = torch.tensor([0.0] * bs, device=device)
else:
first_heading_angle = torch.as_tensor(first_heading_angle, device=device)
if first_heading_angle.numel() == 1:
first_heading_angle = first_heading_angle.repeat(bs)
observed_motion, motion_mask = None, None
if constraint_lst:
observed_motion, motion_mask = self.motion_rep.create_conditions_from_constraints_batched(
constraint_lst,
lengths,
to_normalize=True,
device=device,
)
motion = self._generate(
texts,
max_frames,
num_denoising_steps=num_denoising_steps,
pad_mask=motion_pad_mask,
first_heading_angle=first_heading_angle,
motion_mask=motion_mask,
observed_motion=observed_motion,
cfg_weight=cfg_weight,
cfg_type=cfg_type,
progress_bar=progress_bar,
)
if tosqueeze:
motion = motion[0]
output = self.motion_rep.inverse(
motion,
is_normalized=True,
return_numpy=False, # Keep as tensor for potential post-processing
)
# Apply post-processing if requested
if post_processing:
corrected = post_process_motion(
output["local_rot_mats"],
output["root_positions"],
output["foot_contacts"],
self.skeleton,
constraint_lst,
root_margin=root_margin,
)
# key frame outputs / foot contacts are not changed
output.update(corrected)
# Convert SOMA output to somaskel77 for external API
if isinstance(self.skeleton, SOMASkeleton30):
output = self.skeleton.output_to_SOMASkeleton77(output)
# Convert to numpy if requested
if return_numpy:
output = to_numpy(output)
return output
def _generate(
self,
texts: List[str],
max_frames: int,
num_denoising_steps: int,
pad_mask: torch.Tensor,
first_heading_angle: Optional[torch.Tensor],
motion_mask: torch.Tensor,
observed_motion: torch.Tensor,
cfg_weight: Optional[float] = 2.0,
text_feat: Optional[torch.Tensor] = None,
text_pad_mask: Optional[torch.Tensor] = None,
guide_masks: Optional[Dict] = None,
cfg_type: Optional[str] = None,
progress_bar=tqdm,
) -> torch.Tensor:
"""Sample full denoising loop.
Args:
texts (List[str]): batch of text prompts to use for sampling (if text_feat is not passed in)
"""
device = self.device
if text_feat is None:
assert text_pad_mask is None
log.info("Encoding text...")
text_feat, text_length = self.text_encoder(texts)
text_feat = text_feat.to(device)
# handle empty string (set to zero)
empty_text_mask = [len(text.strip()) == 0 for text in texts]
text_feat[empty_text_mask] = 0
# Create the pad mask for the text
batch_size, maxlen = text_feat.shape[:2]
tensor_text_length = torch.tensor(text_length, device=device)
tensor_text_length[empty_text_mask] = 0
text_pad_mask = torch.arange(maxlen, device=device).expand(batch_size, maxlen) < tensor_text_length[:, None]
if motion_mask is not None:
if motion_mask.dtype == torch.bool:
motion_mask = 1 * motion_mask
batch_size = text_feat.shape[0]
# sample loop
indices = list(range(num_denoising_steps))[::-1]
shape = (batch_size, max_frames, self.motion_rep.motion_rep_dim)
cur_mot = torch.randn(shape, device=self.device)
num_denoising_steps = torch.tensor(
[num_denoising_steps], device=self.device
) # this and t need to be tensor for onnx export
# init diffusion with correct num steps before looping
use_timesteps = self.diffusion.space_timesteps(num_denoising_steps[0])[0]
self.diffusion.calc_diffusion_vars(use_timesteps)
for i in progress_bar(indices):
t = torch.tensor([i] * cur_mot.size(0), device=self.device)
with torch.inference_mode():
cur_mot = self.denoising_step(
cur_mot,
pad_mask,
text_feat,
text_pad_mask,
t,
first_heading_angle,
motion_mask,
observed_motion,
num_denoising_steps,
cfg_weight,
guide_masks=guide_masks,
cfg_type=cfg_type,
)
return cur_mot