| import torch |
| import torch.nn as nn |
| |
| |
| |
|
|
| from .dit_models import TimestepEmbedder, LabelEmbedder, DiTBlock, get_2d_sincos_pos_embed |
|
|
|
|
| class DiTwoEmbedder(nn.Module): |
| """ |
| Diffusion model with a Transformer backbone, performing directly on the ViT token latents rather than spatial latents. |
| """ |
|
|
| def __init__( |
| self, |
| input_size=224, |
| |
| in_channels=4, |
| hidden_size=1152, |
| depth=28, |
| num_heads=16, |
| mlp_ratio=4.0, |
| class_dropout_prob=0.1, |
| num_classes=1000, |
| learn_sigma=True, |
| ): |
| super().__init__() |
| self.learn_sigma = learn_sigma |
| self.in_channels = in_channels |
| self.out_channels = in_channels * 2 if learn_sigma else in_channels |
| self.patch_size = 14 |
| self.num_heads = num_heads |
|
|
| |
| |
| |
| |
| |
| self.t_embedder = TimestepEmbedder(hidden_size) |
| if num_classes > 0: |
| self.y_embedder = LabelEmbedder(num_classes, hidden_size, |
| class_dropout_prob) |
| else: |
| self.y_embedder = None |
|
|
| |
| self.num_patches = (input_size // self.patch_size)**2 |
|
|
| |
| self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, |
| hidden_size), |
| requires_grad=False) |
|
|
| self.blocks = nn.ModuleList([ |
| DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) |
| for _ in range(depth) |
| ]) |
| |
| |
| self.initialize_weights() |
|
|
| def initialize_weights(self): |
| |
| def _basic_init(module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.constant_(module.bias, 0) |
|
|
| self.apply(_basic_init) |
|
|
| |
| pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], |
| int(self.num_patches**0.5)) |
| |
| self.pos_embed.data.copy_( |
| torch.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
| |
| |
| |
| |
|
|
| |
| if self.y_embedder is not None: |
| nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) |
|
|
| |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
|
| |
| for block in self.blocks: |
| nn.init.constant_(block.adaLN_modulation[-1].weight, 0) |
| nn.init.constant_(block.adaLN_modulation[-1].bias, 0) |
|
|
| |
| |
| |
| |
| |
|
|
| def forward(self, x, t, y=None): |
| """ |
| Forward pass of DiT. |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) |
| t: (N,) tensor of diffusion timesteps |
| y: (N,) tensor of class labels |
| """ |
|
|
| |
| |
| |
| x = x + self.pos_embed |
|
|
| t = self.t_embedder(t) |
|
|
| if self.y_embedder is not None: |
| assert y is not None |
| y = self.y_embedder(y, self.training) |
| c = t + y |
| else: |
| c = t |
|
|
| for block in self.blocks: |
| x = block(x, c) |
|
|
| |
| |
|
|
| return x |
|
|
| def forward_with_cfg(self, x, t, y, cfg_scale): |
| """ |
| Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. |
| """ |
| |
| half = x[:len(x) // 2] |
| combined = torch.cat([half, half], dim=0) |
| model_out = self.forward(combined, t, y) |
| |
| |
| |
| |
| eps, rest = model_out[:, :3], model_out[:, 3:] |
| cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) |
| half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
| eps = torch.cat([half_eps, half_eps], dim=0) |
| return torch.cat([eps, rest], dim=1) |
|
|
| def forward_with_cfg_unconditional(self, x, t, y=None, cfg_scale=None): |
| """ |
| Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. |
| """ |
|
|
| combined = x |
| model_out = self.forward(combined, t, y) |
|
|
| return model_out |
|
|
|
|
| class DiTwoEmbedderLongSkipConnection(nn.Module): |
|
|
| def __init__( |
| self, |
| input_size=224, |
| patch_size=14, |
| in_channels=4, |
| hidden_size=1152, |
| depth=28, |
| num_heads=16, |
| mlp_ratio=4.0, |
| class_dropout_prob=0.1, |
| num_classes=1000, |
| learn_sigma=True, |
| ): |
| """DiT with long skip-connections from U-ViT, CVPR 23' |
| """ |
| super().__init__() |
| self.learn_sigma = learn_sigma |
| self.in_channels = in_channels |
| self.out_channels = in_channels * 2 if learn_sigma else in_channels |
| self.patch_size = patch_size |
| self.num_heads = num_heads |
|
|
| self.t_embedder = TimestepEmbedder(hidden_size) |
| if num_classes > 0: |
| self.y_embedder = LabelEmbedder(num_classes, hidden_size, |
| class_dropout_prob) |
| else: |
| self.y_embedder = None |
|
|
| |
| self.num_patches = (input_size // patch_size)**2 |
|
|
| |
| self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, |
| hidden_size), |
| requires_grad=False) |
|
|
| self.blocks = nn.ModuleList([ |
| DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) |
| for _ in range(depth) |
| ]) |
|
|
| |
| self.in_blocks = nn.ModuleList([ |
| DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) |
| for _ in range(depth // 2) |
| ]) |
|
|
| self.mid_block = DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) |
|
|
| self.out_blocks = nn.ModuleList([ |
| DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) |
| for _ in range(depth // 2) |
| ]) |
|
|
| |
| |
| |
| self.initialize_weights() |
|
|
| def initialize_weights(self): |
| |
| def _basic_init(module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.constant_(module.bias, 0) |
|
|
| self.apply(_basic_init) |
|
|
| |
| pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], |
| int(self.num_patches**0.5)) |
| |
| self.pos_embed.data.copy_( |
| torch.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
| |
| |
| |
| |
|
|
| |
| if self.y_embedder is not None: |
| nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) |
|
|
| |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
|
| |
| for block in self.blocks: |
| nn.init.constant_(block.adaLN_modulation[-1].weight, 0) |
| nn.init.constant_(block.adaLN_modulation[-1].bias, 0) |
|
|
| |
| |
| |
| |
| |
|
|
| def forward(self, x, t, y=None): |
| """ |
| Forward pass of DiT. |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) |
| t: (N,) tensor of diffusion timesteps |
| y: (N,) tensor of class labels |
| """ |
|
|
| |
| |
| |
| x = x + self.pos_embed |
|
|
| t = self.t_embedder(t) |
|
|
| if self.y_embedder is not None: |
| assert y is not None |
| y = self.y_embedder(y, self.training) |
| c = t + y |
| else: |
| c = t |
|
|
| |
|
|
| |
| |
|
|
| skips = [] |
| for blk in self.in_blocks: |
| x = blk(x) |
| skips.append(x) |
|
|
| x = self.mid_block(x) |
|
|
| for blk in self.out_blocks: |
| x = blk(x, skips.pop()) |
|
|
| |
| |
| |
|
|
| return x |
|
|
| def forward_with_cfg(self, x, t, y, cfg_scale): |
| """ |
| Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. |
| """ |
| |
| half = x[:len(x) // 2] |
| combined = torch.cat([half, half], dim=0) |
| model_out = self.forward(combined, t, y) |
| |
| |
| |
| |
| eps, rest = model_out[:, :3], model_out[:, 3:] |
| cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) |
| half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
| eps = torch.cat([half_eps, half_eps], dim=0) |
| return torch.cat([eps, rest], dim=1) |
|
|
| def forward_with_cfg_unconditional(self, x, t, y=None, cfg_scale=None): |
| """ |
| Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. |
| """ |
|
|
| combined = x |
| model_out = self.forward(combined, t, y) |
|
|
| return model_out |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
|
|
| def DiT_woembed_S(**kwargs): |
| return DiTwoEmbedder(depth=12, hidden_size=384, num_heads=6, **kwargs) |
|
|
|
|
| def DiT_woembed_B(**kwargs): |
| return DiTwoEmbedder(depth=12, hidden_size=768, num_heads=12, **kwargs) |
|
|
|
|
| def DiT_woembed_L(**kwargs): |
| return DiTwoEmbedder( |
| depth=24, |
| hidden_size=1024, |
| num_heads=16, |
| **kwargs) |
|
|
|
|
| DiT_woembed_models = { |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| 'DiT-wo-S': DiT_woembed_S, |
| 'DiT-wo-B': DiT_woembed_B, |
| 'DiT-wo-L': DiT_woembed_L, |
| } |
|
|