| from abc import ABC
|
|
|
| import torch
|
| import torch.nn.functional as F
|
|
|
| from VietTTS.flow.decoder import Decoder
|
|
|
|
|
| class BASECFM(torch.nn.Module, ABC):
|
| def __init__(
|
| self,
|
| n_feats,
|
| cfm_params,
|
| n_spks=1,
|
| spk_emb_dim=128,
|
| ):
|
| super().__init__()
|
| self.n_feats = n_feats
|
| self.n_spks = n_spks
|
| self.spk_emb_dim = spk_emb_dim
|
| self.solver = cfm_params.solver
|
| if hasattr(cfm_params, "sigma_min"):
|
| self.sigma_min = cfm_params.sigma_min
|
| else:
|
| self.sigma_min = 1e-4
|
|
|
| self.estimator = None
|
|
|
| @torch.inference_mode()
|
| def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
| """Forward diffusion
|
|
|
| Args:
|
| mu (torch.Tensor): output of encoder
|
| shape: (batch_size, n_feats, mel_timesteps)
|
| mask (torch.Tensor): output_mask
|
| shape: (batch_size, 1, mel_timesteps)
|
| n_timesteps (int): number of diffusion steps
|
| temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| shape: (batch_size, spk_emb_dim)
|
| cond: Not used but kept for future purposes
|
|
|
| Returns:
|
| sample: generated mel-spectrogram
|
| shape: (batch_size, n_feats, mel_timesteps)
|
| """
|
| z = torch.randn_like(mu) * temperature
|
| t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
| return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
|
|
| def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
| """
|
| Fixed euler solver for ODEs.
|
| Args:
|
| x (torch.Tensor): random noise
|
| t_span (torch.Tensor): n_timesteps interpolated
|
| shape: (n_timesteps + 1,)
|
| mu (torch.Tensor): output of encoder
|
| shape: (batch_size, n_feats, mel_timesteps)
|
| mask (torch.Tensor): output_mask
|
| shape: (batch_size, 1, mel_timesteps)
|
| spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| shape: (batch_size, spk_emb_dim)
|
| cond: Not used but kept for future purposes
|
| """
|
| t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
|
|
|
|
|
| sol = []
|
|
|
| for step in range(1, len(t_span)):
|
| dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
|
|
| x = x + dt * dphi_dt
|
| t = t + dt
|
| sol.append(x)
|
| if step < len(t_span) - 1:
|
| dt = t_span[step + 1] - t
|
|
|
| return sol[-1]
|
|
|
| def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
| """Computes diffusion loss
|
|
|
| Args:
|
| x1 (torch.Tensor): Target
|
| shape: (batch_size, n_feats, mel_timesteps)
|
| mask (torch.Tensor): target mask
|
| shape: (batch_size, 1, mel_timesteps)
|
| mu (torch.Tensor): output of encoder
|
| shape: (batch_size, n_feats, mel_timesteps)
|
| spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
| shape: (batch_size, spk_emb_dim)
|
|
|
| Returns:
|
| loss: conditional flow matching loss
|
| y: conditional flow
|
| shape: (batch_size, n_feats, mel_timesteps)
|
| """
|
| b, _, t = mu.shape
|
|
|
|
|
| t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
|
|
| z = torch.randn_like(x1)
|
|
|
| y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| u = x1 - (1 - self.sigma_min) * z
|
|
|
| loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
|
| torch.sum(mask) * u.shape[1]
|
| )
|
| return loss, y
|
|
|
|
|
| class CFM(BASECFM):
|
| def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
|
| super().__init__(
|
| n_feats=in_channels,
|
| cfm_params=cfm_params,
|
| n_spks=n_spks,
|
| spk_emb_dim=spk_emb_dim,
|
| )
|
|
|
| in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
|
|
|
| self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
|
|
|
|
|
| class ConditionalCFM(BASECFM):
|
| def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
| super().__init__(
|
| n_feats=in_channels,
|
| cfm_params=cfm_params,
|
| n_spks=n_spks,
|
| spk_emb_dim=spk_emb_dim,
|
| )
|
| self.t_scheduler = cfm_params.t_scheduler
|
| self.training_cfg_rate = cfm_params.training_cfg_rate
|
| self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
| in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
|
|
| self.estimator = estimator
|
|
|
| @torch.inference_mode()
|
| def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
| """Forward diffusion
|
|
|
| Args:
|
| mu (torch.Tensor): output of encoder
|
| shape: (batch_size, n_feats, mel_timesteps)
|
| mask (torch.Tensor): output_mask
|
| shape: (batch_size, 1, mel_timesteps)
|
| n_timesteps (int): number of diffusion steps
|
| temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| shape: (batch_size, spk_emb_dim)
|
| cond: Not used but kept for future purposes
|
|
|
| Returns:
|
| sample: generated mel-spectrogram
|
| shape: (batch_size, n_feats, mel_timesteps)
|
| """
|
| z = torch.randn_like(mu) * temperature
|
| t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
| if self.t_scheduler == 'cosine':
|
| t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
| return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
|
|
| def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
| """
|
| Fixed euler solver for ODEs.
|
| Args:
|
| x (torch.Tensor): random noise
|
| t_span (torch.Tensor): n_timesteps interpolated
|
| shape: (n_timesteps + 1,)
|
| mu (torch.Tensor): output of encoder
|
| shape: (batch_size, n_feats, mel_timesteps)
|
| mask (torch.Tensor): output_mask
|
| shape: (batch_size, 1, mel_timesteps)
|
| spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| shape: (batch_size, spk_emb_dim)
|
| cond: Not used but kept for future purposes
|
| """
|
| t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| t = t.unsqueeze(dim=0)
|
|
|
|
|
|
|
| sol = []
|
|
|
| for step in range(1, len(t_span)):
|
| dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
|
|
|
| if self.inference_cfg_rate > 0:
|
| cfg_dphi_dt = self.forward_estimator(
|
| x, mask,
|
| torch.zeros_like(mu), t,
|
| torch.zeros_like(spks) if spks is not None else None,
|
| torch.zeros_like(cond)
|
| )
|
| dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
|
| self.inference_cfg_rate * cfg_dphi_dt)
|
| x = x + dt * dphi_dt
|
| t = t + dt
|
| sol.append(x)
|
| if step < len(t_span) - 1:
|
| dt = t_span[step + 1] - t
|
|
|
| return sol[-1]
|
|
|
| def forward_estimator(self, x, mask, mu, t, spks, cond):
|
| if isinstance(self.estimator, torch.nn.Module):
|
| return self.estimator.forward(x, mask, mu, t, spks, cond)
|
| else:
|
| ort_inputs = {
|
| 'x': x.cpu().numpy(),
|
| 'mask': mask.cpu().numpy(),
|
| 'mu': mu.cpu().numpy(),
|
| 't': t.cpu().numpy(),
|
| 'spks': spks.cpu().numpy(),
|
| 'cond': cond.cpu().numpy()
|
| }
|
| output = self.estimator.run(None, ort_inputs)[0]
|
| return torch.tensor(output, dtype=x.dtype, device=x.device)
|
|
|
| def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
| """Computes diffusion loss
|
|
|
| Args:
|
| x1 (torch.Tensor): Target
|
| shape: (batch_size, n_feats, mel_timesteps)
|
| mask (torch.Tensor): target mask
|
| shape: (batch_size, 1, mel_timesteps)
|
| mu (torch.Tensor): output of encoder
|
| shape: (batch_size, n_feats, mel_timesteps)
|
| spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
| shape: (batch_size, spk_emb_dim)
|
|
|
| Returns:
|
| loss: conditional flow matching loss
|
| y: conditional flow
|
| shape: (batch_size, n_feats, mel_timesteps)
|
| """
|
| b, _, t = mu.shape
|
|
|
|
|
| t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
| if self.t_scheduler == 'cosine':
|
| t = 1 - torch.cos(t * 0.5 * torch.pi)
|
|
|
| z = torch.randn_like(x1)
|
|
|
| y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| u = x1 - (1 - self.sigma_min) * z
|
|
|
|
|
| if self.training_cfg_rate > 0:
|
| cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
| mu = mu * cfg_mask.view(-1, 1, 1)
|
| spks = spks * cfg_mask.view(-1, 1)
|
| cond = cond * cfg_mask.view(-1, 1, 1)
|
|
|
| pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
| loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
| return loss, y
|
|
|