Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from models.estimator import Decoder | |
| from models.utils import interp_table, random_mask, sequence_mask, pad_nested_tensor, logits_top_p, gumbel_sample | |
| class GibbsTTS_Model(nn.Module): | |
| def __init__(self, configs): | |
| super().__init__() | |
| self.n_vocab = configs.n_vocab | |
| self.special_codebook_size = configs.special_codebook_size # special_codebook_size = 1, pad_id: self.codebook_size | |
| self.codebook_size = configs.codebook_size | |
| self.quantizers_num = configs.quantizers_num | |
| self.estimator = Decoder(configs) | |
| self.register_buffer("dist_matrix", torch.empty(configs.quantizers_num, configs.codebook_size+configs.special_codebook_size, configs.codebook_size)) | |
| self.register_buffer("beta", torch.empty(configs.t_grid_size)) | |
| self.register_buffer("beta_dt", torch.empty(configs.t_grid_size)) | |
| def forward(self, texts, text_lengths, tokens, token_lengths, lang): | |
| with torch.no_grad(): | |
| texts = pad_nested_tensor(texts, padding_value=0, left_padded=True) | |
| device = texts.device | |
| start_token = torch.full((texts.shape[0], 1), self.n_vocab-1, dtype=texts.dtype, device=device) | |
| texts = torch.cat([texts, start_token], dim=-1) | |
| text_lengths = text_lengths + 1 | |
| x_1 = pad_nested_tensor(tokens, padding_value=self.codebook_size).squeeze(1) # pad_id: self.codebook_size | |
| b, l, c = x_1.shape | |
| mask_start_ids = random_mask(token_lengths) | |
| t = torch.rand(b, device=device) | |
| beta = interp_table(t, self.beta)[:, None, None, None] | |
| temp_flat = torch.arange(c).view(1, 1, c).to(device) | |
| idx_flat = temp_flat.expand(b, l, c).reshape(-1) | |
| dist_matrix = getattr(self, f'dist_matrix') | |
| dist_flat = dist_matrix[idx_flat, x_1.reshape(-1)] | |
| dist = dist_flat.view(b, l, c, -1) # [b, l, c, k] | |
| logits = - dist * beta | |
| x_t = gumbel_sample(logits, dim=-1) | |
| pred_mask = torch.arange(l).to(device).unsqueeze(0).expand(b, -1) >= mask_start_ids.unsqueeze(1) # [b, l], bool | |
| x_t = torch.where(pred_mask.unsqueeze(-1), x_t, x_1) | |
| weights_c = 1 - torch.arange(c, device=device) / c # [c] | |
| mask_left = sequence_mask(text_lengths, left_padded = True) | |
| mask_right = sequence_mask(token_lengths) # [b, l] | |
| mask = torch.cat([mask_left, mask_right], dim=1) | |
| weights = (pred_mask & mask_right).float()[:, :, None] * weights_c[None, None, :] # [b, l, c] | |
| logits = self.estimator(t, x_t, texts, mask, pred_mask, lang) # [b, l, c, k] | |
| dfm_loss = F.cross_entropy( | |
| logits.float().reshape(-1, self.codebook_size), | |
| x_1.reshape(-1), | |
| ignore_index=self.codebook_size, | |
| reduction='none').reshape(b, l, c) # [b, l, c] | |
| dfm_loss = (dfm_loss * weights).sum() / weights.sum() | |
| return {f"dfm_loss": dfm_loss}, None | |
| # def first_order_ctmc_solver(self, t, h, x_t, logits): | |
| # b, l, c = x_t.shape | |
| # temp_flat = torch.arange(self.quantizers_num, device=x_t.device).view(1, 1, self.quantizers_num) | |
| # idx_flat = temp_flat.expand(b, l, c).reshape(-1) | |
| # x_1 = gumbel_sample(logits, dim=-1) | |
| # beta = interp_table(t, self.beta)[:, None, None, None] | |
| # beta_dt = interp_table(t, self.beta_dt)[:, None, None, None] | |
| # dist_matrix = getattr(self, f"dist_matrix") | |
| # dist_flat = dist_matrix[idx_flat, x_1.reshape(-1)] | |
| # dist = dist_flat.view(b, l, c, -1) # [b, l, c, k] | |
| # d = torch.gather(dist, -1, x_t.unsqueeze(-1)) - dist | |
| # p_t = F.softmax(- dist * beta, dim=-1) # [b, l, c, k] | |
| # u = p_t * beta_dt * d.clamp_min(0) | |
| # intensity = u.sum(dim=-1) | |
| # jump_prob = 1. - torch.exp(-h * intensity) | |
| # mask_jump = (torch.rand_like(x_t.to(u.dtype)) <= jump_prob) & (intensity > 0) | |
| # if mask_jump.any(): | |
| # probs = u[mask_jump] | |
| # x_t[mask_jump] = torch.multinomial(probs, 1).squeeze(-1) | |
| # return x_t | |
| def solver(self, t, h, x_t, logits): | |
| b, l, c = x_t.shape | |
| device = x_t.device | |
| temp_flat = torch.arange(c, device=device).view(1, 1, c) | |
| idx_flat = temp_flat.expand(b, l, c).reshape(-1) | |
| x_1 = gumbel_sample(logits, dim=-1) | |
| beta = interp_table(t, self.beta)[:, None, None, None] | |
| beta_dt = interp_table(t, self.beta_dt)[:, None, None, None] | |
| beta_next = interp_table(t + h, self.beta)[:, None, None, None] | |
| dist_matrix = getattr(self, "dist_matrix") | |
| dist_flat = dist_matrix[idx_flat, x_1.reshape(-1)] | |
| dist = dist_flat.view(b, l, c, -1) | |
| dist_cur = dist.gather(-1, x_t.unsqueeze(-1)).squeeze(-1) | |
| delta = dist_cur.unsqueeze(-1) - dist | |
| p_t = F.softmax(-dist * beta, dim=-1) | |
| u = p_t * (beta_dt * delta).clamp_min(0) | |
| intensity = u.sum(dim=-1) | |
| p_next = F.softmax(-dist * beta_next, dim=-1) | |
| dist_target = (p_next * dist).sum(dim=-1) | |
| need = dist_cur - dist_target | |
| progress = (u * delta).sum(dim=-1) / intensity.clamp_min(1e-8) | |
| q_base = 1.0 - torch.exp(-h * intensity) | |
| q_match = need / progress | |
| feasible = torch.isfinite(q_match) & (q_match >= 0) & (q_match <= 1) | |
| jump_prob = torch.where(feasible, q_match, q_base) | |
| mask_jump = (torch.rand_like(jump_prob) <= jump_prob) & (intensity > 0) | |
| if mask_jump.any(): | |
| probs = u[mask_jump] | |
| x_t[mask_jump] = torch.multinomial(probs, 1).squeeze(-1) | |
| return x_t | |
| def synthesize(self, texts, lang, length, prompt_token, n_timesteps, temperature, top_p, rescale_cfg, cfg): | |
| device=texts.device | |
| start_token = torch.full((texts.shape[0], 1), self.n_vocab-1, dtype=texts.dtype, device=device) | |
| texts = torch.cat([texts, start_token], dim=-1) | |
| b, prompt_l, c = prompt_token.shape | |
| l = prompt_l + length | |
| x_0 = torch.randint(size=(b, l, c), high=self.codebook_size, device=device) | |
| x_t = x_0.clone() | |
| x_t[:, :prompt_l, :] = prompt_token | |
| x_0 = x_0[:, prompt_l:, :] | |
| mask = sequence_mask(torch.tensor(2 * b * [texts.shape[-1] + l], device=device), left_padded=True).unsqueeze(-1).float() | |
| ts = torch.linspace(0, 1, steps=n_timesteps+1, device=device) | |
| xs = [] | |
| for step in range(n_timesteps): | |
| t = ts[step].unsqueeze(0) | |
| h = ts[step+1].unsqueeze(0) - ts[step].unsqueeze(0) | |
| logits = self.estimator.infer(t, x_t, prompt_l, texts, lang, mask, rescale_cfg=rescale_cfg, cfg=cfg)[:, prompt_l:, :, :] | |
| if step == n_timesteps - 1: | |
| x_t[:, prompt_l:, :] = logits.argmax(dim=-1) | |
| xs.append(x_t[:, prompt_l:, :].clone()) | |
| break | |
| logits = logits_top_p(logits, top_p) / temperature | |
| x_t[:, prompt_l:, :] = self.solver(t, h, x_t[:, prompt_l:, :], logits) | |
| xs.append(x_t[:, prompt_l:, :].clone()) | |
| return { | |
| "x": x_t[:, prompt_l:, :], | |
| "xs": xs | |
| } |