GibbsTTS / models /model.py
ydqmkkx's picture
update
0afe769
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.estimator import Decoder
from models.utils import interp_table, random_mask, sequence_mask, pad_nested_tensor, logits_top_p, gumbel_sample
class GibbsTTS_Model(nn.Module):
def __init__(self, configs):
super().__init__()
self.n_vocab = configs.n_vocab
self.special_codebook_size = configs.special_codebook_size # special_codebook_size = 1, pad_id: self.codebook_size
self.codebook_size = configs.codebook_size
self.quantizers_num = configs.quantizers_num
self.estimator = Decoder(configs)
self.register_buffer("dist_matrix", torch.empty(configs.quantizers_num, configs.codebook_size+configs.special_codebook_size, configs.codebook_size))
self.register_buffer("beta", torch.empty(configs.t_grid_size))
self.register_buffer("beta_dt", torch.empty(configs.t_grid_size))
def forward(self, texts, text_lengths, tokens, token_lengths, lang):
with torch.no_grad():
texts = pad_nested_tensor(texts, padding_value=0, left_padded=True)
device = texts.device
start_token = torch.full((texts.shape[0], 1), self.n_vocab-1, dtype=texts.dtype, device=device)
texts = torch.cat([texts, start_token], dim=-1)
text_lengths = text_lengths + 1
x_1 = pad_nested_tensor(tokens, padding_value=self.codebook_size).squeeze(1) # pad_id: self.codebook_size
b, l, c = x_1.shape
mask_start_ids = random_mask(token_lengths)
t = torch.rand(b, device=device)
beta = interp_table(t, self.beta)[:, None, None, None]
temp_flat = torch.arange(c).view(1, 1, c).to(device)
idx_flat = temp_flat.expand(b, l, c).reshape(-1)
dist_matrix = getattr(self, f'dist_matrix')
dist_flat = dist_matrix[idx_flat, x_1.reshape(-1)]
dist = dist_flat.view(b, l, c, -1) # [b, l, c, k]
logits = - dist * beta
x_t = gumbel_sample(logits, dim=-1)
pred_mask = torch.arange(l).to(device).unsqueeze(0).expand(b, -1) >= mask_start_ids.unsqueeze(1) # [b, l], bool
x_t = torch.where(pred_mask.unsqueeze(-1), x_t, x_1)
weights_c = 1 - torch.arange(c, device=device) / c # [c]
mask_left = sequence_mask(text_lengths, left_padded = True)
mask_right = sequence_mask(token_lengths) # [b, l]
mask = torch.cat([mask_left, mask_right], dim=1)
weights = (pred_mask & mask_right).float()[:, :, None] * weights_c[None, None, :] # [b, l, c]
logits = self.estimator(t, x_t, texts, mask, pred_mask, lang) # [b, l, c, k]
dfm_loss = F.cross_entropy(
logits.float().reshape(-1, self.codebook_size),
x_1.reshape(-1),
ignore_index=self.codebook_size,
reduction='none').reshape(b, l, c) # [b, l, c]
dfm_loss = (dfm_loss * weights).sum() / weights.sum()
return {f"dfm_loss": dfm_loss}, None
# def first_order_ctmc_solver(self, t, h, x_t, logits):
# b, l, c = x_t.shape
# temp_flat = torch.arange(self.quantizers_num, device=x_t.device).view(1, 1, self.quantizers_num)
# idx_flat = temp_flat.expand(b, l, c).reshape(-1)
# x_1 = gumbel_sample(logits, dim=-1)
# beta = interp_table(t, self.beta)[:, None, None, None]
# beta_dt = interp_table(t, self.beta_dt)[:, None, None, None]
# dist_matrix = getattr(self, f"dist_matrix")
# dist_flat = dist_matrix[idx_flat, x_1.reshape(-1)]
# dist = dist_flat.view(b, l, c, -1) # [b, l, c, k]
# d = torch.gather(dist, -1, x_t.unsqueeze(-1)) - dist
# p_t = F.softmax(- dist * beta, dim=-1) # [b, l, c, k]
# u = p_t * beta_dt * d.clamp_min(0)
# intensity = u.sum(dim=-1)
# jump_prob = 1. - torch.exp(-h * intensity)
# mask_jump = (torch.rand_like(x_t.to(u.dtype)) <= jump_prob) & (intensity > 0)
# if mask_jump.any():
# probs = u[mask_jump]
# x_t[mask_jump] = torch.multinomial(probs, 1).squeeze(-1)
# return x_t
def solver(self, t, h, x_t, logits):
b, l, c = x_t.shape
device = x_t.device
temp_flat = torch.arange(c, device=device).view(1, 1, c)
idx_flat = temp_flat.expand(b, l, c).reshape(-1)
x_1 = gumbel_sample(logits, dim=-1)
beta = interp_table(t, self.beta)[:, None, None, None]
beta_dt = interp_table(t, self.beta_dt)[:, None, None, None]
beta_next = interp_table(t + h, self.beta)[:, None, None, None]
dist_matrix = getattr(self, "dist_matrix")
dist_flat = dist_matrix[idx_flat, x_1.reshape(-1)]
dist = dist_flat.view(b, l, c, -1)
dist_cur = dist.gather(-1, x_t.unsqueeze(-1)).squeeze(-1)
delta = dist_cur.unsqueeze(-1) - dist
p_t = F.softmax(-dist * beta, dim=-1)
u = p_t * (beta_dt * delta).clamp_min(0)
intensity = u.sum(dim=-1)
p_next = F.softmax(-dist * beta_next, dim=-1)
dist_target = (p_next * dist).sum(dim=-1)
need = dist_cur - dist_target
progress = (u * delta).sum(dim=-1) / intensity.clamp_min(1e-8)
q_base = 1.0 - torch.exp(-h * intensity)
q_match = need / progress
feasible = torch.isfinite(q_match) & (q_match >= 0) & (q_match <= 1)
jump_prob = torch.where(feasible, q_match, q_base)
mask_jump = (torch.rand_like(jump_prob) <= jump_prob) & (intensity > 0)
if mask_jump.any():
probs = u[mask_jump]
x_t[mask_jump] = torch.multinomial(probs, 1).squeeze(-1)
return x_t
def synthesize(self, texts, lang, length, prompt_token, n_timesteps, temperature, top_p, rescale_cfg, cfg):
device=texts.device
start_token = torch.full((texts.shape[0], 1), self.n_vocab-1, dtype=texts.dtype, device=device)
texts = torch.cat([texts, start_token], dim=-1)
b, prompt_l, c = prompt_token.shape
l = prompt_l + length
x_0 = torch.randint(size=(b, l, c), high=self.codebook_size, device=device)
x_t = x_0.clone()
x_t[:, :prompt_l, :] = prompt_token
x_0 = x_0[:, prompt_l:, :]
mask = sequence_mask(torch.tensor(2 * b * [texts.shape[-1] + l], device=device), left_padded=True).unsqueeze(-1).float()
ts = torch.linspace(0, 1, steps=n_timesteps+1, device=device)
xs = []
for step in range(n_timesteps):
t = ts[step].unsqueeze(0)
h = ts[step+1].unsqueeze(0) - ts[step].unsqueeze(0)
logits = self.estimator.infer(t, x_t, prompt_l, texts, lang, mask, rescale_cfg=rescale_cfg, cfg=cfg)[:, prompt_l:, :, :]
if step == n_timesteps - 1:
x_t[:, prompt_l:, :] = logits.argmax(dim=-1)
xs.append(x_t[:, prompt_l:, :].clone())
break
logits = logits_top_p(logits, top_p) / temperature
x_t[:, prompt_l:, :] = self.solver(t, h, x_t[:, prompt_l:, :], logits)
xs.append(x_t[:, prompt_l:, :].clone())
return {
"x": x_t[:, prompt_l:, :],
"xs": xs
}