| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import repeat |
| import math |
| from .udit import UDiT |
| from .utils.span_mask import compute_mask_indices |
|
|
|
|
| class EmbeddingCFG(nn.Module): |
| """ |
| Handles label dropout for classifier-free guidance. |
| """ |
| |
|
|
| def __init__(self, in_channels): |
| super().__init__() |
| self.cfg_embedding = nn.Parameter( |
| torch.randn(in_channels) / in_channels ** 0.5) |
|
|
| def token_drop(self, condition, condition_mask, cfg_prob): |
| """ |
| Drops labels to enable classifier-free guidance. |
| """ |
| b, t, device = condition.shape[0], condition.shape[1], condition.device |
| drop_ids = torch.rand(b, device=device) < cfg_prob |
| uncond = repeat(self.cfg_embedding, "c -> b t c", b=b, t=t) |
| condition = torch.where(drop_ids[:, None, None], uncond, condition) |
| if condition_mask is not None: |
| condition_mask[drop_ids] = False |
| condition_mask[drop_ids, 0] = True |
|
|
| return condition, condition_mask |
|
|
| def forward(self, condition, condition_mask, cfg_prob=0.0): |
| if condition_mask is not None: |
| condition_mask = condition_mask.clone() |
| if cfg_prob > 0: |
| condition, condition_mask = self.token_drop(condition, |
| condition_mask, |
| cfg_prob) |
| return condition, condition_mask |
|
|
|
|
| class DiscreteCFG(nn.Module): |
| def __init__(self, replace_id=2): |
| super(DiscreteCFG, self).__init__() |
| self.replace_id = replace_id |
|
|
| def forward(self, context, context_mask, cfg_prob): |
| context = context.clone() |
| if context_mask is not None: |
| context_mask = context_mask.clone() |
| if cfg_prob > 0: |
| cfg_mask = torch.rand(len(context)) < cfg_prob |
| if torch.any(cfg_mask): |
| context[cfg_mask] = 0 |
| context[cfg_mask, 0] = self.replace_id |
| if context_mask is not None: |
| context_mask[cfg_mask] = False |
| context_mask[cfg_mask, 0] = True |
| return context, context_mask |
|
|
|
|
| class CFGModel(nn.Module): |
| def __init__(self, context_dim, backbone): |
| super().__init__() |
| self.model = backbone |
| self.context_cfg = EmbeddingCFG(context_dim) |
|
|
| def forward(self, x, timesteps, |
| context, x_mask=None, context_mask=None, |
| cfg_prob=0.0): |
| context = self.context_cfg(context, cfg_prob) |
| x = self.model(x=x, timesteps=timesteps, |
| context=context, |
| x_mask=x_mask, context_mask=context_mask) |
| return x |
|
|
|
|
| class ConcatModel(nn.Module): |
| def __init__(self, backbone, in_dim, stride=[]): |
| super().__init__() |
| self.model = backbone |
|
|
| self.downsample_layers = nn.ModuleList() |
| for i, s in enumerate(stride): |
| downsample_layer = nn.Conv1d( |
| in_dim, |
| in_dim * 2, |
| kernel_size=2 * s, |
| stride=s, |
| padding=math.ceil(s / 2), |
| ) |
| self.downsample_layers.append(downsample_layer) |
| in_dim = in_dim * 2 |
|
|
| self.context_cfg = EmbeddingCFG(in_dim) |
|
|
| def forward(self, x, timesteps, |
| context, x_mask=None, |
| cfg=False, cfg_prob=0.0): |
|
|
| |
| |
| |
|
|
| for downsample_layer in self.downsample_layers: |
| context = downsample_layer(context) |
|
|
| context = context.transpose(1, 2) |
| context = self.context_cfg(caption=context, |
| cfg=cfg, cfg_prob=cfg_prob) |
| context = context.transpose(1, 2) |
|
|
| assert context.shape[-1] == x.shape[-1] |
| x = torch.cat([context, x], dim=1) |
| x = self.model(x=x, timesteps=timesteps, |
| context=None, x_mask=x_mask, context_mask=None) |
| return x |
|
|
|
|
| class MaskDiT(nn.Module): |
| def __init__(self, mae=False, mae_prob=0.5, mask_ratio=[0.25, 1.0], mask_span=10, **kwargs): |
| super().__init__() |
| self.model = UDiT(**kwargs) |
| self.mae = mae |
| if self.mae: |
| out_channel = kwargs.pop('out_chans', None) |
| self.mask_embed = nn.Parameter(torch.zeros((out_channel))) |
| self.mae_prob = mae_prob |
| self.mask_ratio = mask_ratio |
| self.mask_span = mask_span |
|
|
| def random_masking(self, gt, mask_ratios, mae_mask_infer=None): |
| B, D, L = gt.shape |
| if mae_mask_infer is None: |
| |
| mask_ratios = mask_ratios.cpu().numpy() |
| mask = compute_mask_indices(shape=[B, L], |
| padding_mask=None, |
| mask_prob=mask_ratios, |
| mask_length=self.mask_span, |
| mask_type="static", |
| mask_other=0.0, |
| min_masks=1, |
| no_overlap=False, |
| min_space=0,) |
| mask = mask.unsqueeze(1).expand_as(gt) |
| else: |
| mask = mae_mask_infer |
| mask = mask.expand_as(gt) |
| gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask] |
| return gt, mask.type_as(gt) |
|
|
| def forward(self, x, timesteps, context, |
| x_mask=None, context_mask=None, cls_token=None, |
| gt=None, mae_mask_infer=None, |
| forward_model=True): |
| |
| mae_mask = torch.ones_like(x) |
| if self.mae: |
| if gt is not None: |
| B, D, L = gt.shape |
| mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(gt.device) |
| gt, mae_mask = self.random_masking(gt, mask_ratios, mae_mask_infer) |
| |
| if mae_mask_infer is None: |
| |
| mae_batch = torch.rand(B) < self.mae_prob |
| gt[~mae_batch] = self.mask_embed.view(1, D, 1).expand_as(gt)[~mae_batch] |
| mae_mask[~mae_batch] = 1.0 |
| else: |
| B, D, L = x.shape |
| gt = self.mask_embed.view(1, D, 1).expand_as(x) |
| x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1) |
|
|
| if forward_model: |
| x = self.model(x=x, timesteps=timesteps, context=context, |
| x_mask=x_mask, context_mask=context_mask, |
| cls_token=cls_token) |
| |
| return x, mae_mask |
|
|