File size: 6,025 Bytes
6d5047c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Classifier-free guidance wrapper for the denoiser at sampling time."""

from typing import Dict, Optional, Tuple, Union

import torch
import torch.nn as nn

CFG_TYPES = ["nocfg", "regular", "separated"]


class ClassifierFreeGuidedModel(nn.Module):
    """Wrapper around denoiser to use classifier-free guidance at sampling time."""

    def __init__(self, model: nn.Module, cfg_type: Optional[str] = "separated"):
        """Wrap the denoiser for classifier-free guidance; cfg_type in CFG_TYPES (e.g. 'regular',
        'nocfg')."""
        super().__init__()
        self.model = model
        assert cfg_type in CFG_TYPES, f"Invalid cfg_type: {cfg_type}"
        self.cfg_type_default = cfg_type

    def forward(
        self,
        cfg_weight: Union[float, Tuple[float, float]],
        x: torch.Tensor,
        x_pad_mask: torch.Tensor,
        text_feat: torch.Tensor,
        text_feat_pad_mask: torch.Tensor,
        timesteps: torch.Tensor,
        first_heading_angle: Optional[torch.Tensor] = None,
        motion_mask: Optional[torch.Tensor] = None,
        observed_motion: Optional[torch.Tensor] = None,
        cfg_type: Optional[str] = None,
    ) -> torch.Tensor:
        """
        Args:
            cfg_weight (float): guidance weight float or tuple of floats with (text, constraint) weights if using separated cfg
            x (torch.Tensor): [B, T, dim_motion] current noisy motion
            x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not
            text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts
            text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not
            timesteps (torch.Tensor): [B,] current denoising step
            motion_mask
            observed_motion
            neutral_joints (torch.Tensor): [B, nbjoints] The neutral joints of the motions

        Returns:
            torch.Tensor: same size as input x
        """

        if cfg_type is None:
            cfg_type = self.cfg_type_default

        assert cfg_type in CFG_TYPES, f"Invalid cfg_type: {cfg_type}"

        # batched conditional and uncond pass together
        if cfg_type == "nocfg":
            return self.model(
                x,
                x_pad_mask,
                text_feat,
                text_feat_pad_mask,
                timesteps,
                first_heading_angle=first_heading_angle,
                motion_mask=motion_mask,
                observed_motion=observed_motion,
            )
        elif cfg_type == "regular":
            assert isinstance(cfg_weight, (float, int)), "cfg_weight must be a single float for regular CFG"
            # out_uncond + w * (out_text_and_constraint - out_uncond)
            text_feat = torch.concatenate([text_feat, 0 * text_feat], dim=0)
            if motion_mask is not None:
                motion_mask = torch.concatenate([motion_mask, 0 * motion_mask], dim=0)
            if observed_motion is not None:
                observed_motion = torch.concatenate([observed_motion, observed_motion], dim=0)
            if first_heading_angle is not None:
                first_heading_angle = torch.concatenate([first_heading_angle, first_heading_angle], dim=0)

            out_cond_uncond = self.model(
                torch.concatenate([x, x], dim=0),
                torch.concatenate([x_pad_mask, x_pad_mask], dim=0),
                text_feat,
                torch.concatenate([text_feat_pad_mask, False * text_feat_pad_mask], dim=0),
                torch.concatenate([timesteps, timesteps], dim=0),
                first_heading_angle=first_heading_angle,
                motion_mask=motion_mask,
                observed_motion=observed_motion,
            )

            out, out_uncond = torch.chunk(out_cond_uncond, 2)
            out_new = out_uncond + (cfg_weight * (out - out_uncond))
        elif cfg_type == "separated":
            assert len(cfg_weight) == 2, "cfg_weight must be a tuple of two floats for separated CFG"
            # out_uncond + w_text * (out_text - out_uncond) + w_constraint * (out_constraint - out_uncond)
            text_feat = torch.concatenate([text_feat, 0 * text_feat, 0 * text_feat], dim=0)
            if motion_mask is not None:
                motion_mask = torch.concatenate([0 * motion_mask, motion_mask, 0 * motion_mask], dim=0)
            if observed_motion is not None:
                observed_motion = torch.concatenate([observed_motion, observed_motion, observed_motion], dim=0)
            if first_heading_angle is not None:
                first_heading_angle = torch.concatenate(
                    [first_heading_angle, first_heading_angle, first_heading_angle],
                    dim=0,
                )

            out_cond_uncond = self.model(
                torch.concatenate([x, x, x], dim=0),
                torch.concatenate([x_pad_mask, x_pad_mask, x_pad_mask], dim=0),
                text_feat,
                torch.concatenate(
                    [
                        text_feat_pad_mask,
                        False * text_feat_pad_mask,
                        False * text_feat_pad_mask,
                    ],
                    dim=0,
                ),
                torch.concatenate([timesteps, timesteps, timesteps], dim=0),
                first_heading_angle=first_heading_angle,
                motion_mask=motion_mask,
                observed_motion=observed_motion,
            )

            out_text, out_constraint, out_uncond = torch.chunk(out_cond_uncond, 3)
            out_new = (
                out_uncond + (cfg_weight[0] * (out_text - out_uncond)) + (cfg_weight[1] * (out_constraint - out_uncond))
            )
        else:
            raise ValueError(f"Invalid cfg_type: {cfg_type}")

        return out_new