Spaces:
Runtime error
Runtime error
File size: 6,025 Bytes
6d5047c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Classifier-free guidance wrapper for the denoiser at sampling time."""
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
CFG_TYPES = ["nocfg", "regular", "separated"]
class ClassifierFreeGuidedModel(nn.Module):
"""Wrapper around denoiser to use classifier-free guidance at sampling time."""
def __init__(self, model: nn.Module, cfg_type: Optional[str] = "separated"):
"""Wrap the denoiser for classifier-free guidance; cfg_type in CFG_TYPES (e.g. 'regular',
'nocfg')."""
super().__init__()
self.model = model
assert cfg_type in CFG_TYPES, f"Invalid cfg_type: {cfg_type}"
self.cfg_type_default = cfg_type
def forward(
self,
cfg_weight: Union[float, Tuple[float, float]],
x: torch.Tensor,
x_pad_mask: torch.Tensor,
text_feat: torch.Tensor,
text_feat_pad_mask: torch.Tensor,
timesteps: torch.Tensor,
first_heading_angle: Optional[torch.Tensor] = None,
motion_mask: Optional[torch.Tensor] = None,
observed_motion: Optional[torch.Tensor] = None,
cfg_type: Optional[str] = None,
) -> torch.Tensor:
"""
Args:
cfg_weight (float): guidance weight float or tuple of floats with (text, constraint) weights if using separated cfg
x (torch.Tensor): [B, T, dim_motion] current noisy motion
x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not
text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts
text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not
timesteps (torch.Tensor): [B,] current denoising step
motion_mask
observed_motion
neutral_joints (torch.Tensor): [B, nbjoints] The neutral joints of the motions
Returns:
torch.Tensor: same size as input x
"""
if cfg_type is None:
cfg_type = self.cfg_type_default
assert cfg_type in CFG_TYPES, f"Invalid cfg_type: {cfg_type}"
# batched conditional and uncond pass together
if cfg_type == "nocfg":
return self.model(
x,
x_pad_mask,
text_feat,
text_feat_pad_mask,
timesteps,
first_heading_angle=first_heading_angle,
motion_mask=motion_mask,
observed_motion=observed_motion,
)
elif cfg_type == "regular":
assert isinstance(cfg_weight, (float, int)), "cfg_weight must be a single float for regular CFG"
# out_uncond + w * (out_text_and_constraint - out_uncond)
text_feat = torch.concatenate([text_feat, 0 * text_feat], dim=0)
if motion_mask is not None:
motion_mask = torch.concatenate([motion_mask, 0 * motion_mask], dim=0)
if observed_motion is not None:
observed_motion = torch.concatenate([observed_motion, observed_motion], dim=0)
if first_heading_angle is not None:
first_heading_angle = torch.concatenate([first_heading_angle, first_heading_angle], dim=0)
out_cond_uncond = self.model(
torch.concatenate([x, x], dim=0),
torch.concatenate([x_pad_mask, x_pad_mask], dim=0),
text_feat,
torch.concatenate([text_feat_pad_mask, False * text_feat_pad_mask], dim=0),
torch.concatenate([timesteps, timesteps], dim=0),
first_heading_angle=first_heading_angle,
motion_mask=motion_mask,
observed_motion=observed_motion,
)
out, out_uncond = torch.chunk(out_cond_uncond, 2)
out_new = out_uncond + (cfg_weight * (out - out_uncond))
elif cfg_type == "separated":
assert len(cfg_weight) == 2, "cfg_weight must be a tuple of two floats for separated CFG"
# out_uncond + w_text * (out_text - out_uncond) + w_constraint * (out_constraint - out_uncond)
text_feat = torch.concatenate([text_feat, 0 * text_feat, 0 * text_feat], dim=0)
if motion_mask is not None:
motion_mask = torch.concatenate([0 * motion_mask, motion_mask, 0 * motion_mask], dim=0)
if observed_motion is not None:
observed_motion = torch.concatenate([observed_motion, observed_motion, observed_motion], dim=0)
if first_heading_angle is not None:
first_heading_angle = torch.concatenate(
[first_heading_angle, first_heading_angle, first_heading_angle],
dim=0,
)
out_cond_uncond = self.model(
torch.concatenate([x, x, x], dim=0),
torch.concatenate([x_pad_mask, x_pad_mask, x_pad_mask], dim=0),
text_feat,
torch.concatenate(
[
text_feat_pad_mask,
False * text_feat_pad_mask,
False * text_feat_pad_mask,
],
dim=0,
),
torch.concatenate([timesteps, timesteps, timesteps], dim=0),
first_heading_angle=first_heading_angle,
motion_mask=motion_mask,
observed_motion=observed_motion,
)
out_text, out_constraint, out_uncond = torch.chunk(out_cond_uncond, 3)
out_new = (
out_uncond + (cfg_weight[0] * (out_text - out_uncond)) + (cfg_weight[1] * (out_constraint - out_uncond))
)
else:
raise ValueError(f"Invalid cfg_type: {cfg_type}")
return out_new
|