Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class FFN(nn.Module): | |
| """ | |
| Modified from: https://github.com/huggingface/transformers/blob/8ebfd84fa7f4d6c59f5059a439fad10ada26b3ff/src/transformers/models/llama/modeling_llama.py#L173 | |
| """ | |
| def __init__(self, hidden_size, intermediate_size, p_dropout=0.): | |
| super().__init__() | |
| self.up_gate_proj = nn.Linear(hidden_size, 2 * intermediate_size) | |
| self.down_proj = nn.Linear(intermediate_size, hidden_size) | |
| self.dropout = nn.Dropout(p_dropout) | |
| self.act_fn = nn.SiLU() | |
| def forward(self, x): | |
| up, gate = self.up_gate_proj(x).chunk(2, dim=-1) | |
| return self.down_proj(self.dropout(self.act_fn(gate) * up)) | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, hidden_size, n_heads, p_dropout=0.): | |
| super().__init__() | |
| assert hidden_size % n_heads == 0 | |
| self.n_heads = n_heads | |
| self.p_dropout = p_dropout | |
| self.head_dim = hidden_size // n_heads | |
| self.qkv = nn.Linear(hidden_size, 3 * hidden_size, bias=False) | |
| self.rotary_pe = RotaryPositionalEmbeddings(self.head_dim) | |
| self.out_proj = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x, attn_mask=None, position_ids=None): | |
| qkv = self.qkv(x) | |
| q, k, v = qkv.chunk(3, dim=-1) # [b, l, h] | |
| b, l, h = q.shape | |
| # [b, l, h] -> [b, self.n_heads, l, self.head_dim] | |
| q = q.view(b, l, self.n_heads, self.head_dim).transpose(1, 2) | |
| k = k.view(b, l, self.n_heads, self.head_dim).transpose(1, 2) | |
| v = v.view(b, l, self.n_heads, self.head_dim).transpose(1, 2) | |
| q = self.rotary_pe(q, position_ids) | |
| k = self.rotary_pe(k, position_ids) | |
| attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.p_dropout if self.training else 0) | |
| attn = attn.transpose(1, 2).contiguous().view(b, l, h) | |
| x = self.out_proj(attn) | |
| return x | |
| class RMSNorm(nn.Module): | |
| """ | |
| Modified from: https://docs.pytorch.org/torchtune/0.2/_modules/torchtune/modules/rms_norm.html#RMSNorm | |
| """ | |
| def __init__(self, dim, eps=1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.scale = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x): | |
| # computation is in fp32 | |
| x_fp32 = x.float() | |
| x_normed = ( | |
| x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps) | |
| ).type_as(x) | |
| return x_normed * self.scale | |
| def modulate(x, shift, scale): | |
| return x * (1 + scale) + shift | |
| # modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390 | |
| class DiTBlock(nn.Module): | |
| """ | |
| A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. | |
| """ | |
| def __init__(self, hidden_size, intermediate_size, num_heads, p_dropout): | |
| super().__init__() | |
| self.norm1 = RMSNorm(hidden_size) | |
| self.attn = MultiHeadAttention(hidden_size, num_heads, p_dropout) | |
| self.norm2 = RMSNorm(hidden_size) | |
| self.mlp = FFN(hidden_size, intermediate_size, p_dropout) | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, 6 * hidden_size, bias=True) | |
| ) | |
| def forward(self, x, c, x_mask, attn_mask=None, position_ids=None): | |
| """ | |
| Args: | |
| x : [b, l, h] | |
| c : [b, h] | |
| x_mask : [b, l, 1] | |
| attn_mask: [b, 1, l, l] | |
| return the same shape as x | |
| """ | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(1).chunk(6, dim=-1) # shape: [b, 1, h] | |
| x = x * x_mask | |
| x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), attn_mask, position_ids) * x_mask | |
| x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) | |
| return x * x_mask | |
| class DiTFinalLayer(nn.Module): | |
| """ | |
| Modified from: https://github.com/facebookresearch/DiT/blob/ed81ce2229091fd4ecc9a223645f95cf379d582b/models.py#L125 | |
| """ | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.norm = RMSNorm(hidden_size) | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, 2 * hidden_size, bias=True) | |
| ) | |
| def forward(self, x, c): | |
| shift, scale = self.adaLN_modulation(c).unsqueeze(1).chunk(2, dim=-1) | |
| x = modulate(self.norm(x), shift, scale) | |
| return x | |
| class RotaryPositionalEmbeddings(nn.Module): | |
| """ | |
| Modified from: | |
| https://colab.research.google.com/drive/11SKfzvMotuvvXNqY9qBpsD2RQX1PK7rP?usp=sharing#scrollTo=XNeygwV2gEWH | |
| https://github.com/huggingface/transformers/blob/8ebfd84fa7f4d6c59f5059a439fad10ada26b3ff/src/transformers/models/llama/modeling_llama.py#L73 | |
| """ | |
| def __init__(self, d: int, base: int = 10_000): | |
| r""" | |
| * `d` is the number of features $d$ | |
| * `base` is the constant used for calculating $\Theta$ | |
| """ | |
| super().__init__() | |
| self.base = base | |
| self.d = int(d) | |
| self.cos_cached = None | |
| self.sin_cached = None | |
| def _build_cache(self, seq_len: int, device: torch.device): | |
| r""" | |
| Cache $\cos$ and $\sin$ values | |
| """ | |
| # Return if cache is already built | |
| if self.cos_cached is not None and seq_len <= self.cos_cached.shape[0]: | |
| return | |
| # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ | |
| theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(device) | |
| # Create position indexes `[0, 1, ..., seq_len - 1]` | |
| seq_idx = torch.arange(seq_len, device=device).float().to(device) | |
| # Calculate the product of position index and $\theta_i$ | |
| idx_theta = torch.einsum("n,d->nd", seq_idx, theta) | |
| # Concatenate so that for row $m$ we have | |
| # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ | |
| idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) | |
| # Cache them | |
| self.cos_cached = idx_theta2.cos()[:, None, None, :] | |
| self.sin_cached = idx_theta2.sin()[:, None, None, :] | |
| def _neg_half(self, x: torch.Tensor): | |
| # $\frac{d}{2}$ | |
| d_2 = self.d // 2 | |
| # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ | |
| return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) | |
| # [x_1, x_2,...x_d] -> [-x_d/2, ... -x_d, x_1, ... x_d/2] | |
| def forward(self, x: torch.Tensor, position_ids: torch.Tensor | None = None): | |
| # Cache $\cos$ and $\sin$ values | |
| x = x.permute(2, 0, 1, 3) # [b, n_heads, l, d] -> [l, b, n_heads, d] | |
| device = x.device | |
| if position_ids is None: | |
| l = x.shape[0] | |
| self._build_cache(l, device) | |
| cos = self.cos_cached[:l] | |
| sin = self.sin_cached[:l] # [l, 1, 1, d] | |
| else: | |
| max_pos = int(position_ids.max().item()) + 1 | |
| self._build_cache(max_pos, device) | |
| # cos_cached: [max_len, 1, 1, d] | |
| cos = self.cos_cached[position_ids].squeeze(3).squeeze(2) # [b, l, 1, 1, d] -> [b, l, d] | |
| sin = self.sin_cached[position_ids].squeeze(3).squeeze(2) | |
| cos = cos.permute(1, 0, 2)[:, :, None, :] # [b, l, d] -> [l, b, 1, d] | |
| sin = sin.permute(1, 0, 2)[:, :, None, :] | |
| # Split the features, we can choose to apply rotary embeddings only to a partial set of features. | |
| x_rope, x_pass = x[..., : self.d], x[..., self.d :] | |
| # Calculate | |
| # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ | |
| neg_half_x = self._neg_half(x_rope) | |
| x_rope = x_rope * cos + neg_half_x * sin # [l, b, n_heads, d] | |
| return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) # [l, b, n_heads, d] -> [b, n_heads, l, d] | |