movimento / kimodo /model /backbone.py
Kimodo Bot
Add core kimodo package modules required by native demo
6d5047c
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Transformer backbone: padding, masking, and encoder stack for the denoiser."""
import logging
from typing import Optional, Union
import torch
from omegaconf import ListConfig
from pydantic.dataclasses import dataclass
from torch import Tensor, nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from kimodo.tools import validate
log = logging.getLogger(__name__)
def pad_x_and_mask_to_fixed_size(x: Tensor, mask: Tensor, size: int):
"""Pad a feature vector x and the mask to always have the same size.
Args:
x (torch.Tensor): [B, T, D]
mask (torch.Tensor): [B, T]
size (int)
Returns:
torch.Tensor: [B, size, D]
torch.Tensor: [B, size]
"""
batch_size, cur_max_size, dim = x.shape[0], x.shape[1], x.shape[2]
if cur_max_size == size:
# already padded to this size, probably in the collate function
return x, mask
if cur_max_size > size:
# This issue should have been handled in the collate function
# usefull as a check for test time
log.warn("The size of the tensor is larger than the maximum size. Cropping the input..")
cur_max_size = size
new_x = torch.zeros(
(batch_size, size, dim),
dtype=x.dtype,
device=x.device,
)
new_x[:, :cur_max_size] = x
# same for the mask
new_mask = torch.zeros(
(batch_size, size),
dtype=mask.dtype,
device=mask.device,
)
new_mask[:, :cur_max_size] = mask
return new_x, new_mask
@dataclass(frozen=True, config=dict(extra="forbid", arbitrary_types_allowed=True))
class TransformerEncoderBlockConfig:
"""Configuration for the transformer encoder backbone."""
# input features dimension
input_dim: int
# output features dimension
output_dim: int
# skeleton object
skeleton: object
# dimension of the text embeddings
llm_shape: Union[list[int], ListConfig]
# mask the text or not
use_text_mask: bool
# latent dimension of the model
latent_dim: int
# dimension of the feedforward network in transformer
ff_size: int
# num layers in transformer
num_layers: int
# num heads in transformer
num_heads: int
# activation in transformer
activation: str
# dropout rate for the transformer
dropout: float
# dropout rate for the positional embeddings
pe_dropout: float
# use norm first or not
norm_first: bool = False
# artificially extend the number of text tokens
num_text_tokens_override: Optional[int] = None
# Input first heading angle
input_first_heading_angle: bool = False
class TransformerEncoderBlock(nn.Module):
@validate(TransformerEncoderBlockConfig, save_args=True, super_init=True)
def __init__(self, conf):
self.nbjoints = self.skeleton.nbjoints
llm_dim = self.llm_shape[-1]
self.embed_text = nn.Linear(llm_dim, self.latent_dim)
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.pe_dropout)
# maximum number of tokens
self.num_text_tokens = self.llm_shape[0]
if self.num_text_tokens_override is not None:
self.num_text_tokens = self.num_text_tokens_override
self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)
self.input_linear = nn.Linear(self.input_dim, self.latent_dim)
self.output_linear = nn.Linear(self.latent_dim, self.output_dim)
self.linear_first_heading_angle = nn.Linear(2, self.latent_dim)
trans_enc_layer = TransformerEncoderLayer(
d_model=self.latent_dim,
nhead=self.num_heads,
dim_feedforward=self.ff_size,
dropout=self.dropout,
activation=self.activation,
batch_first=True,
norm_first=self.norm_first,
)
self.seqTransEncoder = TransformerEncoder(
trans_enc_layer,
num_layers=self.num_layers,
enable_nested_tensor=False,
)
def forward(
self,
x: Tensor,
x_pad_mask: torch.Tensor,
text_feat: torch.Tensor,
text_feat_pad_mask: torch.Tensor,
timesteps: Tensor,
first_heading_angle: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
x (torch.Tensor): [B, T, dim_motion] current noisy motion
x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not
text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts
text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not
timesteps (torch.Tensor): [B,] current denoising step
Returns:
torch.Tensor: [B, T, output_dim]
"""
batch_size = len(x)
x = self.input_linear(x) # [B, T, D]
# Pad the text tokens + mask to always have the same size == self.num_text_tokens
# done here if it was not done in the collate function
if self.num_text_tokens is not None:
text_feat, text_feat_pad_mask = pad_x_and_mask_to_fixed_size(
text_feat,
text_feat_pad_mask,
self.num_text_tokens,
)
# Encode the text features and the time information
emb_text = self.embed_text(text_feat) # [B, max_text_len, D]
emb_time = self.embed_timestep(timesteps) # [B, 1, D]
# Create mask for the time information
time_mask = torch.ones((batch_size, 1), dtype=bool, device=x.device)
# Create the prefix features (text, time, etc): [B, max_text_len + 1 + etc]
prefix_feats = torch.cat((emb_text, emb_time), axis=1)
# Behavior from old code: not use text mask -> True for all the tokens
if not self.use_text_mask:
text_feat_pad_mask = torch.ones(
(batch_size, emb_text.shape[1]),
dtype=torch.bool,
device=x.device,
)
prefix_mask = torch.cat((text_feat_pad_mask, time_mask), axis=1)
# add the input first heading angle
if self.input_first_heading_angle:
assert first_heading_angle is not None, "The first heading angle is mandatory for this model"
# cos(angle) / sin(angle)
first_heading_angle_feats = torch.stack(
[
torch.cos(first_heading_angle),
torch.sin(first_heading_angle),
],
axis=-1,
)
first_heading_angle_feats = self.linear_first_heading_angle(first_heading_angle_feats)
first_heading_angle_feats = first_heading_angle_feats[:, None] # for cat
first_heading_angle_mask = torch.ones(
(batch_size, 1),
dtype=bool,
device=x.device,
)
prefix_feats = torch.cat((prefix_feats, first_heading_angle_feats), axis=1)
prefix_mask = torch.cat((prefix_mask, first_heading_angle_mask), axis=1)
# compute the number of prefix features
pose_start_ind = prefix_feats.shape[1]
# Concatenate prefix and x: [B, len(prefix) + T, D]
xseq = torch.cat((prefix_feats, x), axis=1)
# Concatenate the masks and negate them: [B, len(prefix) + T]
src_key_padding_mask = ~torch.cat((prefix_mask, x_pad_mask), axis=1)
# Add positional encoding
xseq = self.sequence_pos_encoder(xseq)
# Input to the transformer and keep the motion indexes
if isinstance(self.seqTransEncoder, nn.TransformerEncoder):
assert not self.seqTransEncoder.use_nested_tensor, "Flash attention should be disabled due to bug!"
output = self.seqTransEncoder(
xseq,
src_key_padding_mask=src_key_padding_mask,
)
output = output[:, pose_start_ind:] # [B, T, D]
output = self.output_linear(output) # [B, T, OD]
return output
class PositionalEncoding(nn.Module):
"""Non-learned positional encoding."""
def __init__(
self,
d_model: int,
dropout: Optional[float] = 0.1,
max_len: Optional[int] = 5000,
):
"""
Args:
d_model (int): input dim
dropout (Optional[float] = 0.1): dropout probability on output
max_len (Optional[int] = 5000): maximum sequence length
"""
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# Note: have to replace torch.exp() and math.log() with torch.pow()
# due to MKL exp() and ln() throws floating point exceptions on certain CPUs
# see corresponding commit and MR
div_term = torch.pow(10000.0, -torch.arange(0, d_model, 2).float() / d_model)
# div_term = torch.exp(
# torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
# )
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # [1, T, D]
self.register_buffer("pe", pe, persistent=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply positional encoding to input sequence.
Args:
x (torch.Tensor): [B, T, D] input motion sequence
Returns:
torch.Tensor: [B, T, D] input motion with PE added to it (and optionally dropout)
"""
x = x + self.pe[:, : x.shape[1], :]
return self.dropout(x)
class TimestepEmbedder(nn.Module):
"""Encoder for diffusion step."""
def __init__(self, latent_dim: int, sequence_pos_encoder: PositionalEncoding):
"""
Args:
latent_dim (int): dim to encode to
sequence_pos_encoder (PositionalEncoding): the PE to use on timesteps
"""
super().__init__()
self.latent_dim = latent_dim
self.sequence_pos_encoder = sequence_pos_encoder
time_embed_dim = self.latent_dim
self.time_embed = nn.Sequential(
nn.Linear(self.latent_dim, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
"""Embed timesteps by adding PE then going through linear layers.
Args:
timesteps (torch.Tensor): [B]
Returns:
torch.Tensor: [B, 1, D]
"""
return self.time_embed(self.sequence_pos_encoder.pe.transpose(0, 1)[timesteps])