Spaces:
Running on Zero
Running on Zero
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from models.diffusion_transformer import DiTBlock, DiTFinalLayer | |
| # reference: https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/decoder.py | |
| class SinusoidalPosEmb(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" | |
| def forward(self, x, scale=1000): | |
| if x.ndim < 1: | |
| x = x.unsqueeze(0) | |
| half_dim = self.dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, device=x.device).float() * -emb) | |
| emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) | |
| emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
| return emb | |
| class TimestepEmbedding(nn.Module): | |
| def __init__(self, in_channels, out_channels, intermediate_size): | |
| super().__init__() | |
| self.layer = nn.Sequential( | |
| nn.Linear(in_channels, intermediate_size), | |
| nn.SiLU(inplace=True), | |
| nn.Linear(intermediate_size, out_channels) | |
| ) | |
| def forward(self, x): | |
| return self.layer(x) | |
| class Decoder(nn.Module): | |
| def __init__(self, configs): | |
| super().__init__() | |
| hidden_size = configs.hidden_size | |
| intermediate_size = configs.intermediate_size | |
| self.codebook_size = configs.codebook_size | |
| quantizers_num = configs.quantizers_num | |
| self.text_embed = nn.Embedding(configs.n_vocab, hidden_size) | |
| self.token_embed = nn.Parameter( | |
| torch.empty(quantizers_num, configs.codebook_size + configs.special_codebook_size, hidden_size) | |
| ) | |
| self.input_proj = nn.Linear(quantizers_num * hidden_size, hidden_size) | |
| self.time_embeddings = SinusoidalPosEmb(hidden_size//2) | |
| self.time_mlp = TimestepEmbedding(hidden_size//2, hidden_size//2, intermediate_size//2) | |
| self.cfg_dropout = configs.cfg_dropout | |
| self.text_cfg_embed = nn.Parameter(torch.empty(hidden_size)) | |
| self.token_cfg_embed = nn.Parameter(torch.empty(hidden_size)) | |
| self.prompt_embed = nn.Parameter(torch.empty(hidden_size)) | |
| self.lang_embed = nn.Embedding(configs.n_lang, hidden_size//2) | |
| self.blocks = nn.ModuleList([DiTBlock(hidden_size, intermediate_size, configs.n_heads, configs.dropout) for _ in range(configs.n_layers)]) | |
| self.final_layer = DiTFinalLayer(hidden_size) | |
| self.init_weights() | |
| def init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, (nn.Linear, nn.Conv1d)): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| if isinstance(m, nn.Embedding): | |
| nn.init.normal_(m.weight, std=0.02) | |
| for block in self.blocks: | |
| nn.init.zeros_(block.adaLN_modulation[-1].weight) | |
| nn.init.zeros_(block.adaLN_modulation[-1].bias) | |
| nn.init.zeros_(self.final_layer.adaLN_modulation[-1].weight) | |
| nn.init.zeros_(self.final_layer.adaLN_modulation[-1].bias) | |
| for m in self.time_mlp.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.normal_(m.weight, std=0.02) | |
| nn.init.normal_(self.token_embed, std=0.02) | |
| nn.init.normal_(self.text_cfg_embed, std=0.02) | |
| nn.init.normal_(self.token_cfg_embed, std=0.02) | |
| nn.init.normal_(self.prompt_embed, std=0.02) | |
| def forward(self, t, x_t, texts, mask, pred_mask, lang): | |
| b, l, c = x_t.shape | |
| texts_embed = self.text_embed(texts) # [b, l, h] | |
| idx_flat = torch.arange(c, device=x_t.device).view(1,1,c).expand(b,l,c).reshape(-1) # [b*l*c] | |
| embed = self.token_embed[idx_flat, x_t.reshape(-1), :].reshape(b, l, -1) # [b, l, c*h] | |
| x = self.input_proj(embed) # [b, l, h] | |
| x = x + (~pred_mask[:, :, None]).to(x.dtype) * self.prompt_embed | |
| cfg_mask = torch.rand(b, device=x.device) < self.cfg_dropout | |
| texts_embed = torch.where(cfg_mask[:, None, None], self.text_cfg_embed, texts_embed) | |
| x = torch.where(cfg_mask[:, None, None] & ~pred_mask[:, :, None], self.token_cfg_embed, x) | |
| x = torch.cat([texts_embed, x], dim=1) | |
| mask = mask.unsqueeze(-1).to(x.dtype) # [b, l, 1] | |
| t = t.to(x.dtype) | |
| t = self.time_mlp(self.time_embeddings(t)) | |
| lang_embed = self.lang_embed(lang) | |
| cond = torch.cat([t, lang_embed], dim=-1) | |
| attn_mask = mask * mask.transpose(1, 2) # [b, l, l] | |
| attn_mask = torch.zeros_like(attn_mask).masked_fill(attn_mask == 0, -torch.finfo(x.dtype).max).unsqueeze(1) # [b, 1, l, l] | |
| for block in self.blocks: | |
| x = block(x, cond, mask, attn_mask) | |
| x = x[:, texts.shape[-1]:, :] | |
| x = self.final_layer(x, cond) | |
| logits = torch.einsum("blh,ckh->blck", x, self.token_embed[:, :self.codebook_size, :]) | |
| return logits | |
| def infer(self, t, x_t, prompt_l, texts, lang, mask, rescale_cfg, cfg): | |
| b, l, c = x_t.shape # [b, l, c] | |
| texts_embed = self.text_embed(texts) # [1, l, h] | |
| texts_embed = torch.cat([texts_embed, | |
| self.text_cfg_embed[None, None, :].expand_as(texts_embed)], | |
| dim=0) | |
| idx_flat = torch.arange(c, device=x_t.device).view(1,1,c).expand(b,l,c).reshape(-1) # [b*l*c] | |
| embed = self.token_embed[idx_flat, x_t.reshape(-1), :].view(b, l, -1) | |
| x = self.input_proj(embed) # [b, l, h] | |
| x[:, :prompt_l, :] += self.prompt_embed | |
| t = self.time_mlp(self.time_embeddings(t)) | |
| lang_embed = self.lang_embed(lang) | |
| cond = torch.cat([t, lang_embed], dim=-1) | |
| # cfg | |
| x_cfg = x.clone() | |
| x_cfg[:, :prompt_l, :] = self.token_cfg_embed | |
| x = torch.cat([x, x_cfg], dim=0) # [2b, text+l, h] | |
| x = torch.cat([texts_embed, x], dim=1) | |
| attn_mask = mask * mask.transpose(1, 2) # [b, l, l] | |
| attn_mask = torch.zeros_like(attn_mask).masked_fill(attn_mask == 0, -torch.finfo(x.dtype).max).unsqueeze(1) # [b, 1, l, l] | |
| for block in self.blocks: | |
| x = block(x, cond, mask, attn_mask) | |
| x = x[:, texts.shape[-1]:, :] | |
| x = self.final_layer(x, cond) | |
| # cfg | |
| x, x_cfg = torch.split(x, [b, b], dim=0) | |
| x_std = x.std() | |
| x = x + cfg * (x - x_cfg) | |
| rescale_x = x * x_std / x.std() | |
| x = rescale_cfg * rescale_x + (1 - rescale_cfg) * x | |
| logits = torch.einsum("blh,ckh->blck", x, self.token_embed[:, :self.codebook_size, :]) | |
| return logits |