Buckets:

rydlrKE's picture
download
raw
6.03 kB
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Classifier-free guidance wrapper for the denoiser at sampling time."""
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
CFG_TYPES = ["nocfg", "regular", "separated"]
class ClassifierFreeGuidedModel(nn.Module):
"""Wrapper around denoiser to use classifier-free guidance at sampling time."""
def __init__(self, model: nn.Module, cfg_type: Optional[str] = "separated"):
"""Wrap the denoiser for classifier-free guidance; cfg_type in CFG_TYPES (e.g. 'regular',
'nocfg')."""
super().__init__()
self.model = model
assert cfg_type in CFG_TYPES, f"Invalid cfg_type: {cfg_type}"
self.cfg_type_default = cfg_type
def forward(
self,
cfg_weight: Union[float, Tuple[float, float]],
x: torch.Tensor,
x_pad_mask: torch.Tensor,
text_feat: torch.Tensor,
text_feat_pad_mask: torch.Tensor,
timesteps: torch.Tensor,
first_heading_angle: Optional[torch.Tensor] = None,
motion_mask: Optional[torch.Tensor] = None,
observed_motion: Optional[torch.Tensor] = None,
cfg_type: Optional[str] = None,
) -> torch.Tensor:
"""
Args:
cfg_weight (float): guidance weight float or tuple of floats with (text, constraint) weights if using separated cfg
x (torch.Tensor): [B, T, dim_motion] current noisy motion
x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not
text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts
text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not
timesteps (torch.Tensor): [B,] current denoising step
motion_mask
observed_motion
neutral_joints (torch.Tensor): [B, nbjoints] The neutral joints of the motions
Returns:
torch.Tensor: same size as input x
"""
if cfg_type is None:
cfg_type = self.cfg_type_default
assert cfg_type in CFG_TYPES, f"Invalid cfg_type: {cfg_type}"
# batched conditional and uncond pass together
if cfg_type == "nocfg":
return self.model(
x,
x_pad_mask,
text_feat,
text_feat_pad_mask,
timesteps,
first_heading_angle=first_heading_angle,
motion_mask=motion_mask,
observed_motion=observed_motion,
)
elif cfg_type == "regular":
assert isinstance(cfg_weight, (float, int)), "cfg_weight must be a single float for regular CFG"
# out_uncond + w * (out_text_and_constraint - out_uncond)
text_feat = torch.concatenate([text_feat, 0 * text_feat], dim=0)
if motion_mask is not None:
motion_mask = torch.concatenate([motion_mask, 0 * motion_mask], dim=0)
if observed_motion is not None:
observed_motion = torch.concatenate([observed_motion, observed_motion], dim=0)
if first_heading_angle is not None:
first_heading_angle = torch.concatenate([first_heading_angle, first_heading_angle], dim=0)
out_cond_uncond = self.model(
torch.concatenate([x, x], dim=0),
torch.concatenate([x_pad_mask, x_pad_mask], dim=0),
text_feat,
torch.concatenate([text_feat_pad_mask, False * text_feat_pad_mask], dim=0),
torch.concatenate([timesteps, timesteps], dim=0),
first_heading_angle=first_heading_angle,
motion_mask=motion_mask,
observed_motion=observed_motion,
)
out, out_uncond = torch.chunk(out_cond_uncond, 2)
out_new = out_uncond + (cfg_weight * (out - out_uncond))
elif cfg_type == "separated":
assert len(cfg_weight) == 2, "cfg_weight must be a tuple of two floats for separated CFG"
# out_uncond + w_text * (out_text - out_uncond) + w_constraint * (out_constraint - out_uncond)
text_feat = torch.concatenate([text_feat, 0 * text_feat, 0 * text_feat], dim=0)
if motion_mask is not None:
motion_mask = torch.concatenate([0 * motion_mask, motion_mask, 0 * motion_mask], dim=0)
if observed_motion is not None:
observed_motion = torch.concatenate([observed_motion, observed_motion, observed_motion], dim=0)
if first_heading_angle is not None:
first_heading_angle = torch.concatenate(
[first_heading_angle, first_heading_angle, first_heading_angle],
dim=0,
)
out_cond_uncond = self.model(
torch.concatenate([x, x, x], dim=0),
torch.concatenate([x_pad_mask, x_pad_mask, x_pad_mask], dim=0),
text_feat,
torch.concatenate(
[
text_feat_pad_mask,
False * text_feat_pad_mask,
False * text_feat_pad_mask,
],
dim=0,
),
torch.concatenate([timesteps, timesteps, timesteps], dim=0),
first_heading_angle=first_heading_angle,
motion_mask=motion_mask,
observed_motion=observed_motion,
)
out_text, out_constraint, out_uncond = torch.chunk(out_cond_uncond, 3)
out_new = (
out_uncond + (cfg_weight[0] * (out_text - out_uncond)) + (cfg_weight[1] * (out_constraint - out_uncond))
)
else:
raise ValueError(f"Invalid cfg_type: {cfg_type}")
return out_new

Xet Storage Details

Size:
6.03 kB
·
Xet hash:
64e22baaf694a8eaf02b582cf90f1cd5d5d55421d126e4284129bf3deb5cca27

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.