| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import math |
| from typing import Optional |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
|
|
| from .activations import get_activation |
|
|
|
|
| def get_timestep_embedding( |
| timesteps: torch.Tensor, |
| embedding_dim: int, |
| flip_sin_to_cos: bool = False, |
| downscale_freq_shift: float = 1, |
| scale: float = 1, |
| max_period: int = 10000, |
| ): |
| """ |
| This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. |
| |
| :param timesteps: a 1-D Tensor of N indices, one per batch element. |
| These may be fractional. |
| :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the |
| embeddings. :return: an [N x dim] Tensor of positional embeddings. |
| """ |
| assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" |
|
|
| half_dim = embedding_dim // 2 |
| exponent = -math.log(max_period) * torch.arange( |
| start=0, end=half_dim, dtype=torch.float32, device=timesteps.device |
| ) |
| exponent = exponent / (half_dim - downscale_freq_shift) |
|
|
| emb = torch.exp(exponent) |
| emb = timesteps[:, None].float() * emb[None, :] |
|
|
| |
| emb = scale * emb |
|
|
| |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) |
|
|
| |
| if flip_sin_to_cos: |
| emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) |
|
|
| |
| if embedding_dim % 2 == 1: |
| emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) |
| return emb |
|
|
|
|
| def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): |
| """ |
| grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or |
| [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
| """ |
| grid_h = np.arange(grid_size, dtype=np.float32) |
| grid_w = np.arange(grid_size, dtype=np.float32) |
| grid = np.meshgrid(grid_w, grid_h) |
| grid = np.stack(grid, axis=0) |
|
|
| grid = grid.reshape([2, 1, grid_size, grid_size]) |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
| if cls_token and extra_tokens > 0: |
| pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) |
| return pos_embed |
|
|
|
|
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
| if embed_dim % 2 != 0: |
| raise ValueError("embed_dim must be divisible by 2") |
|
|
| |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
| emb = np.concatenate([emb_h, emb_w], axis=1) |
| return emb |
|
|
|
|
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
| """ |
| embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) |
| """ |
| if embed_dim % 2 != 0: |
| raise ValueError("embed_dim must be divisible by 2") |
|
|
| omega = np.arange(embed_dim // 2, dtype=np.float64) |
| omega /= embed_dim / 2.0 |
| omega = 1.0 / 10000**omega |
|
|
| pos = pos.reshape(-1) |
| out = np.einsum("m,d->md", pos, omega) |
|
|
| emb_sin = np.sin(out) |
| emb_cos = np.cos(out) |
|
|
| emb = np.concatenate([emb_sin, emb_cos], axis=1) |
| return emb |
|
|
|
|
| class PatchEmbed(nn.Module): |
| """2D Image to Patch Embedding""" |
|
|
| def __init__( |
| self, |
| height=224, |
| width=224, |
| patch_size=16, |
| in_channels=3, |
| embed_dim=768, |
| layer_norm=False, |
| flatten=True, |
| bias=True, |
| ): |
| super().__init__() |
|
|
| num_patches = (height // patch_size) * (width // patch_size) |
| self.flatten = flatten |
| self.layer_norm = layer_norm |
|
|
| self.proj = nn.Conv2d( |
| in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias |
| ) |
| if layer_norm: |
| self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) |
| else: |
| self.norm = None |
|
|
| pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) |
| self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) |
|
|
| def forward(self, latent): |
| latent = self.proj(latent) |
| if self.flatten: |
| latent = latent.flatten(2).transpose(1, 2) |
| if self.layer_norm: |
| latent = self.norm(latent) |
| return latent + self.pos_embed |
|
|
|
|
| class TimestepEmbedding(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| time_embed_dim: int, |
| act_fn: str = "silu", |
| out_dim: int = None, |
| post_act_fn: Optional[str] = None, |
| cond_proj_dim=None, |
| ): |
| super().__init__() |
|
|
| self.linear_1 = nn.Linear(in_channels, time_embed_dim) |
|
|
| if cond_proj_dim is not None: |
| self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) |
| else: |
| self.cond_proj = None |
|
|
| self.act = get_activation(act_fn) |
|
|
| if out_dim is not None: |
| time_embed_dim_out = out_dim |
| else: |
| time_embed_dim_out = time_embed_dim |
| self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) |
|
|
| if post_act_fn is None: |
| self.post_act = None |
| else: |
| self.post_act = get_activation(post_act_fn) |
|
|
| def forward(self, sample, condition=None): |
| if condition is not None: |
| sample = sample + self.cond_proj(condition) |
| sample = self.linear_1(sample) |
|
|
| if self.act is not None: |
| sample = self.act(sample) |
|
|
| sample = self.linear_2(sample) |
|
|
| if self.post_act is not None: |
| sample = self.post_act(sample) |
| return sample |
|
|
|
|
| class Timesteps(nn.Module): |
| def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): |
| super().__init__() |
| self.num_channels = num_channels |
| self.flip_sin_to_cos = flip_sin_to_cos |
| self.downscale_freq_shift = downscale_freq_shift |
|
|
| def forward(self, timesteps): |
| t_emb = get_timestep_embedding( |
| timesteps, |
| self.num_channels, |
| flip_sin_to_cos=self.flip_sin_to_cos, |
| downscale_freq_shift=self.downscale_freq_shift, |
| ) |
| return t_emb |
|
|
|
|
| class GaussianFourierProjection(nn.Module): |
| """Gaussian Fourier embeddings for noise levels.""" |
|
|
| def __init__( |
| self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False |
| ): |
| super().__init__() |
| self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) |
| self.log = log |
| self.flip_sin_to_cos = flip_sin_to_cos |
|
|
| if set_W_to_weight: |
| |
| self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) |
|
|
| self.weight = self.W |
|
|
| def forward(self, x): |
| if self.log: |
| x = torch.log(x) |
|
|
| x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi |
|
|
| if self.flip_sin_to_cos: |
| out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) |
| else: |
| out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) |
| return out |
|
|
|
|
| class ImagePositionalEmbeddings(nn.Module): |
| """ |
| Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the |
| height and width of the latent space. |
| |
| For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 |
| |
| For VQ-diffusion: |
| |
| Output vector embeddings are used as input for the transformer. |
| |
| Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. |
| |
| Args: |
| num_embed (`int`): |
| Number of embeddings for the latent pixels embeddings. |
| height (`int`): |
| Height of the latent image i.e. the number of height embeddings. |
| width (`int`): |
| Width of the latent image i.e. the number of width embeddings. |
| embed_dim (`int`): |
| Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. |
| """ |
|
|
| def __init__( |
| self, |
| num_embed: int, |
| height: int, |
| width: int, |
| embed_dim: int, |
| ): |
| super().__init__() |
|
|
| self.height = height |
| self.width = width |
| self.num_embed = num_embed |
| self.embed_dim = embed_dim |
|
|
| self.emb = nn.Embedding(self.num_embed, embed_dim) |
| self.height_emb = nn.Embedding(self.height, embed_dim) |
| self.width_emb = nn.Embedding(self.width, embed_dim) |
|
|
| def forward(self, index): |
| emb = self.emb(index) |
|
|
| height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) |
|
|
| |
| height_emb = height_emb.unsqueeze(2) |
|
|
| width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) |
|
|
| |
| width_emb = width_emb.unsqueeze(1) |
|
|
| pos_emb = height_emb + width_emb |
|
|
| |
| pos_emb = pos_emb.view(1, self.height * self.width, -1) |
|
|
| emb = emb + pos_emb[:, : emb.shape[1], :] |
|
|
| return emb |
|
|
|
|
| class LabelEmbedding(nn.Module): |
| """ |
| Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. |
| |
| Args: |
| num_classes (`int`): The number of classes. |
| hidden_size (`int`): The size of the vector embeddings. |
| dropout_prob (`float`): The probability of dropping a label. |
| """ |
|
|
| def __init__(self, num_classes, hidden_size, dropout_prob): |
| super().__init__() |
| use_cfg_embedding = dropout_prob > 0 |
| self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) |
| self.num_classes = num_classes |
| self.dropout_prob = dropout_prob |
|
|
| def token_drop(self, labels, force_drop_ids=None): |
| """ |
| Drops labels to enable classifier-free guidance. |
| """ |
| if force_drop_ids is None: |
| drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob |
| else: |
| drop_ids = torch.tensor(force_drop_ids == 1) |
| labels = torch.where(drop_ids, self.num_classes, labels) |
| return labels |
|
|
| def forward(self, labels: torch.LongTensor, force_drop_ids=None): |
| use_dropout = self.dropout_prob > 0 |
| if (self.training and use_dropout) or (force_drop_ids is not None): |
| labels = self.token_drop(labels, force_drop_ids) |
| embeddings = self.embedding_table(labels) |
| return embeddings |
|
|
|
|
| class TextImageProjection(nn.Module): |
| def __init__( |
| self, |
| text_embed_dim: int = 1024, |
| image_embed_dim: int = 768, |
| cross_attention_dim: int = 768, |
| num_image_text_embeds: int = 10, |
| ): |
| super().__init__() |
|
|
| self.num_image_text_embeds = num_image_text_embeds |
| self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) |
| self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) |
|
|
| def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): |
| batch_size = text_embeds.shape[0] |
|
|
| |
| image_text_embeds = self.image_embeds(image_embeds) |
| image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) |
|
|
| |
| text_embeds = self.text_proj(text_embeds) |
|
|
| return torch.cat([image_text_embeds, text_embeds], dim=1) |
|
|
|
|
| class ImageProjection(nn.Module): |
| def __init__( |
| self, |
| image_embed_dim: int = 768, |
| cross_attention_dim: int = 768, |
| num_image_text_embeds: int = 32, |
| ): |
| super().__init__() |
|
|
| self.num_image_text_embeds = num_image_text_embeds |
| self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) |
| self.norm = nn.LayerNorm(cross_attention_dim) |
|
|
| def forward(self, image_embeds: torch.FloatTensor): |
| batch_size = image_embeds.shape[0] |
|
|
| |
| image_embeds = self.image_embeds(image_embeds) |
| image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) |
| image_embeds = self.norm(image_embeds) |
| return image_embeds |
|
|
|
|
| class CombinedTimestepLabelEmbeddings(nn.Module): |
| def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): |
| super().__init__() |
|
|
| self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) |
| self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) |
| self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) |
|
|
| def forward(self, timestep, class_labels, hidden_dtype=None): |
| timesteps_proj = self.time_proj(timestep) |
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) |
|
|
| class_labels = self.class_embedder(class_labels) |
|
|
| conditioning = timesteps_emb + class_labels |
|
|
| return conditioning |
|
|
|
|
| class TextTimeEmbedding(nn.Module): |
| def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(encoder_dim) |
| self.pool = AttentionPooling(num_heads, encoder_dim) |
| self.proj = nn.Linear(encoder_dim, time_embed_dim) |
| self.norm2 = nn.LayerNorm(time_embed_dim) |
|
|
| def forward(self, hidden_states): |
| hidden_states = self.norm1(hidden_states) |
| hidden_states = self.pool(hidden_states) |
| hidden_states = self.proj(hidden_states) |
| hidden_states = self.norm2(hidden_states) |
| return hidden_states |
|
|
|
|
| class TextImageTimeEmbedding(nn.Module): |
| def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): |
| super().__init__() |
| self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) |
| self.text_norm = nn.LayerNorm(time_embed_dim) |
| self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) |
|
|
| def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): |
| |
| time_text_embeds = self.text_proj(text_embeds) |
| time_text_embeds = self.text_norm(time_text_embeds) |
|
|
| |
| time_image_embeds = self.image_proj(image_embeds) |
|
|
| return time_image_embeds + time_text_embeds |
|
|
|
|
| class ImageTimeEmbedding(nn.Module): |
| def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): |
| super().__init__() |
| self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) |
| self.image_norm = nn.LayerNorm(time_embed_dim) |
|
|
| def forward(self, image_embeds: torch.FloatTensor): |
| |
| time_image_embeds = self.image_proj(image_embeds) |
| time_image_embeds = self.image_norm(time_image_embeds) |
| return time_image_embeds |
|
|
|
|
| class ImageHintTimeEmbedding(nn.Module): |
| def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): |
| super().__init__() |
| self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) |
| self.image_norm = nn.LayerNorm(time_embed_dim) |
| self.input_hint_block = nn.Sequential( |
| nn.Conv2d(3, 16, 3, padding=1), |
| nn.SiLU(), |
| nn.Conv2d(16, 16, 3, padding=1), |
| nn.SiLU(), |
| nn.Conv2d(16, 32, 3, padding=1, stride=2), |
| nn.SiLU(), |
| nn.Conv2d(32, 32, 3, padding=1), |
| nn.SiLU(), |
| nn.Conv2d(32, 96, 3, padding=1, stride=2), |
| nn.SiLU(), |
| nn.Conv2d(96, 96, 3, padding=1), |
| nn.SiLU(), |
| nn.Conv2d(96, 256, 3, padding=1, stride=2), |
| nn.SiLU(), |
| nn.Conv2d(256, 4, 3, padding=1), |
| ) |
|
|
| def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor): |
| |
| time_image_embeds = self.image_proj(image_embeds) |
| time_image_embeds = self.image_norm(time_image_embeds) |
| hint = self.input_hint_block(hint) |
| return time_image_embeds, hint |
|
|
|
|
| class AttentionPooling(nn.Module): |
| |
|
|
| def __init__(self, num_heads, embed_dim, dtype=None): |
| super().__init__() |
| self.dtype = dtype |
| self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) |
| self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) |
| self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) |
| self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) |
| self.num_heads = num_heads |
| self.dim_per_head = embed_dim // self.num_heads |
|
|
| def forward(self, x): |
| bs, length, width = x.size() |
|
|
| def shape(x): |
| |
| x = x.view(bs, -1, self.num_heads, self.dim_per_head) |
| |
| x = x.transpose(1, 2) |
| |
| x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) |
| |
| x = x.transpose(1, 2) |
| return x |
|
|
| class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) |
| x = torch.cat([class_token, x], dim=1) |
|
|
| |
| q = shape(self.q_proj(class_token)) |
| |
| k = shape(self.k_proj(x)) |
| v = shape(self.v_proj(x)) |
|
|
| |
| scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) |
| weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) |
| weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) |
|
|
| |
| a = torch.einsum("bts,bcs->bct", weight, v) |
|
|
| |
| a = a.reshape(bs, -1, 1).transpose(1, 2) |
|
|
| return a[:, 0, :] |
|
|