diff --git "a/app.py" "b/app.py" deleted file mode 100644--- "a/app.py" +++ /dev/null @@ -1,6611 +0,0 @@ -""" -RVC + Beatrice v2 Voice Conversion - Single-file app for HuggingFace Spaces -RVC-Project + Beatrice v2 (fierce-cats/beatrice-trainer), consolidated into single file - -- Inference: RVC v2 (.pth) + Beatrice v2 (.pt.gz), CPU or GPU -- Training: RVC v2 + Beatrice v2, GPU recommended - -Usage: - CLI: python app.py infer -i input.wav -m model.pth -o output.wav - python app.py infer -i input.wav -m beatrice.pt.gz -o output.wav - Gradio: python app.py -""" -import os -import sys - -# MPS fallback for macOS -os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" - -import argparse -import gc -import gzip -import json as json_module -import logging -import math -import re -import shutil -import tempfile -import warnings - -# Suppress known harmless warnings from HF Spaces / torch internals -warnings.filterwarnings("ignore", message=".*torch.distributed.reduce_op.*", category=FutureWarning) -warnings.filterwarnings("ignore", message=".*torch.nn.utils.weight_norm.*", category=FutureWarning) -from collections import defaultdict -from fractions import Fraction -from functools import partial -from pathlib import Path -from random import Random -from typing import Optional, List, Tuple, Union, BinaryIO, Literal, Sequence, Iterable, Callable - -import gradio as gr -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import Conv1d, ConvTranspose1d -from torch.nn.utils import weight_norm, remove_weight_norm -import librosa -import pyworld -import soundfile as sf -import torchaudio -from scipy import signal -from huggingface_hub import hf_hub_download -from tqdm.auto import tqdm - -# 48 Hz high-pass filter to remove low-frequency artifacts (same as Applio) -FILTER_ORDER = 5 -CUTOFF_FREQUENCY = 48 # Hz -SAMPLE_RATE = 16000 # Hz -bh, ah = signal.butter(N=FILTER_ORDER, Wn=CUTOFF_FREQUENCY, btype="high", fs=SAMPLE_RATE) - -def sanitize_model_name(name: str) -> str: - """Sanitize model name for safe use in file paths""" - name = os.path.basename(name.strip()) - name = re.sub(r'[^\w\-.]', '_', name) - return name or "unnamed_model" - -def list_rvc_models() -> list: - """Scan the weights/ directory and return a sorted list of .pth model filenames.""" - weights_dir = Path("weights") - if not weights_dir.exists(): - return [] - return sorted([p.name for p in weights_dir.glob("*.pth")]) - -# Default example model -DEFAULT_MODEL_REPO = "audo/Benee-RVC" -DEFAULT_MODEL_FILE = "BENEE8000.pth" -DEFAULT_INDEX_FILE = "added_IVF1054_Flat_nprobe_8.index" - -# RVC v2 pretrained weights from official repo -RVC_PRETRAINED_REPO = "lj1995/VoiceConversionWebUI" -RVC_PRETRAINED_V2 = { - # Generator with f0 (pitch) - "f0G48k": "pretrained_v2/f0G48k.pth", - "f0G40k": "pretrained_v2/f0G40k.pth", - "f0G32k": "pretrained_v2/f0G32k.pth", - # Discriminator with f0 - "f0D48k": "pretrained_v2/f0D48k.pth", - "f0D40k": "pretrained_v2/f0D40k.pth", - "f0D32k": "pretrained_v2/f0D32k.pth", - # Generator without f0 - "G48k": "pretrained_v2/G48k.pth", - "G40k": "pretrained_v2/G40k.pth", - "G32k": "pretrained_v2/G32k.pth", - # Discriminator without f0 - "D48k": "pretrained_v2/D48k.pth", - "D40k": "pretrained_v2/D40k.pth", - "D32k": "pretrained_v2/D32k.pth", -} - -def download_pretrained_rvc(name: str) -> str: - """Download RVC v2 pretrained weights from HuggingFace""" - if name not in RVC_PRETRAINED_V2: - raise ValueError(f"Unknown pretrained: {name}. Available: {list(RVC_PRETRAINED_V2.keys())}") - filepath = RVC_PRETRAINED_V2[name] - logger.info(f"Downloading pretrained {name} from {RVC_PRETRAINED_REPO}...") - return hf_hub_download(repo_id=RVC_PRETRAINED_REPO, filename=filepath) - -# Beatrice v2 pretrained assets -BEATRICE_REPO = "fierce-cats/beatrice-trainer" -BEATRICE_PRETRAINED = { - "phone_extractor": "assets/pretrained/122_checkpoint_03000000.pt", - "pitch_estimator": "assets/pretrained/104_3_checkpoint_00300000.pt", - "pretrained_model": "assets/pretrained/151_checkpoint_libritts_r_200_02750000.pt.gz", -} - -def download_beatrice_asset(name: str) -> str: - """Download Beatrice v2 pretrained asset from HuggingFace""" - if name not in BEATRICE_PRETRAINED: - raise ValueError(f"Unknown asset: {name}. Available: {list(BEATRICE_PRETRAINED.keys())}") - filepath = BEATRICE_PRETRAINED[name] - logger.info(f"Downloading Beatrice asset {name} from {BEATRICE_REPO}...") - return hf_hub_download(repo_id=BEATRICE_REPO, filename=filepath) - -def download_beatrice_augmentation(): - """Download Beatrice augmentation assets (noise + IR) - optional for training""" - try: - from huggingface_hub import snapshot_download - cache_dir = snapshot_download(repo_id=BEATRICE_REPO, allow_patterns=["assets/noise/*", "assets/ir/*"]) - noise_dir = os.path.join(cache_dir, "assets", "noise") - ir_dir = os.path.join(cache_dir, "assets", "ir") - if os.path.isdir(noise_dir) and os.path.isdir(ir_dir): - return noise_dir, ir_dir - return None, None - except Exception as e: - logger.warning(f"Could not download augmentation assets: {e}") - return None, None - -def load_pretrained_weights(model: nn.Module, pretrained_path: str) -> None: - """Load pretrained weights into model, handling speaker embedding mismatch""" - logger.info(f"Loading pretrained weights: {pretrained_path}") - state_dict = torch.load(pretrained_path, map_location="cpu", weights_only=True) - # Handle different checkpoint formats - if "model" in state_dict: - state_dict = state_dict["model"] - - # Filter out mismatched keys, but handle emb_g specially - model_state = model.state_dict() - filtered_state = {} - skipped = [] - for k, v in state_dict.items(): - if k in model_state: - if v.shape == model_state[k].shape: - filtered_state[k] = v - elif k == "emb_g.weight": - # Initialize our speaker embedding with mean of pretrained embeddings - # This gives a much better starting point than random initialization - mean_emb = v.mean(dim=0, keepdim=True) # [1, 256] - num_speakers = model_state[k].shape[0] - filtered_state[k] = mean_emb.expand(num_speakers, -1).clone() - logger.info(f"Initialized emb_g from pretrained mean ({v.shape[0]} -> {num_speakers} speakers)") - else: - skipped.append(f"{k}: {v.shape} vs {model_state[k].shape}") - else: - skipped.append(f"{k}: not in model") - - if skipped: - logger.info(f"Skipped {len(skipped)} mismatched keys") - - model.load_state_dict(filtered_state, strict=False) - logger.info(f"Loaded {len(filtered_state)}/{len(state_dict)} pretrained weights") - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# Device selection: -# - Inference: Always CPU (HF Spaces free tier, also works everywhere) -# - Training: GPU if available for speed, CPU fallback -device = torch.device("cpu") # For inference -train_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # For training - -logger.info(f"Inference device: {device}") -logger.info(f"Training device: {train_device}") - -# ============================================================ -# COMMONS - Helper functions from infer/lib/infer_pack/commons.py -# ============================================================ - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - -def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - -@torch.jit.script -def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): - n_channels_int = n_channels[0] - in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - return t_act * s_act - -def slice_segments(x, ids_str, segment_size=4): - """Slice segments from tensor""" - ret = torch.zeros_like(x[:, :, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - return ret - -def slice_segments2(x, ids_str, segment_size=4): - """Slice segments from 2D tensor""" - ret = torch.zeros_like(x[:, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, idx_str:idx_end] - return ret - -def rand_slice_segments(x, x_lengths=None, segment_size=4): - """Random slice segments""" - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=1) - ids_str = (torch.rand([b], device=x.device) * ids_str_max.float()).long() - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str - -# ============================================================ -# MODULES - From infer/lib/infer_pack/modules.py -# ============================================================ - -LRELU_SLOPE = 0.1 - -class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): - super().__init__() - self.channels = channels - self.eps = eps - self.gamma = nn.Parameter(torch.ones(channels)) - self.beta = nn.Parameter(torch.zeros(channels)) - - def forward(self, x): - x = x.transpose(1, -1) - x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) - return x.transpose(1, -1) - -class WN(nn.Module): - def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): - super().__init__() - assert kernel_size % 2 == 1 - self.hidden_channels = hidden_channels - self.kernel_size = (kernel_size,) - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - self.p_dropout = float(p_dropout) - - self.in_layers = nn.ModuleList() - self.res_skip_layers = nn.ModuleList() - self.drop = nn.Dropout(float(p_dropout)) - - if gin_channels != 0: - cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) - self.cond_layer = weight_norm(cond_layer, name="weight") - - for i in range(n_layers): - dilation = dilation_rate ** i - padding = int((kernel_size * dilation - dilation) / 2) - in_layer = nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding) - in_layer = weight_norm(in_layer, name="weight") - self.in_layers.append(in_layer) - - if i < n_layers - 1: - res_skip_channels = 2 * hidden_channels - else: - res_skip_channels = hidden_channels - res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1) - res_skip_layer = weight_norm(res_skip_layer, name="weight") - self.res_skip_layers.append(res_skip_layer) - - def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None): - output = torch.zeros_like(x) - n_channels_tensor = torch.IntTensor([self.hidden_channels]) - - if g is not None: - g = self.cond_layer(g) - - for i, (in_layer, res_skip_layer) in enumerate(zip(self.in_layers, self.res_skip_layers)): - x_in = in_layer(x) - if g is not None: - cond_offset = i * 2 * self.hidden_channels - g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] - else: - g_l = torch.zeros_like(x_in) - - acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) - acts = self.drop(acts) - res_skip_acts = res_skip_layer(acts) - - if i < self.n_layers - 1: - res_acts = res_skip_acts[:, :self.hidden_channels, :] - x = (x + res_acts) * x_mask - output = output + res_skip_acts[:, self.hidden_channels:, :] - else: - output = output + res_skip_acts - return output * x_mask - - def remove_weight_norm(self): - if self.gin_channels != 0: - remove_weight_norm(self.cond_layer) - for l in self.in_layers: - remove_weight_norm(l) - for l in self.res_skip_layers: - remove_weight_norm(l) - -class ResBlock1(nn.Module): - def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): - super().__init__() - self.convs1 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]))), - ]) - self.convs1.apply(init_weights) - self.convs2 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), - ]) - self.convs2.apply(init_weights) - self.lrelu_slope = LRELU_SLOPE - - def forward(self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None): - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, self.lrelu_slope) - if x_mask is not None: - xt = xt * x_mask - xt = c1(xt) - xt = F.leaky_relu(xt, self.lrelu_slope) - if x_mask is not None: - xt = xt * x_mask - xt = c2(xt) - x = xt + x - if x_mask is not None: - x = x * x_mask - return x - - def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) - -class ResBlock2(nn.Module): - def __init__(self, channels, kernel_size=3, dilation=(1, 3)): - super().__init__() - self.convs = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), - ]) - self.convs.apply(init_weights) - self.lrelu_slope = LRELU_SLOPE - - def forward(self, x, x_mask: Optional[torch.Tensor] = None): - for c in self.convs: - xt = F.leaky_relu(x, self.lrelu_slope) - if x_mask is not None: - xt = xt * x_mask - xt = c(xt) - x = xt + x - if x_mask is not None: - x = x * x_mask - return x - - def remove_weight_norm(self): - for l in self.convs: - remove_weight_norm(l) - -class Flip(nn.Module): - def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): - x = torch.flip(x, [1]) - if not reverse: - logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) - return x, logdet - else: - return x, torch.zeros([1], device=x.device) - -class ResidualCouplingLayer(nn.Module): - def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False): - assert channels % 2 == 0 - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.half_channels = channels // 2 - self.mean_only = mean_only - - self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) - self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=float(p_dropout), gin_channels=gin_channels) - self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) - self.post.weight.data.zero_() - self.post.bias.data.zero_() - - def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): - x0, x1 = torch.split(x, [self.half_channels] * 2, 1) - h = self.pre(x0) * x_mask - h = self.enc(h, x_mask, g=g) - stats = self.post(h) * x_mask - if not self.mean_only: - m, logs = torch.split(stats, [self.half_channels] * 2, 1) - else: - m = stats - logs = torch.zeros_like(m) - - if not reverse: - x1 = m + x1 * torch.exp(logs) * x_mask - x = torch.cat([x0, x1], 1) - logdet = torch.sum(logs, [1, 2]) - return x, logdet - else: - x1 = (x1 - m) * torch.exp(-logs) * x_mask - x = torch.cat([x0, x1], 1) - return x, torch.zeros([1]) - - def remove_weight_norm(self): - self.enc.remove_weight_norm() - -# ============================================================ -# ATTENTIONS - From infer/lib/infer_pack/attentions.py -# ============================================================ - -class MultiHeadAttention(nn.Module): - def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, proximal_bias=False, proximal_init=False): - super().__init__() - assert channels % n_heads == 0 - self.channels = channels - self.out_channels = out_channels - self.n_heads = n_heads - self.p_dropout = p_dropout - self.window_size = window_size - self.heads_share = heads_share - self.proximal_bias = proximal_bias - self.proximal_init = proximal_init - - self.k_channels = channels // n_heads - self.conv_q = nn.Conv1d(channels, channels, 1) - self.conv_k = nn.Conv1d(channels, channels, 1) - self.conv_v = nn.Conv1d(channels, channels, 1) - self.conv_o = nn.Conv1d(channels, out_channels, 1) - self.drop = nn.Dropout(p_dropout) - - if window_size is not None: - n_heads_rel = 1 if heads_share else n_heads - rel_stddev = self.k_channels ** -0.5 - self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) - self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) - - nn.init.xavier_uniform_(self.conv_q.weight) - nn.init.xavier_uniform_(self.conv_k.weight) - nn.init.xavier_uniform_(self.conv_v.weight) - if proximal_init: - with torch.no_grad(): - self.conv_k.weight.copy_(self.conv_q.weight) - self.conv_k.bias.copy_(self.conv_q.bias) - - def forward(self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - q = self.conv_q(x) - k = self.conv_k(c) - v = self.conv_v(c) - x, _ = self.attention(q, k, v, mask=attn_mask) - x = self.conv_o(x) - return x - - def attention(self, query, key, value, mask=None): - b, d, t_s = key.size() - t_t = query.size(2) - query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) - key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) - value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) - - scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) - if self.window_size is not None: - key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) - rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) - scores_local = self._relative_position_to_absolute_position(rel_logits) - scores = scores + scores_local - if self.proximal_bias: - scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) - if mask is not None: - scores = scores.masked_fill(mask == 0, -1e4) - p_attn = F.softmax(scores, dim=-1) - p_attn = self.drop(p_attn) - output = torch.matmul(p_attn, value) - if self.window_size is not None: - relative_weights = self._absolute_position_to_relative_position(p_attn) - value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) - output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) - output = output.transpose(2, 3).contiguous().view(b, d, t_t) - return output, p_attn - - def _matmul_with_relative_values(self, x, y): - return torch.matmul(x, y.unsqueeze(0)) - - def _matmul_with_relative_keys(self, x, y): - return torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) - - def _get_relative_embeddings(self, relative_embeddings, length): - pad_length = max(length - (self.window_size + 1), 0) - slice_start_position = max((self.window_size + 1) - length, 0) - slice_end_position = slice_start_position + 2 * length - 1 - if pad_length > 0: - padded_relative_embeddings = F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) - else: - padded_relative_embeddings = relative_embeddings - return padded_relative_embeddings[:, slice_start_position:slice_end_position] - - def _relative_position_to_absolute_position(self, x): - batch, heads, length, _ = x.size() - x = F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0]) - x_flat = x.view([batch, heads, length * 2 * length]) - x_flat = F.pad(x_flat, [0, int(length) - 1, 0, 0, 0, 0]) - x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:] - return x_final - - def _absolute_position_to_relative_position(self, x): - batch, heads, length, _ = x.size() - x = F.pad(x, [0, int(length) - 1, 0, 0, 0, 0, 0, 0]) - x_flat = x.view([batch, heads, int(length ** 2) + int(length * (length - 1))]) - x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) - x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] - return x_final - - def _attention_bias_proximal(self, length): - r = torch.arange(length, dtype=torch.float32) - diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) - return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) - -class FFN(nn.Module): - def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.causal = causal - self.is_activation = activation == "gelu" - - self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) - self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) - self.drop = nn.Dropout(p_dropout) - - def forward(self, x: torch.Tensor, x_mask: torch.Tensor): - x = self.conv_1(self._padding(x, x_mask)) - if self.is_activation: - x = x * torch.sigmoid(1.702 * x) - else: - x = torch.relu(x) - x = self.drop(x) - x = self.conv_2(self._padding(x, x_mask)) - return x * x_mask - - def _padding(self, x, x_mask): - if self.causal: - if self.kernel_size == 1: - return x * x_mask - pad_l = self.kernel_size - 1 - return F.pad(x * x_mask, [pad_l, 0, 0, 0, 0, 0]) - else: - if self.kernel_size == 1: - return x * x_mask - pad_l = (self.kernel_size - 1) // 2 - pad_r = self.kernel_size // 2 - return F.pad(x * x_mask, [pad_l, pad_r, 0, 0, 0, 0]) - -class Encoder(nn.Module): - def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, window_size=10, **kwargs): - super().__init__() - self.hidden_channels = hidden_channels - self.n_layers = int(n_layers) - self.drop = nn.Dropout(p_dropout) - self.attn_layers = nn.ModuleList() - self.norm_layers_1 = nn.ModuleList() - self.ffn_layers = nn.ModuleList() - self.norm_layers_2 = nn.ModuleList() - - for i in range(self.n_layers): - self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) - self.norm_layers_1.append(LayerNorm(hidden_channels)) - self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) - self.norm_layers_2.append(LayerNorm(hidden_channels)) - - def forward(self, x, x_mask): - attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) - x = x * x_mask - for attn, norm1, ffn, norm2 in zip(self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2): - y = attn(x, x, attn_mask) - y = self.drop(y) - x = norm1(x + y) - y = ffn(x, x_mask) - y = self.drop(y) - x = norm2(x + y) - return x * x_mask - -# ============================================================ -# MODELS - From infer/lib/infer_pack/models.py -# ============================================================ - -sr2sr = {"32k": 32000, "40k": 40000, "48k": 48000} - -class TextEncoder256(nn.Module): - def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True): - super().__init__() - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.emb_phone = nn.Linear(256, hidden_channels) - self.lrelu = nn.LeakyReLU(0.1, inplace=True) - if f0: - self.emb_pitch = nn.Embedding(256, hidden_channels) - self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor): - if pitch is None: - x = self.emb_phone(phone) - else: - x = self.emb_phone(phone) + self.emb_pitch(pitch) - x = x * math.sqrt(self.hidden_channels) - x = self.lrelu(x) - x = torch.transpose(x, 1, -1) - x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype) - x = self.encoder(x * x_mask, x_mask) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - return m, logs, x_mask - -class TextEncoder768(nn.Module): - def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True): - super().__init__() - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.emb_phone = nn.Linear(768, hidden_channels) - self.lrelu = nn.LeakyReLU(0.1, inplace=True) - if f0: - self.emb_pitch = nn.Embedding(256, hidden_channels) - self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor): - if pitch is None: - x = self.emb_phone(phone) - else: - x = self.emb_phone(phone) + self.emb_pitch(pitch) - x = x * math.sqrt(self.hidden_channels) - x = self.lrelu(x) - x = torch.transpose(x, 1, -1) - x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype) - x = self.encoder(x * x_mask, x_mask) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - return m, logs, x_mask - -class ResidualCouplingBlock(nn.Module): - def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): - super().__init__() - self.n_flows = n_flows - self.flows = nn.ModuleList() - for i in range(n_flows): - self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) - self.flows.append(Flip()) - - def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): - if not reverse: - for flow in self.flows: - x, _ = flow(x, x_mask, g=g, reverse=reverse) - else: - for flow in self.flows[::-1]: - x, _ = flow.forward(x, x_mask, g=g, reverse=reverse) - return x - - def remove_weight_norm(self): - for i in range(self.n_flows): - self.flows[i * 2].remove_weight_norm() - -class PosteriorEncoder(nn.Module): - def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0): - super().__init__() - self.out_channels = out_channels - self.pre = nn.Conv1d(in_channels, hidden_channels, 1) - self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None): - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) - x = self.pre(x) * x_mask - x = self.enc(x, x_mask, g=g) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask - return z, m, logs, x_mask - - def remove_weight_norm(self): - self.enc.remove_weight_norm() - -class Generator(nn.Module): - def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): - super().__init__() - self.num_kernels = len(resblock_kernel_sizes) - self.num_upsamples = len(upsample_rates) - self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) - resblock_class = ResBlock1 if resblock == "1" else ResBlock2 - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2))) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock_class(ch, k, d)) - - self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) - self.ups.apply(init_weights) - - if gin_channels != 0: - self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - - def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None): - x = self.conv_pre(x) - if g is not None: - x = x + self.cond(g) - for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) - else: - xs += self.resblocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - return x - - def remove_weight_norm(self): - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - -class SineGen(nn.Module): - def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0): - super().__init__() - self.sine_amp = sine_amp - self.noise_std = noise_std - self.harmonic_num = harmonic_num - self.dim = harmonic_num + 1 - self.sampling_rate = samp_rate - self.voiced_threshold = voiced_threshold - - def _f02uv(self, f0): - uv = torch.ones_like(f0) - uv = uv * (f0 > self.voiced_threshold) - return uv.float() - - def forward(self, f0: torch.Tensor, upp: int): - with torch.no_grad(): - f0 = f0[:, None].transpose(1, 2) - f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) - f0_buf[:, :, 0] = f0[:, :, 0] - for idx in range(self.harmonic_num): - f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2) - rad_values = (f0_buf / self.sampling_rate) % 1 - rand_ini = torch.rand(f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device) - rand_ini[:, 0] = 0 - rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini - tmp_over_one = torch.cumsum(rad_values, 1) - tmp_over_one *= upp - tmp_over_one = F.interpolate(tmp_over_one.transpose(2, 1), scale_factor=float(upp), mode="linear", align_corners=True).transpose(2, 1) - rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1) - tmp_over_one %= 1 - tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 - cumsum_shift = torch.zeros_like(rad_values) - cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 - sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi) - sine_waves = sine_waves * self.sine_amp - uv = self._f02uv(f0) - uv = F.interpolate(uv.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1) - noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 - noise = noise_amp * torch.randn_like(sine_waves) - sine_waves = sine_waves * uv + noise - return sine_waves, uv, noise - -class SourceModuleHnNSF(nn.Module): - def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0, is_half=False): - super().__init__() - self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod) - self.l_linear = nn.Linear(harmonic_num + 1, 1) - self.l_tanh = nn.Tanh() - - def forward(self, x: torch.Tensor, upp: int = 1): - sine_wavs, uv, _ = self.l_sin_gen(x, upp) - sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype) - sine_merge = self.l_tanh(self.l_linear(sine_wavs)) - return sine_merge, None, None - -class GeneratorNSF(nn.Module): - def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, sr, is_half=False): - super().__init__() - self.num_kernels = len(resblock_kernel_sizes) - self.num_upsamples = len(upsample_rates) - self.f0_upsamp = nn.Upsample(scale_factor=math.prod(upsample_rates)) - self.m_source = SourceModuleHnNSF(sampling_rate=sr, harmonic_num=0, is_half=is_half) - self.noise_convs = nn.ModuleList() - self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) - resblock_class = ResBlock1 if resblock == "1" else ResBlock2 - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - c_cur = upsample_initial_channel // (2 ** (i + 1)) - self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2))) - if i + 1 < len(upsample_rates): - stride_f0 = math.prod(upsample_rates[i + 1:]) - self.noise_convs.append(Conv1d(1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)) - else: - self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock_class(ch, k, d)) - - self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) - self.ups.apply(init_weights) - if gin_channels != 0: - self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - self.upp = math.prod(upsample_rates) - self.lrelu_slope = LRELU_SLOPE - - def forward(self, x, f0, g: Optional[torch.Tensor] = None): - har_source, _, _ = self.m_source(f0, self.upp) - har_source = har_source.transpose(1, 2) - x = self.conv_pre(x) - if g is not None: - x = x + self.cond(g) - for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)): - if i < self.num_upsamples: - x = F.leaky_relu(x, self.lrelu_slope) - x = ups(x) - x_source = noise_convs(har_source) - x = x + x_source - xs = None - l = [i * self.num_kernels + j for j in range(self.num_kernels)] - for j, resblock in enumerate(self.resblocks): - if j in l: - if xs is None: - xs = resblock(x) - else: - xs += resblock(x) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - return x - - def remove_weight_norm(self): - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - -# Synthesizer classes for different model versions -class SynthesizerTrnMs256NSFsid(nn.Module): - """RVC v1 model with f0""" - def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, **kwargs): - super().__init__() - if isinstance(sr, str): - sr = sr2sr[sr] - self.segment_size = segment_size - self.gin_channels = gin_channels - self.spk_embed_dim = spk_embed_dim - self.enc_p = TextEncoder256(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) - self.dec = GeneratorNSF(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, is_half=kwargs.get("is_half", False)) - self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) - self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) - - @torch.jit.export - def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: torch.Tensor, nsff0: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): - g = self.emb_g(sid).unsqueeze(-1) - m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) - z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask - if rate is not None: - head = int(z_p.shape[2] * (1 - rate.item())) - z_p = z_p[:, :, head:] - x_mask = x_mask[:, :, head:] - nsff0 = nsff0[:, head:] - z = self.flow(z_p, x_mask, g=g, reverse=True) - o = self.dec(z * x_mask, nsff0, g=g) - return o, x_mask, (z, z_p, m_p, logs_p) - -class SynthesizerTrnMs768NSFsid(nn.Module): - """RVC v2 model with f0""" - def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, **kwargs): - super().__init__() - if isinstance(sr, str): - sr = sr2sr[sr] - self.segment_size = segment_size - self.gin_channels = gin_channels - self.spk_embed_dim = spk_embed_dim - self.enc_p = TextEncoder768(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) - self.dec = GeneratorNSF(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, is_half=kwargs.get("is_half", False)) - self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) - self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) - - def forward(self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds): - """Training forward pass""" - g = self.emb_g(ds).unsqueeze(-1) - m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) - z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) - z_p = self.flow(z, y_mask, g=g) - z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size) - pitchf = slice_segments2(pitchf, ids_slice, self.segment_size) - o = self.dec(z_slice, pitchf, g=g) - return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) - - @torch.jit.export - def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: torch.Tensor, nsff0: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): - g = self.emb_g(sid).unsqueeze(-1) - m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) - z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask - if rate is not None: - head = int(z_p.shape[2] * (1.0 - rate.item())) - z_p = z_p[:, :, head:] - x_mask = x_mask[:, :, head:] - nsff0 = nsff0[:, head:] - z = self.flow(z_p, x_mask, g=g, reverse=True) - o = self.dec(z * x_mask, nsff0, g=g) - return o, x_mask, (z, z_p, m_p, logs_p) - -class SynthesizerTrnMs256NSFsid_nono(nn.Module): - """RVC v1 model without f0""" - def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr=None, **kwargs): - super().__init__() - self.segment_size = segment_size - self.gin_channels = gin_channels - self.enc_p = TextEncoder256(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), f0=False) - self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) - self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) - self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) - - @torch.jit.export - def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): - g = self.emb_g(sid).unsqueeze(-1) - m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) - z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask - if rate is not None: - head = int(z_p.shape[2] * (1.0 - rate.item())) - z_p = z_p[:, :, head:] - x_mask = x_mask[:, :, head:] - z = self.flow(z_p, x_mask, g=g, reverse=True) - o = self.dec(z * x_mask, g=g) - return o, x_mask, (z, z_p, m_p, logs_p) - -class SynthesizerTrnMs768NSFsid_nono(nn.Module): - """RVC v2 model without f0""" - def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr=None, **kwargs): - super().__init__() - self.segment_size = segment_size - self.gin_channels = gin_channels - self.enc_p = TextEncoder768(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), f0=False) - self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) - self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) - self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) - - @torch.jit.export - def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): - g = self.emb_g(sid).unsqueeze(-1) - m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) - z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask - if rate is not None: - head = int(z_p.shape[2] * (1.0 - rate.item())) - z_p = z_p[:, :, head:] - x_mask = x_mask[:, :, head:] - z = self.flow(z_p, x_mask, g=g, reverse=True) - o = self.dec(z * x_mask, g=g) - return o, x_mask, (z, z_p, m_p, logs_p) - -# ============================================================ -# DISCRIMINATOR - For training -# ============================================================ - -class DiscriminatorS(nn.Module): - def __init__(self, use_spectral_norm=False): - super().__init__() - norm_f = nn.utils.spectral_norm if use_spectral_norm else weight_norm - self.convs = nn.ModuleList([ - norm_f(Conv1d(1, 16, 15, 1, padding=7)), - norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), - norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ]) - self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - - def forward(self, x): - fmap = [] - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, 0.1) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - return x, fmap - -class DiscriminatorP(nn.Module): - def __init__(self, period, use_spectral_norm=False): - super().__init__() - self.period = period - norm_f = nn.utils.spectral_norm if use_spectral_norm else weight_norm - self.convs = nn.ModuleList([ - norm_f(nn.Conv2d(1, 32, (5, 1), (3, 1), padding=(2, 0))), - norm_f(nn.Conv2d(32, 128, (5, 1), (3, 1), padding=(2, 0))), - norm_f(nn.Conv2d(128, 512, (5, 1), (3, 1), padding=(2, 0))), - norm_f(nn.Conv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0))), - norm_f(nn.Conv2d(1024, 1024, (5, 1), 1, padding=(2, 0))), - ]) - self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - - def forward(self, x): - fmap = [] - b, c, t = x.shape - if t % self.period != 0: - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, 0.1) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - return x, fmap - -class MultiPeriodDiscriminator(nn.Module): - def __init__(self, use_spectral_norm=False): - super().__init__() - periods = [2, 3, 5, 7, 11, 17, 23, 37] # 8 periods for v2 pretrained (9 total discriminators) - self.discriminators = nn.ModuleList( - [DiscriminatorS(use_spectral_norm)] + - [DiscriminatorP(p, use_spectral_norm) for p in periods] - ) - - def forward(self, y, y_hat): - y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], [] - for d in self.discriminators: - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - y_d_rs.append(y_d_r) - y_d_gs.append(y_d_g) - fmap_rs.append(fmap_r) - fmap_gs.append(fmap_g) - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - -# ============================================================ -# TRAINING LOSSES -# ============================================================ - -def feature_loss(fmap_r, fmap_g): - loss = 0 - for dr, dg in zip(fmap_r, fmap_g): - for rl, gl in zip(dr, dg): - loss += torch.mean(torch.abs(rl.float().detach() - gl.float())) - return loss * 2 - -def discriminator_loss(disc_real_outputs, disc_generated_outputs): - loss = 0 - for dr, dg in zip(disc_real_outputs, disc_generated_outputs): - loss += torch.mean((1 - dr.float()) ** 2) + torch.mean(dg.float() ** 2) - return loss - -def generator_loss(disc_outputs): - loss = 0 - for dg in disc_outputs: - loss += torch.mean((1 - dg.float()) ** 2) - return loss - -def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): - z_p, logs_q, m_p, logs_p, z_mask = [x.float() for x in [z_p, logs_q, m_p, logs_p, z_mask]] - kl = logs_p - logs_q - 0.5 + 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) - return torch.sum(kl * z_mask) / torch.sum(z_mask) - -# ============================================================ -# HUBERT EXTRACTION - Using torchaudio bundle -# ============================================================ - -# ContentVec model for v1 (256-dim) and HuBERT for v2 (768-dim) -_contentvec_model = None # For v1 models (256-dim output) -_hubert_model = None # For v2 models (768-dim output) -_hubert_bundle = None - -CONTENTVEC_REPO = "IAHispano/Applio" -CONTENTVEC_MODEL = "Resources/embedders/contentvec/pytorch_model.bin" -CONTENTVEC_CONFIG = "Resources/embedders/contentvec/config.json" - -def load_contentvec(): - """Load ContentVec model from HuggingFace for v1 models (256-dim output)""" - global _contentvec_model - if _contentvec_model is None: - try: - from transformers import HubertModel, HubertConfig - logger.info("Loading ContentVec model from HuggingFace...") - - # Download model files - model_path = hf_hub_download(repo_id=CONTENTVEC_REPO, filename=CONTENTVEC_MODEL) - config_path = hf_hub_download(repo_id=CONTENTVEC_REPO, filename=CONTENTVEC_CONFIG) - - # Create model with final_proj layer - class HubertModelWithFinalProj(HubertModel): - def __init__(self, config): - super().__init__(config) - self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) - - config = HubertConfig.from_pretrained(config_path) - _contentvec_model = HubertModelWithFinalProj(config) - state_dict = torch.load(model_path, map_location="cpu", weights_only=True) - _contentvec_model.load_state_dict(state_dict) - _contentvec_model.to(device).eval() - logger.info(f"ContentVec loaded: hidden={config.hidden_size}, proj={config.classifier_proj_size}") - except Exception as e: - logger.warning(f"Failed to load ContentVec: {e}, falling back to torchaudio HuBERT") - _contentvec_model = None - return _contentvec_model - -def load_hubert(): - """Load HuBERT model via torchaudio for v2 models (768-dim output)""" - global _hubert_model, _hubert_bundle - if _hubert_model is None: - import torchaudio - logger.info("Loading HuBERT model via torchaudio...") - _hubert_bundle = torchaudio.pipelines.HUBERT_BASE - _hubert_model = _hubert_bundle.get_model().to(device) - _hubert_model.eval() - logger.info("HuBERT model loaded") - return _hubert_model, _hubert_bundle - -def extract_hubert_features(audio: np.ndarray, sr: int = 16000, version: str = "v2") -> torch.Tensor: - """Extract ContentVec features from audio (same as Applio) - - v1 models: Use ContentVec with final_proj (256-dim) - v2 models: Use ContentVec without final_proj (768-dim) - """ - audio = audio.astype(np.float32) - if np.abs(audio).max() > 1.0: - audio = audio / np.abs(audio).max() - - inputs = torch.from_numpy(audio).unsqueeze(0).to(device) - - # Use ContentVec for ALL versions (same as Applio) - contentvec = load_contentvec() - if contentvec is not None: - with torch.no_grad(): - output = contentvec(inputs) - if version == "v1": - # v1: use final_proj for 256-dim - feats = contentvec.final_proj(output.last_hidden_state) - else: - # v2: use raw hidden state (768-dim) - feats = output.last_hidden_state - return feats - - # Fallback to torchaudio HuBERT if ContentVec not available - logger.warning("ContentVec not available, using torchaudio HuBERT (results may be degraded)") - hubert, bundle = load_hubert() - with torch.no_grad(): - features, _ = hubert.extract_features(inputs) - layer_idx = 11 if version == "v2" else 8 - feats = features[min(layer_idx, len(features)-1)] - - if version == "v1": - proj = nn.Linear(768, 256, bias=False).to(device) - with torch.no_grad(): - w = torch.zeros(256, 768) - for i in range(256): - w[i, i*3:(i+1)*3] = 1/3 - proj.weight.copy_(w) - feats = proj(feats) - - return feats - -# ============================================================ -# F0 EXTRACTION -# ============================================================ - -def extract_f0_pm(audio: np.ndarray, sr: int = 16000, f0_up_key: int = 0) -> Tuple[np.ndarray, np.ndarray]: - """Extract F0 using parselmouth (pm method)""" - import parselmouth - - p_len = audio.shape[0] // 160 + 1 - f0_min = 65 - f0_max = 1100 - - l_pad = int(np.ceil(1.5 / f0_min * 16000)) - r_pad = l_pad + 1 - - s = parselmouth.Sound(np.pad(audio, (l_pad, r_pad)), 16000).to_pitch_ac( - time_step=0.01, voicing_threshold=0.6, pitch_floor=f0_min, pitch_ceiling=f0_max, - ) - - f0 = s.selected_array["frequency"] - if len(f0) < p_len: - f0 = np.pad(f0, (0, p_len - len(f0))) - f0 = f0[:p_len] - f0 *= pow(2, f0_up_key / 12) - - return f0_to_coarse(f0) - -def extract_f0_harvest(audio: np.ndarray, sr: int = 16000, f0_up_key: int = 0) -> Tuple[np.ndarray, np.ndarray]: - """Extract F0 using pyworld harvest""" - import pyworld - from scipy import signal as scipy_signal - - f0, t = pyworld.harvest(audio.astype(np.double), fs=16000, f0_ceil=1100, f0_floor=50, frame_period=10) - f0 = scipy_signal.medfilt(f0, 3) - f0 *= pow(2, f0_up_key / 12) - - return f0_to_coarse(f0) - -def f0_to_coarse(f0: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """Convert f0 to coarse representation""" - f0_min = 50 - f0_max = 1100 - f0_mel_min = 1127 * np.log(1 + f0_min / 700) - f0_mel_max = 1127 * np.log(1 + f0_max / 700) - - f0bak = f0.copy() - f0_mel = 1127 * np.log(1 + f0 / 700) - f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1 - f0_mel[f0_mel <= 1] = 1 - f0_mel[f0_mel > 255] = 255 - f0_coarse = np.rint(f0_mel).astype(np.int32) - - return f0_coarse, f0bak - -# ============================================================ -# RMVPE F0 EXTRACTION (from Applio - IAHispano/Applio) -# ============================================================ - -class RMVPE_ConvBlockRes(nn.Module): - def __init__(self, in_channels, out_channels, momentum=0.01): - super().__init__() - self.conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False), - nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), - nn.Conv2d(out_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False), - nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), - ) - self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) if in_channels != out_channels else None - - def forward(self, x): - r = self.conv(x) - return r + self.shortcut(x) if self.shortcut else r + x - -class RMVPE_ResEncoderBlock(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01): - super().__init__() - self.conv = nn.ModuleList([RMVPE_ConvBlockRes(in_channels, out_channels, momentum)]) - for _ in range(n_blocks - 1): - self.conv.append(RMVPE_ConvBlockRes(out_channels, out_channels, momentum)) - self.kernel_size = kernel_size - if kernel_size is not None: - self.pool = nn.AvgPool2d(kernel_size=kernel_size) - - def forward(self, x): - for c in self.conv: - x = c(x) - return (x, self.pool(x)) if self.kernel_size is not None else x - -class RMVPE_Encoder(nn.Module): - def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01): - super().__init__() - self.n_encoders = n_encoders - self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) - self.layers = nn.ModuleList() - for _ in range(n_encoders): - self.layers.append(RMVPE_ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum)) - in_channels = out_channels - out_channels *= 2 - in_size //= 2 - self.out_size = in_size - self.out_channel = out_channels - - def forward(self, x): - concat_tensors = [] - x = self.bn(x) - for layer in self.layers: - t, x = layer(x) - concat_tensors.append(t) - return x, concat_tensors - -class RMVPE_Intermediate(nn.Module): - def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01): - super().__init__() - self.layers = nn.ModuleList([RMVPE_ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)]) - for _ in range(n_inters - 1): - self.layers.append(RMVPE_ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)) - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return x - -class RMVPE_ResDecoderBlock(nn.Module): - def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): - super().__init__() - out_padding = (0, 1) if stride == (1, 2) else (1, 1) - self.conv1 = nn.Sequential( - nn.ConvTranspose2d(in_channels, out_channels, (3, 3), stride, (1, 1), out_padding, bias=False), - nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), - ) - self.conv2 = nn.ModuleList([RMVPE_ConvBlockRes(out_channels * 2, out_channels, momentum)]) - for _ in range(n_blocks - 1): - self.conv2.append(RMVPE_ConvBlockRes(out_channels, out_channels, momentum)) - - def forward(self, x, concat_tensor): - x = self.conv1(x) - x = torch.cat((x, concat_tensor), dim=1) - for c in self.conv2: - x = c(x) - return x - -class RMVPE_Decoder(nn.Module): - def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01): - super().__init__() - self.layers = nn.ModuleList() - for _ in range(n_decoders): - out_channels = in_channels // 2 - self.layers.append(RMVPE_ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)) - in_channels = out_channels - self.n_decoders = n_decoders - - def forward(self, x, concat_tensors): - for i in range(self.n_decoders): - x = self.layers[i](x, concat_tensors[-1 - i]) - return x - -class RMVPE_DeepUnet(nn.Module): - def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): - super().__init__() - self.encoder = RMVPE_Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels) - self.intermediate = RMVPE_Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) - self.decoder = RMVPE_Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) - - def forward(self, x): - x, concat_tensors = self.encoder(x) - x = self.intermediate(x) - x = self.decoder(x, concat_tensors) - return x - -class RMVPE_BiGRU(nn.Module): - def __init__(self, input_features, hidden_features, num_layers): - super().__init__() - self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) - - def forward(self, x): - return self.gru(x)[0] - -RMVPE_N_MELS = 128 -RMVPE_N_CLASS = 360 - -class RMVPE_E2E(nn.Module): - def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): - super().__init__() - self.unet = RMVPE_DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) - self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) - if n_gru: - self.fc = nn.Sequential( - RMVPE_BiGRU(3 * 128, 256, n_gru), - nn.Linear(512, RMVPE_N_CLASS), nn.Dropout(0.25), nn.Sigmoid(), - ) - else: - self.fc = nn.Sequential(nn.Linear(3 * RMVPE_N_MELS, RMVPE_N_CLASS), nn.Dropout(0.25), nn.Sigmoid()) - - def forward(self, mel): - mel = mel.transpose(-1, -2).unsqueeze(1) - x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) - return self.fc(x) - -class RMVPE_MelSpectrogram(nn.Module): - def __init__(self, n_mel_channels=128, sample_rate=16000, win_length=1024, hop_length=160, n_fft=None, mel_fmin=30, mel_fmax=8000, clamp=1e-5): - super().__init__() - from librosa.filters import mel as librosa_mel - n_fft = win_length if n_fft is None else n_fft - self.hann_window = {} - mel_basis = librosa_mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True) - self.register_buffer("mel_basis", torch.from_numpy(mel_basis).float()) - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.clamp = clamp - - def forward(self, audio, keyshift=0, speed=1, center=True): - factor = 2 ** (keyshift / 12) - n_fft_new = int(np.round(self.n_fft * factor)) - win_length_new = int(np.round(self.win_length * factor)) - hop_length_new = int(np.round(self.hop_length * speed)) - key = f"{keyshift}_{audio.device}" - if key not in self.hann_window: - self.hann_window[key] = torch.hann_window(win_length_new).to(audio.device) - fft = torch.stft(audio, n_fft=n_fft_new, hop_length=hop_length_new, win_length=win_length_new, - window=self.hann_window[key], center=center, return_complex=True) - magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) - if keyshift != 0: - size = self.n_fft // 2 + 1 - resize = magnitude.size(1) - if resize < size: - magnitude = F.pad(magnitude, (0, 0, 0, size - resize)) - magnitude = magnitude[:, :size, :] * self.win_length / win_length_new - mel_output = torch.matmul(self.mel_basis, magnitude) - return torch.log(torch.clamp(mel_output, min=self.clamp)) - -_rmvpe_model = None - -def load_rmvpe(): - """Download and load RMVPE model for f0 extraction""" - global _rmvpe_model - if _rmvpe_model is None: - logger.info("Downloading RMVPE model...") - rmvpe_path = hf_hub_download(repo_id="IAHispano/Applio", filename="Resources/predictors/rmvpe.pt") - model = RMVPE_E2E(4, 1, (2, 2)) - ckpt = torch.load(rmvpe_path, map_location="cpu", weights_only=True) - model.load_state_dict(ckpt) - model.eval().to(device) - mel_extractor = RMVPE_MelSpectrogram().to(device) - cents_mapping = 20 * np.arange(RMVPE_N_CLASS) + 1997.3794084376191 - _rmvpe_model = (model, mel_extractor, np.pad(cents_mapping, (4, 4))) - logger.info("RMVPE model loaded") - return _rmvpe_model - -def extract_f0_rmvpe(audio: np.ndarray, sr: int = 16000, f0_up_key: int = 0, thred: float = 0.03) -> Tuple[np.ndarray, np.ndarray]: - """Extract F0 using RMVPE (best quality, neural network based)""" - model, mel_extractor, cents_mapping = load_rmvpe() - - audio_t = torch.from_numpy(audio).float().to(device).unsqueeze(0) - mel = mel_extractor(audio_t, center=True) - del audio_t - - # mel2hidden with chunking - with torch.no_grad(): - n_frames = mel.shape[-1] - mel_padded = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect") - chunks = [] - for start in range(0, mel_padded.shape[-1], 32000): - end = min(start + 32000, mel_padded.shape[-1]) - chunks.append(model(mel_padded[..., start:end])) - hidden = torch.cat(chunks, dim=1)[:, :n_frames].squeeze(0).cpu().numpy() - - # Decode hidden to f0 - center = np.argmax(hidden, axis=1) - salience = np.pad(hidden, ((0, 0), (4, 4))) - center += 4 - todo_salience = [] - todo_cents = [] - for idx in range(salience.shape[0]): - s, e = center[idx] - 4, center[idx] + 5 - todo_salience.append(salience[idx, s:e]) - todo_cents.append(cents_mapping[s:e]) - todo_salience = np.array(todo_salience) - todo_cents = np.array(todo_cents) - cents_pred = np.sum(todo_salience * todo_cents, 1) / np.sum(todo_salience, 1) - cents_pred[np.max(salience, axis=1) <= thred] = 0 - - f0 = 10 * (2 ** (cents_pred / 1200)) - f0[f0 == 10] = 0 - f0 *= pow(2, f0_up_key / 12) - - return f0_to_coarse(f0) - -# ============================================================ -# MODEL LOADING -# ============================================================ - -_model_cache = {} - -def load_rvc_model(model_path: str): - """Load RVC model and auto-detect version""" - if model_path in _model_cache: - return _model_cache[model_path] - - logger.info(f"Loading RVC model: {model_path}") - try: - cpt = torch.load(model_path, map_location="cpu", weights_only=True) - except Exception: - logger.warning("Model requires unsafe loading - may be an older format") - cpt = torch.load(model_path, map_location="cpu", weights_only=False) - - weight_key = None - for key in ["weight", "model", "state_dict", "net_g"]: - if key in cpt: - weight_key = key - break - if weight_key is None: - raise ValueError(f"Cannot find model weights. Keys: {list(cpt.keys())}") - - config = cpt.get("config", None) - if config is None: - config = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4], 109, 256, 40000] - logger.warning("No config found, using v2 defaults") - - version = cpt.get("version", "v1") - if_f0 = cpt.get("f0", 1) - - if weight_key in cpt: - emb_weight = cpt[weight_key].get("emb_g.weight") - if emb_weight is not None: - config[-3] = emb_weight.shape[0] - - sr = config[-1] if isinstance(config[-1], int) else 40000 - - if version == "v1": - model_class = SynthesizerTrnMs256NSFsid if if_f0 == 1 else SynthesizerTrnMs256NSFsid_nono - else: - model_class = SynthesizerTrnMs768NSFsid if if_f0 == 1 else SynthesizerTrnMs768NSFsid_nono - - model = model_class( - spec_channels=config[0], segment_size=config[1], inter_channels=config[2], - hidden_channels=config[3], filter_channels=config[4], n_heads=config[5], - n_layers=config[6], kernel_size=config[7], p_dropout=config[8], - resblock=config[9], resblock_kernel_sizes=config[10], - resblock_dilation_sizes=config[11], upsample_rates=config[12], - upsample_initial_channel=config[13], upsample_kernel_sizes=config[14], - spk_embed_dim=config[15], gin_channels=config[16], sr=sr, is_half=False - ) - - model.load_state_dict(cpt[weight_key], strict=False) - model.eval().to(device) - - _model_cache[model_path] = (model, sr, version, if_f0) - logger.info(f"Model loaded: version={version}, f0={if_f0}, sr={sr}") - - return model, sr, version, if_f0 - -# ============================================================ -# TRAINING - Simplified for CPU testing -# ============================================================ - -def spectrogram_torch(y, n_fft, hop_size, win_size, center=False): - """Compute spectrogram""" - hann_window = torch.hann_window(win_size).to(y.device) - y = F.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect').squeeze(1) - spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, - center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) - spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6) - return spec - -# Mel spectrogram for training loss -_mel_basis_cache = {} - -def spec_to_mel_torch(spec, n_fft=2048, num_mels=125, sampling_rate=40000, fmin=0, fmax=None): - """Convert spectrogram to mel spectrogram""" - from librosa.filters import mel as librosa_mel_fn - global _mel_basis_cache - - if fmax is None: - fmax = sampling_rate // 2 - - key = f"{n_fft}_{num_mels}_{sampling_rate}_{fmin}_{fmax}_{spec.dtype}_{spec.device}" - if key not in _mel_basis_cache: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - _mel_basis_cache[key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) - - melspec = torch.matmul(_mel_basis_cache[key], spec) - melspec = torch.log(torch.clamp(melspec, min=1e-5)) # Log-amplitude - return melspec - -def preprocess_audio_for_training(audio_path: str, output_dir: str, target_sr: int = 40000, f0_method: str = "rmvpe"): - """Preprocess audio file for training - slice and extract features""" - import scipy.signal as signal - - os.makedirs(output_dir, exist_ok=True) - os.makedirs(f"{output_dir}/wavs", exist_ok=True) - os.makedirs(f"{output_dir}/hubert", exist_ok=True) - os.makedirs(f"{output_dir}/f0", exist_ok=True) - - logger.info(f"Preprocessing: {audio_path}") - - # Load and resample audio - audio, sr = librosa.load(audio_path, sr=target_sr, mono=True) - - # High-pass filter - bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=target_sr) - audio = signal.lfilter(bh, ah, audio) - - # Slice into chunks (3.7 seconds with 0.3 overlap) - chunk_size = int(3.7 * target_sr) - hop = int(3.4 * target_sr) - - chunks = [] - for i, start in enumerate(range(0, len(audio) - chunk_size, hop)): - chunk = audio[start:start + chunk_size] - # Normalize - max_val = np.abs(chunk).max() - if max_val > 0.01: # Skip silence - chunk = chunk / max_val * 0.9 - chunks.append((i, chunk)) - - if not chunks: - logger.warning("No valid audio chunks found") - return None - - logger.info(f"Created {len(chunks)} chunks") - - # Save chunks and extract features - manifest = [] - for idx, chunk in chunks: - # Save wav - wav_path = f"{output_dir}/wavs/{idx:04d}.wav" - sf.write(wav_path, chunk, target_sr) - - # Resample to 16k for HuBERT - chunk_16k = librosa.resample(chunk, orig_sr=target_sr, target_sr=16000) - - # Extract HuBERT features - feats = extract_hubert_features(chunk_16k, sr=16000, version="v2") - hubert_path = f"{output_dir}/hubert/{idx:04d}.npy" - np.save(hubert_path, feats.squeeze(0).cpu().numpy()) - - # Extract F0 - if f0_method == "rmvpe": - f0_coarse, f0 = extract_f0_rmvpe(chunk_16k, 16000, 0) - elif f0_method == "harvest": - f0_coarse, f0 = extract_f0_harvest(chunk_16k, 16000, 0) - else: - f0_coarse, f0 = extract_f0_pm(chunk_16k, 16000, 0) - f0_path = f"{output_dir}/f0/{idx:04d}.npy" - np.save(f0_path, np.stack([f0_coarse, f0], axis=0)) - - manifest.append(f"{idx:04d}") - - # Save manifest - with open(f"{output_dir}/manifest.txt", "w") as f: - f.write("\n".join(manifest)) - - logger.info(f"Preprocessing complete: {len(manifest)} samples") - return output_dir - -def train_rvc_generator( - data_dir: str, - output_dir: str, - epochs: int = 10, - batch_size: int = 2, - lr: float = 1e-5, # Lower LR prevents overfitting on small data - target_sr: int = 40000, - progress_callback=None -): - """Generator version of train_rvc - yields (epoch_msg, ckpt_path) tuples""" - logger.info(f"Starting training: {data_dir} -> {output_dir}") - os.makedirs(output_dir, exist_ok=True) - - # Load manifest - with open(f"{data_dir}/manifest.txt") as f: - samples = [l.strip() for l in f if l.strip()] - - if len(samples) < 1: - logger.error("No training samples found") - return None - - logger.info(f"Training with {len(samples)} samples") - - # Model config (v2 40k defaults) - config = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4], 1, 256, target_sr] - - # Create models (v2 only - 768-dim HuBERT features) - net_g = SynthesizerTrnMs768NSFsid( - spec_channels=config[0], segment_size=config[1], inter_channels=config[2], - hidden_channels=config[3], filter_channels=config[4], n_heads=config[5], - n_layers=config[6], kernel_size=config[7], p_dropout=config[8], - resblock=config[9], resblock_kernel_sizes=config[10], - resblock_dilation_sizes=config[11], upsample_rates=config[12], - upsample_initial_channel=config[13], upsample_kernel_sizes=config[14], - spk_embed_dim=config[15], gin_channels=config[16], sr=target_sr - ).to(train_device) - - net_d = MultiPeriodDiscriminator().to(train_device) - logger.info(f"Training on device: {train_device}") - - # Download and load pretrained weights (essential for good results) - sr_key = f"{target_sr // 1000}k" # e.g., "40k" - try: - pretrain_g_path = download_pretrained_rvc(f"f0G{sr_key}") - pretrain_d_path = download_pretrained_rvc(f"f0D{sr_key}") - load_pretrained_weights(net_g, pretrain_g_path) - load_pretrained_weights(net_d, pretrain_d_path) - except Exception as e: - logger.warning(f"Failed to load pretrained weights: {e}") - logger.warning("Training from scratch (results may be poor)") - - # Optimizers (after loading pretrained weights) - optim_g = torch.optim.AdamW(net_g.parameters(), lr=lr, betas=(0.8, 0.99)) - optim_d = torch.optim.AdamW(net_d.parameters(), lr=lr, betas=(0.8, 0.99)) - - # LR scheduler (matches Applio - exponential decay) - lr_decay = 0.999875 - scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=lr_decay) - scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=lr_decay) - - net_g.train() - net_d.train() - - # Training loop - for epoch in range(epochs): - total_loss_g, total_loss_d = 0, 0 - np.random.shuffle(samples) - - for i in range(0, len(samples), batch_size): - batch_samples = samples[i:i+batch_size] - - # Load batch data - wavs, huberts, f0s = [], [], [] - for s in batch_samples: - wav, _ = librosa.load(f"{data_dir}/wavs/{s}.wav", sr=target_sr, mono=True) - hubert = np.load(f"{data_dir}/hubert/{s}.npy") - # Upsample 50Hz -> 100Hz using interpolation (same as inference) - hubert_t = torch.from_numpy(hubert).unsqueeze(0).permute(0, 2, 1) # (1, 768, seq) - hubert_t = F.interpolate(hubert_t, scale_factor=2, mode='linear', align_corners=False) - hubert = hubert_t.permute(0, 2, 1).squeeze(0).numpy() # (seq*2, 768) - f0_data = np.load(f"{data_dir}/f0/{s}.npy") - wavs.append(wav) - huberts.append(hubert) - f0s.append(f0_data) - - # Compute spectrogram first to get target length - max_wav_len = max(len(w) for w in wavs) - wav_batch = np.zeros((len(wavs), max_wav_len)) - for j, w in enumerate(wavs): - wav_batch[j, :len(w)] = w - wav_t = torch.FloatTensor(wav_batch).unsqueeze(1).to(train_device) - spec = spectrogram_torch(wav_t.squeeze(1), 2048, 400, 2048) - spec_len = spec.shape[2] # Target length for all features - - # Pad/truncate features to match spec length exactly - hubert_batch = np.zeros((len(huberts), spec_len, huberts[0].shape[1])) - f0_batch = np.zeros((len(f0s), spec_len)) - f0f_batch = np.zeros((len(f0s), spec_len)) - - for j, (h, f) in enumerate(zip(huberts, f0s)): - # Truncate or pad HuBERT to spec_len - h_len = min(h.shape[0], spec_len) - hubert_batch[j, :h_len] = h[:h_len] - # Truncate or pad F0 to spec_len - f0_len = min(f.shape[1], spec_len) - f0_batch[j, :f0_len] = f[0, :f0_len] - f0f_batch[j, :f0_len] = f[1, :f0_len] - - # To tensors - all features now have spec_len - hubert_t = torch.FloatTensor(hubert_batch).to(train_device) - f0_t = torch.LongTensor(f0_batch.astype(np.int64)).to(train_device) - f0f_t = torch.FloatTensor(f0f_batch).to(train_device) - lengths_t = torch.LongTensor([spec_len] * len(batch_samples)).to(train_device) - sid_t = torch.LongTensor([0] * len(batch_samples)).to(train_device) - spec_lengths = torch.LongTensor([spec_len] * len(batch_samples)).to(train_device) - - # Forward pass generator - # Args: phone, phone_lengths, pitch, pitchf, y (spec), y_lengths, ds - try: - y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = net_g( - hubert_t, lengths_t, f0_t, f0f_t, spec, spec_lengths, sid_t - ) - except Exception as e: - logger.warning(f"Generator forward failed: {e}") - continue - - # Slice wav at same position model generated (CRITICAL for proper loss) - # ids_slice is in latent space, multiply by hop_length to get waveform position - hop_length = 400 - segment_size_wav = 32 * hop_length # segment_size in latent * hop_length - y = slice_segments(wav_t, ids_slice * hop_length, segment_size_wav) - - # Discriminator forward - y_d_rs, y_d_gs, fmap_rs, fmap_gs = net_d(y, y_hat.detach()) - - # Discriminator loss - loss_d = discriminator_loss(y_d_rs, y_d_gs) - - optim_d.zero_grad() - loss_d.backward() - optim_d.step() - - # Generator loss - y_d_rs, y_d_gs, fmap_rs, fmap_gs = net_d(y, y_hat) - loss_gen = generator_loss(y_d_gs) - loss_fm = feature_loss(fmap_rs, fmap_gs) - loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) - - # Mel spectrogram loss (crucial for quality) - # Config: n_fft=2048, hop=400, win=2048, n_mels=125, fmin=0, fmax=None - y_mel = spec_to_mel_torch(spectrogram_torch(y.squeeze(1), 2048, 400, 2048), - n_fft=2048, num_mels=125, sampling_rate=target_sr, fmin=0, fmax=None) - y_hat_mel = spec_to_mel_torch(spectrogram_torch(y_hat.squeeze(1), 2048, 400, 2048), - n_fft=2048, num_mels=125, sampling_rate=target_sr, fmin=0, fmax=None) - # Align lengths if needed - min_len = min(y_mel.shape[2], y_hat_mel.shape[2]) - loss_mel = F.l1_loss(y_mel[:, :, :min_len], y_hat_mel[:, :, :min_len]) * 45 # c_mel = 45 - - loss_g = loss_gen + loss_fm + loss_mel + loss_kl - - optim_g.zero_grad() - loss_g.backward() - optim_g.step() - - total_loss_g += loss_g.item() - total_loss_d += loss_d.item() - - avg_loss_g = total_loss_g / max(1, len(samples) // batch_size) - avg_loss_d = total_loss_d / max(1, len(samples) // batch_size) - epoch_msg = f"Epoch {epoch+1}/{epochs} - G: {avg_loss_g:.2f}, D: {avg_loss_d:.2f}" - logger.info(epoch_msg) - - # Update progress callback if provided - if progress_callback: - progress_pct = 0.30 + (0.65 * (epoch + 1) / epochs) - progress_callback(progress_pct, epoch_msg) - - # Yield epoch message for live UI updates - yield epoch_msg, None, None - - # Step LR schedulers - scheduler_g.step() - scheduler_d.step() - - # Save checkpoint - ckpt_path = f"{output_dir}/model.pth" - torch.save({ - "weight": net_g.state_dict(), - "config": config, - "version": "v2", # v2 only - "f0": 1, - }, ckpt_path) - logger.info(f"Saved checkpoint: {ckpt_path}") - - # Generate index file for better speaker similarity - index_path = None - try: - import faiss - hubert_dir = f"{data_dir}/hubert" - npys = [] - for name in sorted(os.listdir(hubert_dir)): - if name.endswith('.npy'): - phone = np.load(os.path.join(hubert_dir, name)) - npys.append(phone) - if npys: - big_npy = np.concatenate(npys, axis=0) - n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39) - n_ivf = max(1, n_ivf) # Ensure at least 1 - index = faiss.index_factory(big_npy.shape[1], f"IVF{n_ivf},Flat") - index.train(big_npy) - index.add(big_npy) - index_path = f"{output_dir}/model.index" - faiss.write_index(index, index_path) - logger.info(f"Saved index: {index_path}") - except Exception as e: - logger.warning(f"Failed to generate index: {e}") - - # Cleanup training models to free memory - del net_g, net_d, optim_g, optim_d, scheduler_g, scheduler_d - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - yield "Training complete!", ckpt_path, index_path - - -def train_rvc( - data_dir: str, - output_dir: str, - epochs: int = 10, - batch_size: int = 2, - lr: float = 1e-5, # Lower LR prevents overfitting on small data - target_sr: int = 40000, - progress_callback=None -): - """Non-generator wrapper for CLI use - returns (checkpoint_path, index_path)""" - ckpt = None - idx = None - for msg, path, index in train_rvc_generator(data_dir, output_dir, epochs, batch_size, lr, target_sr, progress_callback): - if path: - ckpt = path - if index: - idx = index - return ckpt, idx - -# ============================================================ -# INFERENCE -# ============================================================ - -def convert_voice( - source_audio: str, - model_file, - index_file=None, - pitch_shift: int = 0, - f0_method: str = "pm", - index_rate: float = 0.5, - protect: float = 0.33, - volume_envelope: float = 1.0, - progress=gr.Progress() -) -> Tuple[str, str]: - """Convert voice using RVC model (Applio-compatible pipeline).""" - try: - if source_audio is None: - return None, "Please upload source audio" - if model_file is None: - return None, "Please upload RVC model (.pth)" - - model_path = model_file.name if hasattr(model_file, 'name') else model_file - - progress(0.1, "Loading model...") - model, tgt_sr, version, if_f0 = load_rvc_model(model_path) - - progress(0.2, "Loading audio...") - audio, sr = librosa.load(source_audio, sr=16000, mono=True) - - # Apply 48Hz high-pass filter (critical - removes low-frequency artifacts) - audio = signal.filtfilt(bh, ah, audio) - - # Normalize audio - audio_max = np.abs(audio).max() / 0.95 - if audio_max > 1: - audio /= audio_max - - # Pipeline constants (same as Applio) - window = 160 # Critical for feature/pitch alignment - x_pad = 1 # Padding in seconds - t_pad = 16000 * x_pad # Padding in samples - - # Pad audio - audio_pad = np.pad(audio, (t_pad, t_pad), mode="reflect") - p_len = audio_pad.shape[0] // window - - progress(0.3, "Extracting features...") - feats = extract_hubert_features(audio_pad, sr=16000, version=version) - - # Save original features for protect mechanism - feats0 = feats.clone() if if_f0 == 1 and protect < 0.5 else None - - # Index retrieval (speaker similarity) - if index_file is not None and index_rate > 0: - try: - import faiss - index_path = index_file.name if hasattr(index_file, 'name') else index_file - progress(0.4, "Loading index...") - index = faiss.read_index(index_path) - big_npy = index.reconstruct_n(0, index.ntotal) - - npy = feats[0].cpu().numpy().astype("float32") - score, ix = index.search(npy, k=8) - weight = np.square(1 / score) - weight /= weight.sum(axis=1, keepdims=True) - npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) - feats = torch.from_numpy(npy).unsqueeze(0).to(device) * index_rate + (1 - index_rate) * feats - except Exception as e: - logger.warning(f"Index retrieval failed: {e}") - - # Feature upsampling by 2x - feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) - - # Adjust length based on audio - p_len = min(audio_pad.shape[0] // window, feats.shape[1]) - - pitch, pitchf = None, None - if if_f0 == 1: - progress(0.5, f"Extracting F0 ({f0_method})...") - if f0_method == "rmvpe": - pitch, pitchf = extract_f0_rmvpe(audio_pad, 16000, pitch_shift) - elif f0_method == "harvest": - pitch, pitchf = extract_f0_harvest(audio_pad, 16000, pitch_shift) - else: - pitch, pitchf = extract_f0_pm(audio_pad, 16000, pitch_shift) - - pitch = pitch[:p_len] - pitchf = pitchf[:p_len] - - # Upsample feats0 for protect - if feats0 is not None: - feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) - - # Apply protect mechanism (preserve original features for unvoiced segments) - if protect < 0.5 and feats0 is not None: - pitchf_tensor = torch.from_numpy(pitchf).float().to(device) - pitchff = pitchf_tensor.clone() - pitchff[pitchf_tensor > 0] = 1 - pitchff[pitchf_tensor < 1] = protect - pitchff = pitchff.unsqueeze(0).unsqueeze(-1) - feats = feats[:, :p_len, :] * pitchff + feats0[:, :p_len, :] * (1 - pitchff) - - if len(pitch) < p_len: - pitch = np.pad(pitch, (0, p_len - len(pitch))) - pitchf = np.pad(pitchf, (0, p_len - len(pitchf))) - - pitch = torch.LongTensor(pitch).unsqueeze(0).to(device) - pitchf = torch.FloatTensor(pitchf).unsqueeze(0).to(device) - - p_len_tensor = torch.LongTensor([p_len]).to(device) - sid = torch.LongTensor([0]).to(device) - - progress(0.7, "Running inference...") - with torch.no_grad(): - if if_f0 == 1: - audio_out = model.infer(feats[:, :p_len, :], p_len_tensor, pitch, pitchf, sid)[0][0, 0].data.cpu().float().numpy() - else: - audio_out = model.infer(feats[:, :p_len, :], p_len_tensor, sid)[0][0, 0].data.cpu().float().numpy() - - # Remove padding from output - t_pad_tgt = int(t_pad * tgt_sr / 16000) - if len(audio_out) > 2 * t_pad_tgt: - audio_out = audio_out[t_pad_tgt:-t_pad_tgt] - - # RMS mixing - match volume dynamics of source audio - if volume_envelope != 1.0: - try: - source_at_tgt_sr = librosa.resample(audio, orig_sr=16000, target_sr=tgt_sr) - frame_len = tgt_sr // 2 * 2 - hop_len = tgt_sr // 2 - - rms_source = librosa.feature.rms(y=source_at_tgt_sr, frame_length=frame_len, hop_length=hop_len) - rms_output = librosa.feature.rms(y=audio_out, frame_length=frame_len, hop_length=hop_len) - - rms_source = F.interpolate( - torch.from_numpy(rms_source).float().unsqueeze(0), - size=audio_out.shape[0], mode="linear" - ).squeeze() - rms_output = F.interpolate( - torch.from_numpy(rms_output).float().unsqueeze(0), - size=audio_out.shape[0], mode="linear" - ).squeeze() - rms_output = torch.maximum(rms_output, torch.zeros_like(rms_output) + 1e-6) - - # Applio formula: target * (source^(1-rate) * output^(rate-1)) - audio_out = audio_out * (torch.pow(rms_source, 1 - volume_envelope) * torch.pow(rms_output, volume_envelope - 1)).numpy() - except Exception as e: - logger.warning(f"RMS mixing failed: {e}") - - # Final normalization - audio_max = np.abs(audio_out).max() / 0.99 - if audio_max > 1: - audio_out /= audio_max - - progress(0.9, "Saving output...") - fd, output_path = tempfile.mkstemp(suffix=".wav") - os.close(fd) - sf.write(output_path, audio_out, tgt_sr) - - # Cleanup tensors and model cache to prevent OOM on HF Spaces - del model, feats, audio_out, audio, audio_pad - if feats0 is not None: - del feats0 - _model_cache.clear() - gc.collect() - - return output_path, f"Converted: {version}, sr={tgt_sr}, pitch={pitch_shift:+d}" - - except Exception as e: - logger.exception("Conversion failed") - _model_cache.clear() - gc.collect() - return None, f"Error: {str(e)}" - -# ============================================================ -# DEFAULT MODEL DOWNLOAD -# ============================================================ - -def load_example_model(): - """Download and load the default example model from HuggingFace""" - import shutil - try: - logger.info(f"Downloading example model from {DEFAULT_MODEL_REPO}...") - model_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_MODEL_FILE) - index_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_INDEX_FILE) - - # Gradio 6 requires files to be in allowed directories (cwd or /tmp) - # Copy from HF cache to temp directory - temp_dir = tempfile.mkdtemp() - temp_model = os.path.join(temp_dir, DEFAULT_MODEL_FILE) - temp_index = os.path.join(temp_dir, DEFAULT_INDEX_FILE) - shutil.copy2(model_path, temp_model) - shutil.copy2(index_path, temp_index) - - return temp_model, temp_index, f"Loaded: {DEFAULT_MODEL_REPO}" - except Exception as e: - logger.exception("Failed to download example model") - return None, None, f"Error: {str(e)}" - -# ============================================================ -# BEATRICE V2 MODEL -# ============================================================ - -def beatrice_load_audio(file, **kwargs): - """Load audio using soundfile directly (for Beatrice dataset)""" - data, sr = sf.read(file, dtype='float32') - # soundfile returns (samples, channels), convert to torch (channels, samples) - wav = torch.from_numpy(data) - if wav.ndim == 1: - wav = wav.unsqueeze(0) # mono -> (1, samples) - else: - wav = wav.T # (samples, channels) -> (channels, samples) - return wav, sr - - -class AttrDict(dict): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.__dict__ = self - - - -def dump_params(params: torch.Tensor, f: BinaryIO): - if params is None: - return - if params.dtype == torch.bfloat16: - f.write( - params.detach() - .clone() - .float() - .view(torch.short) - .numpy() - .ravel()[1::2] - .tobytes() - ) - else: - f.write(params.detach().numpy().ravel().tobytes()) - f.flush() - - -def dump_layer(layer: nn.Module, f: BinaryIO): - dump = partial(dump_params, f=f) - if hasattr(layer, "dump"): - layer.dump(f) - elif isinstance(layer, (nn.Linear, nn.Conv1d, nn.LayerNorm)): - dump(layer.weight) - dump(layer.bias) - elif isinstance(layer, nn.MultiheadAttention): - embed_dim = layer.embed_dim - num_heads = layer.num_heads - # [3 * embed_dim, embed_dim] - in_proj_weight = layer.in_proj_weight.data.clone() - in_proj_weight[: 2 * embed_dim] *= 1.0 / math.sqrt( - math.sqrt(embed_dim // num_heads) - ) - in_proj_weight = in_proj_weight.view( - 3, num_heads, embed_dim // num_heads, embed_dim - ) - # [num_heads, 3, embed_dim / num_heads, embed_dim] - in_proj_weight = in_proj_weight.transpose(0, 1) - # [3 * embed_dim] - in_proj_bias = layer.in_proj_bias.data.clone() - in_proj_bias[: 2 * embed_dim] *= 1.0 / math.sqrt( - math.sqrt(embed_dim // num_heads) - ) - in_proj_bias = in_proj_bias.view(3, num_heads, embed_dim // num_heads) - # [num_heads, 3, embed_dim / num_heads] - in_proj_bias = in_proj_bias.transpose(0, 1) - dump(in_proj_weight) - dump(in_proj_bias) - dump(layer.out_proj.weight) - dump(layer.out_proj.bias) - elif isinstance(layer, nn.Embedding): - dump(layer.weight) - elif isinstance(layer, nn.Parameter): - dump(layer) - elif isinstance(layer, nn.ModuleList): - for layer_i in layer: - dump_layer(layer_i, f) - else: - assert False, layer - - -class CausalConv1d(nn.Conv1d): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - dilation: int = 1, - groups: int = 1, - bias: bool = True, - delay: int = 0, - ): - padding = (kernel_size - 1) * dilation - delay - self.trim = (kernel_size - 1) * dilation - 2 * delay - if self.trim < 0: - raise ValueError - super().__init__( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - result = super().forward(input) - if self.trim == 0: - return result - else: - return result[:, :, : -self.trim] - - -class WSConv1d(CausalConv1d): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - dilation: int = 1, - groups: int = 1, - bias: bool = True, - delay: int = 0, - ): - super().__init__( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - groups=groups, - bias=bias, - delay=delay, - ) - self.weight.data.normal_( - 0.0, math.sqrt(1.0 / (in_channels * kernel_size // groups)) - ) - if bias: - self.bias.data.zero_() - self.gain = nn.Parameter(torch.ones((out_channels, 1, 1))) - - def standardized_weight(self) -> torch.Tensor: - var, mean = torch.var_mean(self.weight, [1, 2], keepdim=True) - scale = ( - self.gain - * ( - self.in_channels * self.kernel_size[0] // self.groups * var + 1e-8 - ).rsqrt() - ) - return scale * (self.weight - mean) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - result = F.conv1d( - input, - self.standardized_weight(), - self.bias, - self.stride, - self.padding, - self.dilation, - self.groups, - ) - if self.trim == 0: - return result - else: - return result[:, :, : -self.trim] - - def merge_weights(self): - self.weight.data[:] = self.standardized_weight().detach() - self.gain.data.fill_(1.0) - - -class WSLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, bias: bool = True): - super().__init__(in_features, out_features, bias) - self.weight.data.normal_(0.0, math.sqrt(1.0 / in_features)) - self.bias.data.zero_() - self.gain = nn.Parameter(torch.ones((out_features, 1))) - - def standardized_weight(self) -> torch.Tensor: - var, mean = torch.var_mean(self.weight, 1, keepdim=True) - scale = self.gain * (self.in_features * var + 1e-8).rsqrt() - return scale * (self.weight - mean) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, self.standardized_weight(), self.bias) - - def merge_weights(self): - self.weight.data[:] = self.standardized_weight().detach() - self.gain.data.fill_(1.0) - - -class CrossAttention(nn.Module): - def __init__( - self, - qk_channels: int, - vo_channels: int, - num_heads: int, - in_q_channels: int, - in_kv_channels: int, - out_channels: int, - dropout: float = 0.0, - ): - super().__init__() - assert qk_channels % num_heads == 0 - self.qk_channels = qk_channels - self.vo_channels = vo_channels - self.num_heads = num_heads - self.in_q_channels = in_q_channels - self.in_kv_channels = in_kv_channels - self.out_channels = out_channels - self.dropout = dropout - self.head_qk_channels = qk_channels // num_heads - self.head_vo_channels = vo_channels // num_heads - self.q_projection = nn.Linear(in_q_channels, qk_channels) - self.q_projection.weight.data.normal_(0.0, math.sqrt(1.0 / in_q_channels)) - self.q_projection.bias.data.zero_() - self.kv_projection = nn.Linear(in_kv_channels, qk_channels + vo_channels) - self.kv_projection.weight.data.normal_(0.0, math.sqrt(1.0 / in_kv_channels)) - self.kv_projection.bias.data.zero_() - self.out_projection = nn.Linear(vo_channels, out_channels) - self.out_projection.weight.data.normal_(0.0, math.sqrt(1.0 / vo_channels)) - self.out_projection.bias.data.zero_() - - def forward( - self, - q: torch.Tensor, - kv: torch.Tensor, - ) -> torch.Tensor: - # q: [batch_size, q_length, in_q_channels] - # kv: [batch_size, kv_length, in_kv_channels] - batch_size, q_length, _ = q.size() - _, kv_length, _ = kv.size() - # [batch_size, q_length, qk_channels] - q = self.q_projection(q) - # [batch_size, kv_length, qk_channels + vo_channels] - kv = self.kv_projection(kv) - # [batch_size, kv_length, qk_channels], [batch_size, kv_length, vo_channels] - k, v = kv.split([self.qk_channels, self.vo_channels], dim=2) - q = q.view( - batch_size, q_length, self.num_heads, self.head_qk_channels - ).transpose(1, 2) - k = k.view( - batch_size, kv_length, self.num_heads, self.head_qk_channels - ).transpose(1, 2) - v = v.view( - batch_size, kv_length, self.num_heads, self.head_vo_channels - ).transpose(1, 2) - # [batch_size, num_heads, q_length, head_vo_channels] - attn_out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout) - # [batch_size, q_length, vo_channels] - attn_out = ( - attn_out.transpose(1, 2) - .contiguous() - .view(batch_size, q_length, self.vo_channels) - ) - # [batch_size, q_length, out_channels] - attn_out = self.out_projection(attn_out) - return attn_out - - def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): - if isinstance(f, (str, bytes, os.PathLike)): - with open(f, "wb") as f: - self.dump(f) - return - if not hasattr(f, "write"): - raise TypeError - - q_projection_weight = self.q_projection.weight.data.clone() - q_projection_bias = self.q_projection.bias.data.clone() - q_projection_weight *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) - q_projection_bias *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) - dump_params(q_projection_weight, f) - dump_params(q_projection_bias, f) - dump_layer(self.out_projection, f) - - def dump_kv(self, f: Union[BinaryIO, str, bytes, os.PathLike]): - if isinstance(f, (str, bytes, os.PathLike)): - with open(f, "wb") as f: - self.dump_kv(f) - return - if not hasattr(f, "write"): - raise TypeError - - kv_projection_weight = self.kv_projection.weight.data.clone() - kv_projection_bias = self.kv_projection.bias.data.clone() - k_projection_weight, v_projection_weight = kv_projection_weight.split( - [self.qk_channels, self.vo_channels] - ) - k_projection_bias, v_projection_bias = kv_projection_bias.split( - [self.qk_channels, self.vo_channels] - ) - k_projection_weight *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) - k_projection_bias *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) - # [qk_channels, in_kv_channels] -> [num_heads, head_qk_channels, in_kv_channels] - k_projection_weight = k_projection_weight.view( - self.num_heads, self.head_qk_channels, self.in_kv_channels - ) - # [qk_channels] -> [num_heads, head_qk_channels] - k_projection_bias = k_projection_bias.view( - self.num_heads, self.head_qk_channels - ) - # [vo_channels, in_kv_channels] -> [num_heads, head_vo_channels, in_kv_channels] - v_projection_weight = v_projection_weight.view( - self.num_heads, self.head_vo_channels, self.in_kv_channels - ) - # [vo_channels] -> [num_heads, head_vo_channels] - v_projection_bias = v_projection_bias.view( - self.num_heads, self.head_vo_channels - ) - for i in range(self.num_heads): - # [head_qk_channels, in_kv_channels] - dump_params(k_projection_weight[i], f) - # [head_vo_channels, in_kv_channels] - dump_params(v_projection_weight[i], f) - for i in range(self.num_heads): - # [head_qk_channels] - dump_params(k_projection_bias[i], f) - # [head_vo_channels] - dump_params(v_projection_bias[i], f) - - -class ConvNeXtBlock(nn.Module): - def __init__( - self, - channels: int, - intermediate_channels: int, - layer_scale_init_value: float, - kernel_size: int = 7, - use_weight_standardization: bool = False, - enable_scaling: bool = False, - pre_scale: float = 1.0, - post_scale: float = 1.0, - use_mha: bool = False, - cross_attention: bool = False, - num_heads: int = 4, - attention_dropout: float = 0.1, - attention_channels: Optional[int] = None, - kv_channels: Optional[int] = None, - ): - super().__init__() - self.use_weight_standardization = use_weight_standardization - self.enable_scaling = enable_scaling - self.use_mha = use_mha - self.cross_attention = cross_attention - if use_mha: - self.attn_norm = nn.LayerNorm(channels) - if cross_attention: - self.mha = CrossAttention( - qk_channels=attention_channels, - vo_channels=attention_channels, - num_heads=num_heads, - in_q_channels=channels, - in_kv_channels=kv_channels, - out_channels=channels, - dropout=attention_dropout, - ) - else: # self-attention - assert attention_channels is None - assert kv_channels is None - self.mha = nn.MultiheadAttention( - embed_dim=channels, - num_heads=num_heads, - dropout=attention_dropout, - batch_first=True, - ) - self.dwconv = CausalConv1d( - channels, channels, kernel_size=kernel_size, groups=channels - ) - self.norm = nn.LayerNorm(channels) - self.pwconv1 = nn.Linear(channels, intermediate_channels) - self.pwconv2 = nn.Linear(intermediate_channels, channels) - self.gamma = nn.Parameter(torch.full((channels,), layer_scale_init_value)) - self.dwconv.weight.data.normal_(0.0, math.sqrt(1.0 / kernel_size)) - self.dwconv.bias.data.zero_() - self.pwconv1.weight.data.normal_(0.0, math.sqrt(2.0 / channels)) - self.pwconv1.bias.data.zero_() - self.pwconv2.weight.data.normal_(0.0, math.sqrt(1.0 / intermediate_channels)) - self.pwconv2.bias.data.zero_() - if use_weight_standardization: - self.norm = nn.Identity() - self.dwconv = WSConv1d(channels, channels, kernel_size, groups=channels) - self.pwconv1 = WSLinear(channels, intermediate_channels) - self.pwconv2 = WSLinear(intermediate_channels, channels) - del self.gamma - if enable_scaling: - self.register_buffer("pre_scale", torch.tensor(pre_scale)) - self.register_buffer("post_scale", torch.tensor(post_scale)) - self.post_scale_weight = nn.Parameter(torch.ones(())) - - def forward( - self, - x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - kv: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if self.use_mha: - batch_size, channels, length = x.size() - if self.cross_attention: - assert kv is not None - else: - assert kv is None - assert length % 4 == 0 - identity = x - if self.cross_attention: - # kv: [batch_size, kv_length, kv_channels] - x = x.transpose(1, 2) - x = self.attn_norm(x) - x = self.mha(x, kv) - x = x.transpose(1, 2) - else: - x = x.view(batch_size, channels, length // 4, 4) - x = x.permute(0, 3, 2, 1) - x = x.reshape(batch_size * 4, length // 4, channels) - x = self.attn_norm(x) - x, _ = self.mha( - x, x, x, attn_mask=attn_mask, is_causal=True, need_weights=False - ) - x = x.view(batch_size, 4, length // 4, channels) - x = x.permute(0, 3, 2, 1) - x = x.reshape(batch_size, channels, length) - x += identity - - identity = x - if self.enable_scaling: - x = x * self.pre_scale - x = self.dwconv(x) - x = x.transpose(1, 2) - x = self.norm(x) - x = self.pwconv1(x) - x = F.gelu(x, approximate="tanh") - x = self.pwconv2(x) - if not self.use_weight_standardization: - x *= self.gamma - if self.enable_scaling: - x *= self.post_scale * self.post_scale_weight - x = x.transpose(1, 2) - x += identity - return x - - def merge_weights(self): - if self.use_mha: - if self.cross_attention: - assert isinstance(self.mha, CrossAttention) - self.mha.q_projection.bias.data += torch.mv( - self.mha.q_projection.weight.data, self.attn_norm.bias.data - ) - self.mha.q_projection.weight.data *= self.attn_norm.weight.data[None, :] - self.attn_norm.bias.data[:] = 0.0 - self.attn_norm.weight.data[:] = 1.0 - else: # self-attention - assert isinstance(self.mha, nn.MultiheadAttention) - self.mha.in_proj_bias.data += torch.mv( - self.mha.in_proj_weight.data, self.attn_norm.bias.data - ) - self.mha.in_proj_weight.data *= self.attn_norm.weight.data[None, :] - self.attn_norm.bias.data[:] = 0.0 - self.attn_norm.weight.data[:] = 1.0 - if self.use_weight_standardization: - self.dwconv.merge_weights() - self.pwconv1.merge_weights() - self.pwconv2.merge_weights() - else: - self.pwconv1.bias.data += torch.mv( - self.pwconv1.weight.data, self.norm.bias.data - ) - self.pwconv1.weight.data *= self.norm.weight.data[None, :] - self.norm.bias.data[:] = 0.0 - self.norm.weight.data[:] = 1.0 - self.pwconv2.weight.data *= self.gamma.data[:, None] - self.pwconv2.bias.data *= self.gamma.data - self.gamma.data[:] = 1.0 - if self.enable_scaling: - self.dwconv.weight.data *= self.pre_scale.data - self.pre_scale.data.fill_(1.0) - self.pwconv2.weight.data *= ( - self.post_scale.data * self.post_scale_weight.data - ) - self.pwconv2.bias.data *= self.post_scale.data * self.post_scale_weight.data - self.post_scale.data.fill_(1.0) - self.post_scale_weight.data.fill_(1.0) - - def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): - if isinstance(f, (str, bytes, os.PathLike)): - with open(f, "wb") as f: - self.dump(f) - return - if not hasattr(f, "write"): - raise TypeError - - if self.use_mha: - dump_layer(self.mha, f) - dump_layer(self.dwconv, f) - dump_layer(self.pwconv1, f) - dump_layer(self.pwconv2, f) - - -class ConvNeXtStack(nn.Module): - def __init__( - self, - in_channels: int, - channels: int, - intermediate_channels: int, - n_blocks: int, - delay: int, - embed_kernel_size: int, - kernel_size: int, - use_weight_standardization: bool = False, - enable_scaling: bool = False, - use_mha: bool = False, - cross_attention: bool = False, - kv_channels: Optional[int] = None, - ): - super().__init__() - assert delay * 2 + 1 <= embed_kernel_size - assert not (use_weight_standardization and use_mha) # 未対応 - self.use_weight_standardization = use_weight_standardization - self.use_mha = use_mha - self.cross_attention = cross_attention - self.embed = CausalConv1d(in_channels, channels, embed_kernel_size, delay=delay) - self.norm = nn.LayerNorm(channels) - self.convnext = nn.ModuleList() - for i in range(n_blocks): - pre_scale = 1.0 / math.sqrt(1.0 + i / n_blocks) if enable_scaling else 1.0 - post_scale = 1.0 / math.sqrt(n_blocks) if enable_scaling else 1.0 - block = ConvNeXtBlock( - channels=channels, - intermediate_channels=intermediate_channels, - layer_scale_init_value=1.0 / n_blocks, - kernel_size=kernel_size, - use_weight_standardization=use_weight_standardization, - enable_scaling=enable_scaling, - pre_scale=pre_scale, - post_scale=post_scale, - use_mha=use_mha, - cross_attention=cross_attention, - num_heads=4, - attention_dropout=0.1, - attention_channels=kv_channels, - kv_channels=kv_channels, - ) - self.convnext.append(block) - self.final_layer_norm = nn.LayerNorm(channels) - self.embed.weight.data.normal_( - 0.0, math.sqrt(0.5 / (embed_kernel_size * in_channels)) - ) - self.embed.bias.data.zero_() - if use_weight_standardization: - self.embed = WSConv1d(in_channels, channels, embed_kernel_size, delay=delay) - self.norm = nn.Identity() - self.final_layer_norm = nn.Identity() - - def forward( - self, x: torch.Tensor, kv: Optional[torch.Tensor] = None - ) -> torch.Tensor: - x = self.embed(x) - x = self.norm(x.transpose(1, 2)).transpose(1, 2) - if self.use_mha and not self.cross_attention: - pad_length = -x.size(2) % 4 - if pad_length: - x = F.pad(x, (0, pad_length)) - t40 = x.size(2) // 4 - attn_mask = torch.ones((t40, t40), dtype=torch.bool, device=x.device).triu( - 1 - ) - else: - attn_mask = None - for conv_block in self.convnext: - x = conv_block(x, attn_mask=attn_mask, kv=kv) - if self.use_mha and not self.cross_attention and pad_length: - x = x[:, :, :-pad_length] - x = self.final_layer_norm(x.transpose(1, 2)).transpose(1, 2) - return x - - def merge_weights(self): - if self.use_weight_standardization: - self.embed.merge_weights() - for conv_block in self.convnext: - conv_block.merge_weights() - - def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): - if isinstance(f, (str, bytes, os.PathLike)): - with open(f, "wb") as f: - self.dump(f) - return - if not hasattr(f, "write"): - raise TypeError - - dump_layer(self.embed, f) - if not self.use_weight_standardization: - dump_layer(self.norm, f) - dump_layer(self.convnext, f) - if not self.use_weight_standardization: - dump_layer(self.final_layer_norm, f) - - def dump_kv(self, f: Union[BinaryIO, str, bytes, os.PathLike]): - if isinstance(f, (str, bytes, os.PathLike)): - with open(f, "wb") as f: - self.dump_kv(f) - return - if not hasattr(f, "write"): - raise TypeError - - assert self.use_mha and self.cross_attention - for conv_block in self.convnext: - if not conv_block.use_mha or not conv_block.cross_attention: - continue - assert isinstance(conv_block, ConvNeXtBlock) - assert hasattr(conv_block, "mha") - assert isinstance(conv_block.mha, CrossAttention) - conv_block.mha.dump_kv(f) - - -class FeatureExtractor(nn.Module): - def __init__(self, hidden_channels: int): - super().__init__() - # fmt: off - self.conv0 = weight_norm(nn.Conv1d(1, hidden_channels // 8, 10, 5, bias=False)) - self.conv1 = weight_norm(nn.Conv1d(hidden_channels // 8, hidden_channels // 4, 3, 2, bias=False)) - self.conv2 = weight_norm(nn.Conv1d(hidden_channels // 4, hidden_channels // 2, 3, 2, bias=False)) - self.conv3 = weight_norm(nn.Conv1d(hidden_channels // 2, hidden_channels, 3, 2, bias=False)) - self.conv4 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 3, 2, bias=False)) - self.conv5 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 2, 2, bias=False)) - # fmt: on - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # x: [batch_size, 1, wav_length] - wav_length = x.size(2) - if wav_length % 160 != 0: - warnings.warn("wav_length % 160 != 0") - x = F.pad(x, (40, 40)) - x = F.gelu(self.conv0(x), approximate="tanh") - x = F.gelu(self.conv1(x), approximate="tanh") - x = F.gelu(self.conv2(x), approximate="tanh") - x = F.gelu(self.conv3(x), approximate="tanh") - x = F.gelu(self.conv4(x), approximate="tanh") - x = F.gelu(self.conv5(x), approximate="tanh") - # [batch_size, hidden_channels, wav_length / 160] - return x - - def remove_weight_norm(self): - remove_weight_norm(self.conv0) - remove_weight_norm(self.conv1) - remove_weight_norm(self.conv2) - remove_weight_norm(self.conv3) - remove_weight_norm(self.conv4) - remove_weight_norm(self.conv5) - - def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): - if isinstance(f, (str, bytes, os.PathLike)): - with open(f, "wb") as f: - self.dump(f) - return - if not hasattr(f, "write"): - raise TypeError - - dump_layer(self.conv0, f) - dump_layer(self.conv1, f) - dump_layer(self.conv2, f) - dump_layer(self.conv3, f) - dump_layer(self.conv4, f) - dump_layer(self.conv5, f) - - -class FeatureProjection(nn.Module): - def __init__(self, channels: int): - super().__init__() - self.norm = nn.LayerNorm(channels) - self.dropout = nn.Dropout(0.1) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # [batch_size, channels, length] - x = self.norm(x.transpose(1, 2)).transpose(1, 2) - x = self.dropout(x) - return x - - -class PhoneExtractor(nn.Module): - def __init__( - self, - phone_channels: int = 128, - hidden_channels: int = 128, - backbone_embed_kernel_size: int = 9, - kernel_size: int = 17, - n_blocks: int = 20, - ): - super().__init__() - self.feature_extractor = FeatureExtractor(hidden_channels) - self.feature_projection = FeatureProjection(hidden_channels) - self.backbone = ConvNeXtStack( - in_channels=hidden_channels, - channels=hidden_channels, - intermediate_channels=hidden_channels * 3, - n_blocks=n_blocks, - delay=0, - embed_kernel_size=backbone_embed_kernel_size, - kernel_size=kernel_size, - use_mha=True, - ) - self.head = weight_norm(nn.Conv1d(hidden_channels, phone_channels, 1)) - - def forward( - self, x: torch.Tensor, return_stats: bool = True - ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]: - # x: [batch_size, 1, wav_length] - - stats = {} - - # [batch_size, 1, wav_length] -> [batch_size, feature_extractor_hidden_channels, length] - x = self.feature_extractor(x) - if return_stats: - stats["feature_norm"] = x.detach().norm(dim=1).mean() - # [batch_size, feature_extractor_hidden_channels, length] -> [batch_size, hidden_channels, length] - x = self.feature_projection(x) - # [batch_size, hidden_channels, length] - x = self.backbone(x) - # [batch_size, hidden_channels, length] -> [batch_size, phone_channels, length] - phone = self.head(F.gelu(x, approximate="tanh")) - - results = [phone] - if return_stats: - stats["code_norm"] = phone.detach().norm(dim=1).mean() - results.append(stats) - - if len(results) == 1: - return results[0] - return tuple(results) - - @torch.inference_mode() - def units(self, x: torch.Tensor) -> torch.Tensor: - # x: [batch_size, 1, wav_length] - - # [batch_size, 1, wav_length] -> [batch_size, phone_channels, length] - phone = self.forward(x, return_stats=False) - # [batch_size, phone_channels, length] -> [batch_size, length, phone_channels] - phone = phone.transpose(1, 2) - # [batch_size, length, phone_channels] - return phone - - def remove_weight_norm(self): - self.feature_extractor.remove_weight_norm() - remove_weight_norm(self.head) - - def merge_weights(self): - self.backbone.merge_weights() - - self.backbone.embed.bias.data += ( - ( - self.feature_projection.norm.bias.data[None, :, None] - * self.backbone.embed.weight.data # [o, i, k] - ) - .sum(1) - .sum(1) - ) - self.backbone.embed.weight.data *= self.feature_projection.norm.weight.data[ - None, :, None - ] - self.feature_projection.norm.bias.data[:] = 0.0 - self.feature_projection.norm.weight.data[:] = 1.0 - - def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): - if isinstance(f, (str, bytes, os.PathLike)): - with open(f, "wb") as f: - self.dump(f) - return - if not hasattr(f, "write"): - raise TypeError - - dump_layer(self.feature_extractor, f) - dump_layer(self.backbone, f) - dump_layer(self.head, f) - - -class VectorQuantizer(nn.Module): - def __init__( - self, - n_speakers: int, - codebook_size: int, - channels: int, - topk: int = 4, - training_time_vq: Literal["none", "self", "random"] = "none", - ): - super().__init__() - assert 1 <= topk <= codebook_size - self.n_speakers = n_speakers - self.codebook_size = codebook_size - self.channels = channels - self.topk = topk - self.training_time_vq = training_time_vq - - self.register_buffer( - "codebooks", - torch.empty(n_speakers, codebook_size, channels, dtype=torch.half), - ) - self.codebooks: torch.Tensor - - # VQ の適用箇所を変更しやすいように hook にしている - self._hook_handle: Optional[torch.utils.hooks.RemovableHandle] = None - self.target_speaker_ids: Optional[torch.Tensor] = None - - def _hook(_, __, output): - return self(output, self.target_speaker_ids) - - self._hook_fn = _hook - - @torch.no_grad() - def build_codebooks( - self, - collector_func: Callable, - target_layer: nn.Module, - inputs: Sequence[Iterable[torch.Tensor]], - kmeans_n_iters: int = 50, - ): - assert len(inputs) == self.n_speakers - assert self._hook_handle is None, "hook already installed" - device = next(self.buffers()).device - - for spk_id, inps in enumerate(tqdm(inputs, desc="Building codebooks")): - activations: list[torch.Tensor] = [] - - # TODO: データ多すぎる場合に間引く処理をする - - def _collect(_, __, output): - # output: [batch_size, channels, length] - activations.append(output.detach()) - - handle = target_layer.register_forward_hook(_collect) - for x in inps: - collector_func(x.to(device)) - handle.remove() - - if not activations: - raise RuntimeError(f"No activation collected for speaker {spk_id}") - - # [n_data, channels] - activations: torch.Tensor = torch.cat( - [ - a.transpose(1, 2).reshape(a.size(0) * a.size(2), self.channels) - for a in activations - ] - ) - activations = activations.float() - activations = F.normalize(activations, dim=1, eps=1e-6) - # [codebook_size, channels] - centers = ( - self._kmeans_plus_plus(activations, self.codebook_size, kmeans_n_iters) - if activations.size(0) >= self.codebook_size - else self._pad_replicate(activations, self.codebook_size) - ) - self.codebooks[spk_id] = centers.to(self.codebooks.dtype) - - def forward( - self, x: torch.Tensor, speaker_ids: Optional[torch.Tensor] = None - ) -> torch.Tensor: - batch_size, channels, length = x.size() - assert channels == self.channels - device = x.device - dtype = x.dtype - - if self.training: - if self.training_time_vq == "none": - return x - elif self.training_time_vq == "self": - if self.target_speaker_ids is None: - raise ValueError("target_speaker_ids is not set") - elif self.training_time_vq == "random": - speaker_ids = torch.randint( - 0, self.n_speakers, (batch_size,), device=device - ) - else: - raise ValueError(f"Unknown training_time_vq: {self.training_time_vq}") - else: - if speaker_ids is None: - return x - speaker_ids = speaker_ids.to(device) - - # [batch_size, channels, length] → [batch_size, length, channels] - q = F.normalize(x, dim=1, eps=1e-6) - codes = self.codebooks[speaker_ids].to(q.dtype) - # [batch_size, length, codebook_size] - sim = torch.einsum("bcl,bkc->blk", q, codes) - - # [batch_size, length, topk] - _, topk_idx = sim.topk(self.topk, dim=-1) - # [batch_size, length, codebook_size, channels] - expanded_codes = codes[:, None, :, :].expand(-1, length, -1, -1) - # [batch_size, length, topk, channels] - expanded_topk_idx = topk_idx[:, :, :, None].expand(-1, -1, -1, channels) - # [batch_size, length, topk, channels] - gathered = expanded_codes.gather(2, expanded_topk_idx) - # [batch_size, length, channels] - gathered = gathered.mean(2) - # [batch_size, channels, length] - return gathered.transpose(1, 2).to(dtype) - - def enable_hook(self, target_layer: nn.Module): - if self._hook_handle is not None: - raise RuntimeError("hook already installed") - self._hook_handle = target_layer.register_forward_hook(self._hook_fn) - - def disable_hook(self): - if self._hook_handle is None: - raise RuntimeError("hook not installed") - self._hook_handle.remove() - self._hook_handle = None - - def set_target_speaker_ids(self, speaker_ids: Optional[torch.Tensor]): - # この話者が使われる条件は forward() を参照 - self.target_speaker_ids = speaker_ids - - @staticmethod - def _pad_replicate(x: torch.Tensor, n: int) -> torch.Tensor: - # データ数が n に満たないとき適当に複製して埋める - idx = torch.arange(n, device=x.device) % x.size(0) - return x[idx] - - @staticmethod - def _kmeans_plus_plus( - x: torch.Tensor, n_clusters: int, n_iters: int = 50 - ) -> torch.Tensor: - n_data, _ = x.size() - center_indices = [torch.randint(0, n_data, ()).item()] - min_distances = torch.full((n_data,), math.inf, device=x.device) - for _ in range(1, n_clusters): - last_center_index = center_indices[-1] - min_distances = min_distances.minimum( - torch.cdist(x, x[last_center_index : last_center_index + 1]) - .float() - .square_() - .squeeze_(1) - ) - probs = min_distances / (min_distances.sum() + 1e-12) - center_indices.append(torch.multinomial(probs, 1).item()) - centers = x[center_indices] - del min_distances, probs - for _ in range(n_iters): - distances = torch.cdist(x, centers) # [n_data, n_clusters] - labels = distances.argmin(1) # [n_data] - # [n_clusters, dim] - new_centers = torch.zeros_like(centers).index_add_(0, labels, x) - # [n_clusters] - counts = labels.bincount(minlength=n_clusters) - if (counts == 0).sum().item() != 0: - # TODO: 割り当てがないクラスタの処理 - warnings.warn("Some clusters have no assigned data points.") - new_centers /= counts[:, None].clamp_(min=1).float() - centers = new_centers - return centers - - - -def extract_pitch_features( - y: torch.Tensor, # [..., wav_length] - hop_length: int = 160, # 10ms - win_length: int = 560, # 35ms - max_corr_period: int = 256, # 16ms, 62.5Hz (16000 / 256) - corr_win_length: int = 304, # 19ms - instfreq_features_cutoff_bin: int = 64, # 1828Hz (16000 * 64 / 560) -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert max_corr_period + corr_win_length == win_length - - # パディングする - padding_length = (win_length - hop_length) // 2 - y = F.pad(y, (padding_length, padding_length)) - - # フレームにする - # [..., win_length, n_frames] - y_frames = y.unfold(-1, win_length, hop_length).transpose_(-2, -1) - - # 複素スペクトログラム - # Complex[..., (win_length // 2 + 1), n_frames] - spec: torch.Tensor = torch.fft.rfft(y_frames, n=win_length, dim=-2) - - # Complex[..., instfreq_features_cutoff_bin, n_frames] - spec = spec[..., :instfreq_features_cutoff_bin, :] - - # 対数パワースペクトログラム - log_power_spec = spec.abs().add_(1e-5).log10_() - - # 瞬時位相の時間差分 - # 時刻 0 の値は 0 - delta_spec = spec[..., :, 1:] * spec[..., :, :-1].conj() - delta_spec /= delta_spec.abs().add_(1e-5) - delta_spec = torch.cat( - [torch.zeros_like(delta_spec[..., :, :1]), delta_spec], dim=-1 - ) - - # [..., instfreq_features_cutoff_bin * 3, n_frames] - instfreq_features = torch.cat( - [log_power_spec, delta_spec.real, delta_spec.imag], dim=-2 - ) - - # 自己相関 - # 元々これに 2.0 / corr_win_length を掛けて使おうと思っていたが、 - # この値は振幅の 2 乗に比例していて、NN に入力するために良い感じに分散を - # 標準化する方法が思いつかなかったのでやめた - flipped_y_frames = y_frames.flip((-2,)) - a = torch.fft.rfft(flipped_y_frames, n=win_length, dim=-2) - b = torch.fft.rfft(y_frames[..., -corr_win_length:, :], n=win_length, dim=-2) - # [..., max_corr_period, n_frames] - corr = torch.fft.irfft(a * b, n=win_length, dim=-2)[..., corr_win_length:, :] - - # エネルギー項 - energy = flipped_y_frames.square_().cumsum_(-2) - energy0 = energy[..., corr_win_length - 1 : corr_win_length, :] - energy = energy[..., corr_win_length:, :] - energy[..., :-corr_win_length, :] - - # Difference function - corr_diff = (energy0 + energy).sub_(corr.mul_(2.0)) - assert corr_diff.min() >= -1e-3, corr_diff.min() - corr_diff.clamp_(min=0.0) # 計算誤差対策 - - # 標準化 - corr_diff *= 2.0 / corr_win_length - corr_diff.sqrt_() - - # 変換モデルへの入力用のエネルギー - energy = ( - (y_frames * torch.signal.windows.cosine(win_length, device=y.device)[..., None]) - .square_() - .sum(-2, keepdim=True) - ) - - energy.clamp_(min=1e-3).log10_() # >= -3, 振幅 1 の正弦波なら大体 2.15 - energy *= 0.5 # >= -1.5, 振幅 1 の正弦波なら大体 1.07, 1 の差は振幅で 20dB の差 - - return ( - instfreq_features, # [..., instfreq_features_cutoff_bin * 3, n_frames] - corr_diff, # [..., max_corr_period, n_frames] - energy, # [..., 1, n_frames] - ) - - -class PitchEstimator(nn.Module): - def __init__( - self, - input_instfreq_channels: int = 192, - input_corr_channels: int = 256, - pitch_bins: int = 448, - channels: int = 192, - intermediate_channels: int = 192 * 2, - n_blocks: int = 9, - delay: int = 1, # 10ms, 特徴抽出と合わせると 22.5ms - embed_kernel_size: int = 3, - kernel_size: int = 33, - pitch_bins_per_octave: int = 96, - ): - super().__init__() - self.pitch_bins_per_octave = pitch_bins_per_octave - - self.instfreq_embed_0 = nn.Conv1d(input_instfreq_channels, channels, 1) - self.instfreq_embed_1 = nn.Conv1d(channels, channels, 1) - self.corr_embed_0 = nn.Conv1d(input_corr_channels, channels, 1) - self.corr_embed_1 = nn.Conv1d(channels, channels, 1) - self.backbone = ConvNeXtStack( - channels, - channels, - intermediate_channels, - n_blocks, - delay, - embed_kernel_size, - kernel_size, - enable_scaling=True, - ) - self.head = nn.Conv1d(channels, pitch_bins, 1) - - def forward(self, wav: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - # wav: [batch_size, 1, wav_length] - - # [batch_size, input_instfreq_channels, length], - # [batch_size, input_corr_channels, length] - with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): - instfreq_features, corr_diff, energy = extract_pitch_features( - wav.squeeze(1), - hop_length=160, - win_length=560, - max_corr_period=256, - corr_win_length=304, - instfreq_features_cutoff_bin=64, - ) - instfreq_features = F.gelu( - self.instfreq_embed_0(instfreq_features), approximate="tanh" - ) - instfreq_features = self.instfreq_embed_1(instfreq_features) - corr_diff = F.gelu(self.corr_embed_0(corr_diff), approximate="tanh") - corr_diff = self.corr_embed_1(corr_diff) - # [batch_size, channels, length] - x = F.gelu(instfreq_features + corr_diff, approximate="tanh") - x = self.backbone(x) - # [batch_size, pitch_bins, length] - x = self.head(x) - return x, energy - - def sample_pitch( - self, pitch: torch.Tensor, band_width: int = 4, return_features: bool = False - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - # pitch: [batch_size, pitch_bins, length] - # 返されるピッチの値には 0 は含まれない - batch_size, pitch_bins, length = pitch.size() - pitch = pitch.softmax(1) - if return_features: - unvoiced_proba = pitch[:, :1, :].clone() - pitch[:, 0, :] = -100.0 - pitch = ( - pitch.transpose(1, 2).contiguous().view(batch_size * length, 1, pitch_bins) - ) - band_pitch = F.conv1d( - pitch, - torch.ones((1, 1, 1), device=pitch.device).expand(1, 1, band_width), - ) - # [batch_size * length, 1, pitch_bins - band_width + 1] -> Long[batch_size * length, 1] - quantized_band_pitch = band_pitch.argmax(2) - if return_features: - # [batch_size * length, 1] - band_proba = band_pitch.gather(2, quantized_band_pitch[:, :, None]) - # [batch_size * length, 1] - half_pitch_band_proba = band_pitch.gather( - 2, - (quantized_band_pitch - self.pitch_bins_per_octave).clamp_(min=1)[ - :, :, None - ], - ) - half_pitch_band_proba[ - quantized_band_pitch <= self.pitch_bins_per_octave - ] = 0.0 - half_pitch_proba = (half_pitch_band_proba / (band_proba + 1e-6)).view( - batch_size, 1, length - ) - # [batch_size * length, 1] - double_pitch_band_proba = band_pitch.gather( - 2, - (quantized_band_pitch + self.pitch_bins_per_octave).clamp_( - max=pitch_bins - band_width - )[:, :, None], - ) - double_pitch_band_proba[ - quantized_band_pitch - > pitch_bins - band_width - self.pitch_bins_per_octave - ] = 0.0 - double_pitch_proba = (double_pitch_band_proba / (band_proba + 1e-6)).view( - batch_size, 1, length - ) - # Long[1, pitch_bins] - mask = torch.arange(pitch_bins, device=pitch.device)[None, :] - # bool[batch_size * length, pitch_bins] - mask = (quantized_band_pitch <= mask) & ( - mask < quantized_band_pitch + band_width - ) - # Long[batch_size, length] - quantized_pitch = (pitch.squeeze(1) * mask).argmax(1).view(batch_size, length) - - if return_features: - features = torch.cat( - [unvoiced_proba, half_pitch_proba, double_pitch_proba], dim=1 - ) - # Long[batch_size, length], [batch_size, 3, length] - return quantized_pitch, features - else: - return quantized_pitch - - def merge_weights(self): - self.backbone.merge_weights() - - def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): - if isinstance(f, (str, bytes, os.PathLike)): - with open(f, "wb") as f: - self.dump(f) - return - if not hasattr(f, "write"): - raise TypeError - - dump_layer(self.instfreq_embed_0, f) - dump_layer(self.instfreq_embed_1, f) - dump_layer(self.corr_embed_0, f) - dump_layer(self.corr_embed_1, f) - dump_layer(self.backbone, f) - dump_layer(self.head, f) - - - -def overlap_add( - ir_amp: torch.Tensor, - ir_phase: torch.Tensor, - window: torch.Tensor, - pitch: torch.Tensor, - hop_length: int = 240, - delay: int = 0, - sr: float = 24000.0, -) -> torch.Tensor: - batch_size, ir_length, length = ir_amp.size() - ir_length = (ir_length - 1) * 2 - assert ir_phase.size() == ir_amp.size() - assert window.size() == (ir_length,), (window.size(), ir_amp.size()) - assert pitch.size() == (batch_size, length * hop_length) - assert 0 <= delay < ir_length, (delay, ir_length) - # 正規化角周波数 [2π rad] - normalized_freq = pitch / sr - # 初期位相 [2π rad] をランダムに設定 - normalized_freq[:, 0] = torch.rand(batch_size, device=pitch.device) - with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): - phase = (normalized_freq.double().cumsum_(1) % 1.0).float() - # 重ねる箇所を求める - # [n_pitchmarks], [n_pitchmarks] - indices0, indices1 = torch.nonzero(phase[:, :-1] > phase[:, 1:], as_tuple=True) - # 重ねる箇所の小数部分 (位相の遅れ) を求める - numer = 1.0 - phase[indices0, indices1] - # [n_pitchmarks] - fractional_part = numer / (numer + phase[indices0, indices1 + 1]) - # 重ねる値を求める - # Complex[n_pitchmarks, ir_length / 2 + 1] - ir_amp = ir_amp[indices0, :, indices1 // hop_length] - ir_phase = ir_phase[indices0, :, indices1 // hop_length] - # 位相遅れの量 [rad] - # [n_pitchmarks, ir_length / 2 + 1] - delay_phase = ( - torch.arange(ir_length // 2 + 1, device=pitch.device, dtype=torch.float32)[ - None, : - ] - * (-math.tau / ir_length) - * fractional_part[:, None] - ) - # Complex[n_pitchmarks, ir_length / 2 + 1] - spec = torch.polar(ir_amp, ir_phase + delay_phase) - # [n_pitchmarks, ir_length] - ir = torch.fft.irfft(spec, n=ir_length, dim=1) - ir *= window - - # 加算する値をサンプル単位にばらす - # [n_pitchmarks * ir_length] - ir = ir.ravel() - # Long[n_pitchmarks * ir_length] - indices0 = indices0[:, None].expand(-1, ir_length).ravel() - # Long[n_pitchmarks * ir_length] - indices1 = ( - indices1[:, None] + torch.arange(ir_length, device=pitch.device) - ).ravel() - - # overlap-add する - overlap_added_signal = torch.zeros( - (batch_size, length * hop_length + ir_length), device=pitch.device - ) - overlap_added_signal.index_put_((indices0, indices1), ir, accumulate=True) - overlap_added_signal = overlap_added_signal[:, delay : -ir_length + delay] - - return overlap_added_signal - - -def generate_noise( - aperiodicity: torch.Tensor, delay: int = 0 -) -> tuple[torch.Tensor, torch.Tensor]: - # aperiodicity: [batch_size, hop_length, length] - batch_size, hop_length, length = aperiodicity.size() - excitation = torch.rand( - batch_size, (length + 1) * hop_length, device=aperiodicity.device - ) - excitation -= 0.5 - n_fft = 2 * hop_length - # 矩形窓で分析 - # Complex[batch_size, hop_length + 1, length] - noise = torch.stft( - excitation, - n_fft=n_fft, - hop_length=hop_length, - window=torch.ones(n_fft, device=excitation.device), - center=False, - return_complex=True, - ) - assert noise.size(2) == aperiodicity.size(2) - noise[:, 0, :] = 0.0 - noise[:, 1:, :] *= aperiodicity - # ハン窓で合成 - # torch.istft は最適合成窓が使われるので使えないことに注意 - # [batch_size, 2 * hop_length, length] - noise = torch.fft.irfft(noise, n=2 * hop_length, dim=1) - noise *= torch.hann_window(2 * hop_length, device=noise.device)[None, :, None] - # [batch_size, (length + 1) * hop_length] - noise = F.fold( - noise, - (1, (length + 1) * hop_length), - (1, 2 * hop_length), - stride=(1, hop_length), - ).squeeze_((1, 2)) - - assert delay < hop_length - noise = noise[:, delay : -hop_length + delay] - excitation = excitation[:, delay : -hop_length + delay] - return noise, excitation # [batch_size, length * hop_length] - - -D4C_PREVENT_ZERO_DIVISION = True # False にすると本家の処理 - - -def interp(x: torch.Tensor, y: torch.Tensor, xi: torch.Tensor) -> torch.Tensor: - # x が単調増加で等間隔と仮定 - # 外挿は起こらないと仮定 - x = torch.as_tensor(x) - y = torch.as_tensor(y) - xi = torch.as_tensor(xi) - if xi.ndim < y.ndim: - diff_ndim = y.ndim - xi.ndim - xi = xi.view(tuple([1] * diff_ndim) + xi.size()) - if xi.size()[:-1] != y.size()[:-1]: - xi = xi.expand(y.size()[:-1] + (xi.size(-1),)) - assert (x.min(-1).values == x[..., 0]).all() - assert (x.max(-1).values == x[..., -1]).all() - assert (xi.min(-1).values >= x[..., 0]).all() - assert (xi.max(-1).values <= x[..., -1]).all() - delta_x = (x[..., -1].double() - x[..., 0].double()) / (x.size(-1) - 1.0) - delta_x = delta_x.to(x.dtype) - xi = (xi - x[..., :1]).div_(delta_x[..., None]) - xi_base = xi.floor() - xi_fraction = xi.sub_(xi_base) - xi_base = xi_base.long() - delta_y = y.diff(dim=-1, append=y[..., -1:]) - yi = y.gather(-1, xi_base) + delta_y.gather(-1, xi_base) * xi_fraction - return yi - - -def linear_smoothing( - group_delay: torch.Tensor, sr: int, n_fft: int, width: torch.Tensor -) -> torch.Tensor: - group_delay = torch.as_tensor(group_delay) - assert group_delay.size(-1) == n_fft // 2 + 1 - width = torch.as_tensor(width) - boundary = (width.max() * n_fft / sr).long() + 1 - - dtype = group_delay.dtype - device = group_delay.device - fft_resolution = sr / n_fft - mirroring_freq_axis = ( - torch.arange(-boundary, n_fft // 2 + 1 + boundary, dtype=dtype, device=device) - .add(0.5) - .mul(fft_resolution) - ) - if group_delay.ndim == 1: - mirroring_spec = F.pad( - group_delay[None], (boundary, boundary), mode="reflect" - ).squeeze_(0) - elif group_delay.ndim >= 4: - shape = group_delay.size() - mirroring_spec = F.pad( - group_delay.view(math.prod(shape[:-1]), group_delay.size(-1)), - (boundary, boundary), - mode="reflect", - ).view(shape[:-1] + (shape[-1] + 2 * boundary,)) - else: - mirroring_spec = F.pad(group_delay, (boundary, boundary), mode="reflect") - mirroring_segment = mirroring_spec.mul(fft_resolution).cumsum_(-1) - center_freq = torch.arange(n_fft // 2 + 1, dtype=dtype, device=device).mul_( - fft_resolution - ) - low_freq = center_freq - width[..., None] * 0.5 - high_freq = center_freq + width[..., None] * 0.5 - levels = interp( - mirroring_freq_axis, mirroring_segment, torch.cat([low_freq, high_freq], dim=-1) - ) - low_levels, high_levels = levels.split([n_fft // 2 + 1] * 2, dim=-1) - smoothed = (high_levels - low_levels).div_(width[..., None]) - return smoothed - - -def dc_correction( - spec: torch.Tensor, sr: int, n_fft: int, f0: torch.Tensor -) -> torch.Tensor: - spec = torch.as_tensor(spec) - f0 = torch.as_tensor(f0) - dtype = spec.dtype - device = spec.device - - upper_limit = 2 + (f0 * (n_fft / sr)).long() - max_upper_limit = upper_limit.max() - upper_limit_mask = ( - torch.arange(max_upper_limit - 1, device=device) < (upper_limit - 1)[..., None] - ) - low_freq_axis = torch.arange(max_upper_limit + 1, dtype=dtype, device=device) * ( - sr / n_fft - ) - low_freq_replica = interp( - f0[..., None] - low_freq_axis.flip(-1), - spec[..., : max_upper_limit + 1].flip(-1), - low_freq_axis[..., : max_upper_limit - 1] * upper_limit_mask, - ) - output = spec.clone() - output[..., : max_upper_limit - 1] += low_freq_replica * upper_limit_mask - return output - - -def nuttall(n: int, device: torch.types.Device) -> torch.Tensor: - t = torch.linspace(0, math.tau, n, device=device) - coefs = torch.tensor([0.355768, -0.487396, 0.144232, -0.012604], device=device) - terms = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device) - cos_matrix = (terms[:, None] * t).cos_() # [4, n] - window = coefs.matmul(cos_matrix) - return window - - -def get_windowed_waveform( - x: torch.Tensor, - sr: int, - f0: torch.Tensor, - position: torch.Tensor, - half_window_length_ratio: float, - window_type: Literal["hann", "blackman"], - n_fft: int, -) -> tuple[torch.Tensor, torch.Tensor]: - x = torch.as_tensor(x) - f0 = torch.as_tensor(f0) - position = torch.as_tensor(position) - - current_sample = position * sr - # [...] - half_window_length = (half_window_length_ratio * sr / f0).add_(0.5).long() - # [..., fft_size] - base_index = -half_window_length[..., None] + torch.arange(n_fft, device=x.device) - base_index_mask = base_index <= half_window_length[..., None] - # [..., fft_size] - safe_index = ((current_sample + 0.501).long()[..., None] + base_index).clamp_( - 0, x.size(-1) - 1 - ) - # [..., fft_size] - time_axis = base_index.to(x.dtype).div_(half_window_length_ratio) - # [...] - normalized_f0 = math.pi / sr * f0 - # [..., fft_size] - phase = time_axis.mul_(normalized_f0[..., None]) - - if window_type == "hann": - window = phase.cos_().mul_(0.5).add_(0.5) - elif window_type == "blackman": - window = phase.mul(2.0).cos_().mul_(0.08).add_(phase.cos().mul_(0.5)).add_(0.42) - else: - assert False - window *= base_index_mask - - prefix_shape = tuple( - max(x_size, i_size) for x_size, i_size in zip(x.size(), safe_index.size()) - )[:-1] - waveform = ( - x.expand(prefix_shape + (-1,)) - .gather(-1, safe_index.expand(prefix_shape + (-1,))) - .mul_(window) - ) - if not D4C_PREVENT_ZERO_DIVISION: - waveform += torch.randn_like(window).mul_(1e-12) - waveform *= base_index_mask - waveform -= window * waveform.sum(-1, keepdim=True).div_( - window.sum(-1, keepdim=True) - ) - return waveform, window - - -def get_centroid(x: torch.Tensor, n_fft: int) -> torch.Tensor: - x = torch.as_tensor(x) - if D4C_PREVENT_ZERO_DIVISION: - x = x / x.norm(dim=-1, keepdim=True).clamp(min=6e-8) - else: - x = x / x.norm(dim=-1, keepdim=True) - spec0 = torch.fft.rfft(x, n_fft) - spec1 = torch.fft.rfft( - x * torch.arange(1, x.size(-1) + 1, dtype=x.dtype, device=x.device).div_(n_fft), - n_fft, - ) - centroid = spec0.real * spec1.real + spec0.imag * spec1.imag - return centroid - - -def get_static_centroid( - x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int -) -> torch.Tensor: - """First step: calculation of temporally static parameters on basis of group delay""" - x1, _ = get_windowed_waveform( - x, sr, f0, position + 0.25 / f0, 2.0, "blackman", n_fft - ) - x2, _ = get_windowed_waveform( - x, sr, f0, position - 0.25 / f0, 2.0, "blackman", n_fft - ) - centroid1 = get_centroid(x1, n_fft) - centroid2 = get_centroid(x2, n_fft) - return dc_correction(centroid1 + centroid2, sr, n_fft, f0) - - -def get_smoothed_power_spec( - x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int -) -> tuple[torch.Tensor, torch.Tensor]: - x = torch.as_tensor(x) - f0 = torch.as_tensor(f0) - x, window = get_windowed_waveform(x, sr, f0, position, 2.0, "hann", n_fft) - window_weight = window.square().sum(-1, keepdim=True) - rms = x.square().sum(-1, keepdim=True).div_(window_weight).sqrt_() - if D4C_PREVENT_ZERO_DIVISION: - x = x / (rms * math.sqrt(n_fft)).clamp_(min=6e-8) - smoothed_power_spec = torch.fft.rfft(x, n_fft).abs().square_() - smoothed_power_spec = dc_correction(smoothed_power_spec, sr, n_fft, f0) - smoothed_power_spec = linear_smoothing(smoothed_power_spec, sr, n_fft, f0) - return smoothed_power_spec, rms.detach().squeeze(-1) - - -def get_static_group_delay( - static_centroid: torch.Tensor, - smoothed_power_spec: torch.Tensor, - sr: int, - f0: torch.Tensor, - n_fft: int, -) -> torch.Tensor: - """Second step: calculation of parameter shaping""" - if D4C_PREVENT_ZERO_DIVISION: - smoothed_power_spec = smoothed_power_spec.clamp(min=6e-8) - static_group_delay = static_centroid / smoothed_power_spec # t_g - static_group_delay = linear_smoothing( - static_group_delay, sr, n_fft, f0 * 0.5 - ) # t_gs - smoothed_group_delay = linear_smoothing(static_group_delay, sr, n_fft, f0) # t_gb - static_group_delay = static_group_delay - smoothed_group_delay # t_D - return static_group_delay - - -def get_coarse_aperiodicity( - group_delay: torch.Tensor, - sr: int, - n_fft: int, - freq_interval: int, - n_aperiodicities: int, - window: torch.Tensor, -) -> torch.Tensor: - """Third step: estimation of band-aperiodicity""" - group_delay = torch.as_tensor(group_delay) - window = torch.as_tensor(window) - boundary = int(round(n_fft * 8 / window.size(-1))) - half_window_length = window.size(-1) // 2 - coarse_aperiodicity = torch.empty( - group_delay.size()[:-1] + (n_aperiodicities,), - dtype=group_delay.dtype, - device=group_delay.device, - ) - for i in range(n_aperiodicities): - center = freq_interval * (i + 1) * n_fft // sr - segment = ( - group_delay[ - ..., center - half_window_length : center + half_window_length + 1 - ] - * window - ) - power_spec: torch.Tensor = torch.fft.rfft(segment, n_fft).abs().square_() - cumulative_power_spec = power_spec.sort(-1).values.cumsum_(-1) - if D4C_PREVENT_ZERO_DIVISION: - cumulative_power_spec.clamp_(min=6e-8) - coarse_aperiodicity[..., i] = ( - cumulative_power_spec[..., n_fft // 2 - boundary - 1] - / cumulative_power_spec[..., -1] - ) - coarse_aperiodicity.log10_().mul_(10.0) - return coarse_aperiodicity - - -def d4c_love_train( - x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, threshold: float -) -> int: - x = torch.as_tensor(x) - position = torch.as_tensor(position) - f0: torch.Tensor = torch.as_tensor(f0) - vuv = f0 != 0 - lowest_f0 = 40 - f0 = f0.clamp(min=lowest_f0) - n_fft = 1 << (3 * sr // lowest_f0).bit_length() - boundary0 = (100 * n_fft - 1) // sr + 1 - boundary1 = (4000 * n_fft - 1) // sr + 1 - boundary2 = (7900 * n_fft - 1) // sr + 1 - waveform, _ = get_windowed_waveform(x, sr, f0, position, 1.5, "blackman", n_fft) - power_spec = torch.fft.rfft(waveform, n_fft).abs().square_() - power_spec[..., : boundary0 + 1] = 0.0 - cumulative_spec = power_spec.cumsum_(-1) - vuv = vuv & ( - cumulative_spec[..., boundary1] > threshold * cumulative_spec[..., boundary2] - ) - return vuv - - -def d4c_general_body( - x: torch.Tensor, - sr: int, - f0: torch.Tensor, - freq_interval: int, - position: torch.Tensor, - n_fft: int, - n_aperiodicities: int, - window: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - static_centroid = get_static_centroid(x, sr, f0, position, n_fft) - smoothed_power_spec, rms = get_smoothed_power_spec(x, sr, f0, position, n_fft) - static_group_delay = get_static_group_delay( - static_centroid, smoothed_power_spec, sr, f0, n_fft - ) - coarse_aperiodicity = get_coarse_aperiodicity( - static_group_delay, sr, n_fft, freq_interval, n_aperiodicities, window - ) - coarse_aperiodicity.add_((f0[..., None] - 100.0).div_(50.0)).clamp_(max=0.0) - return coarse_aperiodicity, rms - - -def d4c( - x: torch.Tensor, - f0: torch.Tensor, - t: torch.Tensor, - sr: int, - threshold: float = 0.85, - n_fft_spec: Optional[int] = None, - coarse_only: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: - """Adapted from https://github.com/tuanad121/Python-WORLD/blob/master/world/d4c.py""" - FLOOR_F0 = 71 - FLOOR_F0_D4C = 47 - UPPER_LIMIT = 15000 - FREQ_INTERVAL = 3000 - - assert sr == int(sr) - sr = int(sr) - assert sr % 2 == 0 - x = torch.as_tensor(x) - f0 = torch.as_tensor(f0) - temporal_positions = torch.as_tensor(t) - - n_fft_d4c = 1 << (4 * sr // FLOOR_F0_D4C).bit_length() - if n_fft_spec is None: - n_fft_spec = 1 << (3 * sr // FLOOR_F0).bit_length() - n_aperiodicities = min(UPPER_LIMIT, sr // 2 - FREQ_INTERVAL) // FREQ_INTERVAL - assert n_aperiodicities >= 1 - window_length = FREQ_INTERVAL * n_fft_d4c // sr * 2 + 1 - window = nuttall(window_length, device=x.device) - freq_axis = torch.arange(n_fft_spec // 2 + 1, device=x.device) * (sr / n_fft_spec) - - coarse_aperiodicity, rms = d4c_general_body( - x[..., None, :], - sr, - f0.clamp(min=FLOOR_F0_D4C), - FREQ_INTERVAL, - temporal_positions, - n_fft_d4c, - n_aperiodicities, - window, - ) - if coarse_only: - return coarse_aperiodicity, rms - - even_coarse_axis = ( - torch.arange(n_aperiodicities + 3, device=x.device) * FREQ_INTERVAL - ) - assert even_coarse_axis[-2] <= sr // 2 < even_coarse_axis[-1], sr - coarse_axis_low = ( - torch.arange(n_aperiodicities + 1, dtype=torch.float, device=x.device) - * FREQ_INTERVAL - ) - aperiodicity_low = interp( - coarse_axis_low, - F.pad(coarse_aperiodicity, (1, 0), value=-60.0), - freq_axis[freq_axis < n_aperiodicities * FREQ_INTERVAL], - ) - coarse_axis_high = torch.tensor( - [n_aperiodicities * FREQ_INTERVAL, sr * 0.5], device=x.device - ) - aperiodicity_high = interp( - coarse_axis_high, - F.pad(coarse_aperiodicity[..., -1:], (0, 1), value=-1e-12), - freq_axis[freq_axis >= n_aperiodicities * FREQ_INTERVAL], - ) - aperiodicity = torch.cat([aperiodicity_low, aperiodicity_high], -1) - aperiodicity = 10.0 ** (aperiodicity / 20.0) - vuv = d4c_love_train(x[..., None, :], sr, f0, temporal_positions, threshold) - aperiodicity = torch.where(vuv[..., None], aperiodicity, 1 - 1e-12) - - return aperiodicity, coarse_aperiodicity - - -class Vocoder(nn.Module): - def __init__( - self, - channels: int, - speaker_embedding_channels: int = 128, - hop_length: int = 240, - n_pre_blocks: int = 4, - out_sample_rate: float = 24000.0, - ): - super().__init__() - self.hop_length = hop_length - self.out_sample_rate = out_sample_rate - - self.prenet = ConvNeXtStack( - in_channels=channels, - channels=channels, - intermediate_channels=channels * 2, - n_blocks=n_pre_blocks, - delay=2, # 20ms 遅延 - embed_kernel_size=7, - kernel_size=33, - enable_scaling=True, - use_mha=True, - cross_attention=True, - kv_channels=speaker_embedding_channels, - ) - self.ir_generator = ConvNeXtStack( - in_channels=channels, - channels=channels, - intermediate_channels=channels * 2, - n_blocks=2, - delay=0, - embed_kernel_size=3, - kernel_size=33, - use_weight_standardization=True, - enable_scaling=True, - ) - self.ir_generator_post = WSConv1d(channels, 512, 1) - self.register_buffer("ir_scale", torch.tensor(1.0)) - self.ir_window = nn.Parameter(torch.ones(512)) - self.aperiodicity_generator = ConvNeXtStack( - in_channels=channels, - channels=channels, - intermediate_channels=channels * 2, - n_blocks=1, - delay=0, - embed_kernel_size=3, - kernel_size=33, - use_weight_standardization=True, - enable_scaling=True, - ) - self.aperiodicity_generator_post = WSConv1d(channels, hop_length, 1, bias=False) - self.register_buffer("aperiodicity_scale", torch.tensor(0.005)) - self.post_filter_generator = ConvNeXtStack( - in_channels=channels, - channels=channels, - intermediate_channels=channels * 2, - n_blocks=1, - delay=0, - embed_kernel_size=3, - kernel_size=33, - use_weight_standardization=True, - enable_scaling=True, - ) - self.post_filter_generator_post = WSConv1d(channels, 512, 1, bias=False) - self.register_buffer("post_filter_scale", torch.tensor(0.01)) - - def forward( - self, x: torch.Tensor, pitch: torch.Tensor, speaker_embedding: torch.Tensor - ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: - # x: [batch_size, channels, length] - # pitch: [batch_size, length] - # speaker_embedding: [batch_size, speaker_embedding_length, speaker_embedding_channels] - batch_size, _, length = x.size() - - x = self.prenet(x, speaker_embedding) - ir = self.ir_generator(x) - ir = F.silu(ir, inplace=True) - # [batch_size, 512, length] - ir = self.ir_generator_post(ir) - ir *= self.ir_scale - ir_amp = ir[:, : ir.size(1) // 2 + 1, :].exp() - ir_phase = F.pad(ir[:, ir.size(1) // 2 + 1 :, :], (0, 0, 1, 1)) - ir_phase[:, 1::2, :] += math.pi - # TODO: 直流成分が正の値しか取れないのを修正する - - # 最近傍補間 - # [batch_size, length * hop_length] - pitch = torch.repeat_interleave(pitch, self.hop_length, dim=1) - - # [batch_size, length * hop_length] - periodic_signal = overlap_add( - ir_amp, - ir_phase, - self.ir_window, - pitch, - self.hop_length, - delay=0, - sr=self.out_sample_rate, - ) - - aperiodicity = self.aperiodicity_generator(x) - aperiodicity = F.silu(aperiodicity, inplace=True) - # [batch_size, hop_length, length] - aperiodicity = self.aperiodicity_generator_post(aperiodicity) - aperiodicity *= self.aperiodicity_scale - # [batch_size, length * hop_length], [batch_size, length * hop_length] - aperiodic_signal, noise_excitation = generate_noise(aperiodicity, delay=0) - - post_filter = self.post_filter_generator(x) - post_filter = F.silu(post_filter, inplace=True) - # [batch_size, 512, length] - post_filter = self.post_filter_generator_post(post_filter) - post_filter *= self.post_filter_scale - post_filter[:, 0, :] += 1.0 - # [batch_size, length, 512] - post_filter = post_filter.transpose(1, 2) - with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): - periodic_signal = periodic_signal.float() - aperiodic_signal = aperiodic_signal.float() - post_filter = post_filter.float() - post_filter = torch.fft.rfft(post_filter, n=768) - - # [batch_size, length, 768] - periodic_signal = torch.fft.irfft( - torch.fft.rfft( - periodic_signal.view(batch_size, length, self.hop_length), n=768 - ) - * post_filter, - n=768, - ) - aperiodic_signal = torch.fft.irfft( - torch.fft.rfft( - aperiodic_signal.view(batch_size, length, self.hop_length), n=768 - ) - * post_filter, - n=768, - ) - periodic_signal = F.fold( - periodic_signal.transpose(1, 2), - (1, (length - 1) * self.hop_length + 768), - (1, 768), - stride=(1, self.hop_length), - ).squeeze_((1, 2)) - aperiodic_signal = F.fold( - aperiodic_signal.transpose(1, 2), - (1, (length - 1) * self.hop_length + 768), - (1, 768), - stride=(1, self.hop_length), - ).squeeze_((1, 2)) - periodic_signal = periodic_signal[:, 120 : 120 + length * self.hop_length] - aperiodic_signal = aperiodic_signal[:, 120 : 120 + length * self.hop_length] - noise_excitation = noise_excitation[:, 120:] - - # TODO: compensation の正確さが怪しくなってくる。今も本当に必要なのか? - - # [batch_size, 1, length * hop_length] - y_g_hat = (periodic_signal + aperiodic_signal)[:, None, :] - - return y_g_hat, { - "periodic_signal": periodic_signal.detach(), - "aperiodic_signal": aperiodic_signal.detach(), - "noise_excitation": noise_excitation.detach(), - } - - def merge_weights(self): - self.prenet.merge_weights() - self.ir_generator.merge_weights() - self.ir_generator_post.merge_weights() - self.aperiodicity_generator.merge_weights() - self.aperiodicity_generator_post.merge_weights() - self.ir_generator_post.weight.data *= self.ir_scale - self.ir_generator_post.bias.data *= self.ir_scale - self.ir_scale.fill_(1.0) - self.aperiodicity_generator_post.weight.data *= self.aperiodicity_scale - self.aperiodicity_scale.fill_(1.0) - self.post_filter_generator.merge_weights() - self.post_filter_generator_post.merge_weights() - self.post_filter_generator_post.weight.data *= self.post_filter_scale - self.post_filter_scale.fill_(1.0) - - def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): - if isinstance(f, (str, bytes, os.PathLike)): - with open(f, "wb") as f: - self.dump(f) - return - if not hasattr(f, "write"): - raise TypeError - - dump_layer(self.prenet, f) - dump_layer(self.ir_generator, f) - dump_layer(self.ir_generator_post, f) - dump_layer(self.ir_window, f) - dump_layer(self.aperiodicity_generator, f) - dump_layer(self.aperiodicity_generator_post, f) - dump_layer(self.post_filter_generator, f) - dump_layer(self.post_filter_generator_post, f) - - -def compute_loudness( - x: torch.Tensor, sr: int, win_lengths: list[int] -) -> list[torch.Tensor]: - # x: [batch_size, wav_length] - assert x.ndim == 2 - n_fft = 2048 - chunk_length = n_fft // 2 - n_taps = chunk_length + 1 - - results = [] - with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): - if not hasattr(compute_loudness, "filter"): - compute_loudness.filter = {} - if sr not in compute_loudness.filter: - ir = torch.zeros(n_taps, device=x.device, dtype=torch.double) - ir[0] = 0.5 - ir = torchaudio.functional.treble_biquad( - ir, sr, 4.0, 1500.0, 1.0 / math.sqrt(2) - ) - ir = torchaudio.functional.highpass_biquad(ir, sr, 38.0, 0.5) - ir *= 2.0 - compute_loudness.filter[sr] = torch.fft.rfft(ir, n=n_fft).to( - torch.complex64 - ) - - x = x.float() - wav_length = x.size(-1) - if wav_length % chunk_length != 0: - x = F.pad(x, (0, chunk_length - wav_length % chunk_length)) - padded_wav_length = x.size(-1) - x = x.view(x.size()[:-1] + (padded_wav_length // chunk_length, chunk_length)) - x = torch.fft.irfft( - torch.fft.rfft(x, n=n_fft) * compute_loudness.filter[sr], - n=n_fft, - ) - x = F.fold( - x.transpose(-2, -1), - (1, padded_wav_length + chunk_length), - (1, n_fft), - stride=(1, chunk_length), - ).squeeze_((-3, -2))[..., :wav_length] - - x.square_() - for win_length in win_lengths: - hop_length = win_length // 4 - # [..., n_frames] - energy = ( - x.unfold(-1, win_length, hop_length) - .matmul(torch.hann_window(win_length, device=x.device)) - .add_(win_length / 4.0 * 1e-5) - .log10_() - ) - # フィルタリング後の波形が振幅 1 の正弦波なら大体 log10(win_length/4), 1 の差は 10dB の差 - results.append(energy) - return results - - -def beatrice_slice_segments( - x: torch.Tensor, start_indices: torch.Tensor, segment_length: int -) -> torch.Tensor: - batch_size, channels, _ = x.size() - # [batch_size, 1, segment_size] - indices = start_indices[:, None, None] + torch.arange( - segment_length, device=start_indices.device - ) - # [batch_size, channels, segment_size] - indices = indices.expand(batch_size, channels, segment_length) - return x.gather(2, indices) - - -class ConverterNetwork(nn.Module): - def __init__( - self, - phone_extractor: PhoneExtractor, - pitch_estimator: PitchEstimator, - n_speakers: int, - pitch_bins: int, - hidden_channels: int, - vq_topk: int = 4, - training_time_vq: Literal["none", "self", "random"] = "none", - phone_noise_ratio: int = 0.5, - floor_noise_level: float = 1e-3, - ): - super().__init__() - self.frozen_modules = { - "phone_extractor": phone_extractor.eval().requires_grad_(False), - "pitch_estimator": pitch_estimator.eval().requires_grad_(False), - } - self.pitch_bins = pitch_bins - self.phone_noise_ratio = phone_noise_ratio - self.floor_noise_level = floor_noise_level - self.out_sample_rate = out_sample_rate = 24000 - phone_channels = 128 - self.vq = VectorQuantizer( - n_speakers=n_speakers, - codebook_size=512, - channels=phone_channels, - topk=vq_topk, - training_time_vq=training_time_vq, - ) - self.embed_phone = nn.Conv1d(phone_channels, hidden_channels, 1) - self.embed_phone.weight.data.normal_(0.0, math.sqrt(2.0 / (256 * 5))) - self.embed_phone.bias.data.zero_() - self.embed_quantized_pitch = nn.Embedding(pitch_bins, hidden_channels) - phase = ( - torch.arange(pitch_bins, dtype=torch.float)[:, None] - * ( - torch.arange(0, hidden_channels, 2, dtype=torch.float) - * (-math.log(10000.0) / hidden_channels) - ).exp_() - ) - self.embed_quantized_pitch.weight.data[:, 0::2] = phase.sin() - self.embed_quantized_pitch.weight.data[:, 1::2] = phase.cos_() - self.embed_quantized_pitch.weight.data *= math.sqrt(4.0 / 5.0) - self.embed_quantized_pitch.weight.requires_grad_(False) - self.embed_pitch_features = nn.Conv1d(4, hidden_channels, 1) - self.embed_pitch_features.weight.data.normal_(0.0, math.sqrt(2.0 / (4 * 5))) - self.embed_pitch_features.bias.data.zero_() - self.embed_speaker = nn.Embedding(n_speakers, hidden_channels) - self.embed_speaker.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0)) - self.embed_formant_shift = nn.Embedding(9, hidden_channels) - self.embed_formant_shift.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0)) - - self.key_value_speaker_embedding_length = 384 - self.key_value_speaker_embedding_channels = 128 - self.key_value_speaker_embedding = nn.Embedding( - n_speakers, - self.key_value_speaker_embedding_length - * self.key_value_speaker_embedding_channels, - ) - self.key_value_speaker_embedding.weight.data[0].normal_() - self.key_value_speaker_embedding.weight.data[1:] = ( - self.key_value_speaker_embedding.weight.data[0] - ) - - self.vocoder = Vocoder( - channels=hidden_channels, - speaker_embedding_channels=self.key_value_speaker_embedding_channels, - hop_length=out_sample_rate // 100, - n_pre_blocks=4, - out_sample_rate=out_sample_rate, - ) - self.melspectrograms = nn.ModuleList() - for win_length, n_mels in [ - (32, 5), - (64, 10), - (128, 20), - (256, 40), - (512, 80), - (1024, 160), - (2048, 320), - ]: - self.melspectrograms.append( - torchaudio.transforms.MelSpectrogram( - sample_rate=out_sample_rate, - n_fft=win_length, - win_length=win_length, - hop_length=win_length // 4, - n_mels=n_mels, - power=2, - norm="slaney", - mel_scale="slaney", - ) - ) - - def initialize_vq(self, inputs: Sequence[Iterable[torch.Tensor]]): - collector_func = self.frozen_modules["phone_extractor"].units - target_layer = self.frozen_modules["phone_extractor"].head - - self.vq.build_codebooks( - collector_func, - target_layer, - inputs, - ) - self.vq.enable_hook(target_layer) - - def enable_hook(self): - target_layer = self.frozen_modules["phone_extractor"].head - self.vq.enable_hook(target_layer) - - def _get_resampler( - self, orig_freq, new_freq, device, cache={} - ) -> torchaudio.transforms.Resample: - key = orig_freq, new_freq - if key in cache: - return cache[key] - resampler = torchaudio.transforms.Resample(orig_freq, new_freq).to( - device, non_blocking=True - ) - cache[key] = resampler - return resampler - - def forward( - self, - x: torch.Tensor, - target_speaker_id: torch.Tensor, - formant_shift_semitone: torch.Tensor, - pitch_shift_semitone: Optional[torch.Tensor] = None, - slice_start_indices: Optional[torch.Tensor] = None, - slice_segment_length: Optional[int] = None, - return_stats: bool = False, - ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]: - # x: [batch_size, 1, wav_length] - # target_speaker_id: Long[batch_size] - # formant_shift_semitone: [batch_size] - # pitch_shift_semitone: [batch_size] - # slice_start_indices: [batch_size] - - batch_size, _, _ = x.size() - self.vq.set_target_speaker_ids(target_speaker_id) - - with torch.inference_mode(): - phone_extractor: PhoneExtractor = self.frozen_modules["phone_extractor"] - pitch_estimator: PitchEstimator = self.frozen_modules["pitch_estimator"] - # [batch_size, 1, wav_length] -> [batch_size, phone_channels, length] - phone = phone_extractor.units(x).transpose(1, 2) - - if self.training and self.phone_noise_ratio != 0.0: - phone *= (1.0 - self.phone_noise_ratio) / phone.square().mean( - 1, keepdim=True - ).sqrt_() - noise = torch.randn_like(phone) - noise *= ( - self.phone_noise_ratio - / noise.square().mean(1, keepdim=True).sqrt_() - ) - phone += noise - # F.rms_norm は PyTorch >= 2.4 が必要 - phone *= ( - 1.0 - / phone.square() - .mean(1, keepdim=True) - .add_(torch.finfo(torch.float).eps) - .sqrt_() - ) - - # [batch_size, 1, wav_length] -> [batch_size, pitch_bins, length], [batch_size, 1, length] - pitch, energy = pitch_estimator(x) - # augmentation - if self.training: - # [batch_size, pitch_bins - 1] - weights = pitch.softmax(1)[:, 1:, :].mean(2) - # [batch_size] - mean_pitch = ( - weights - * torch.arange( - 1, - self.embed_quantized_pitch.num_embeddings, - device=weights.device, - ) - ).sum(1) / weights.sum(1) - mean_pitch = mean_pitch.round_().long() - target_pitch = torch.randint_like(mean_pitch, 64, 257) - shift = target_pitch - mean_pitch - shift_ratio = ( - 2.0 ** (shift.float() / pitch_estimator.pitch_bins_per_octave) - ).tolist() - shift = [] - interval_length = 100 # 1s - interval_zeros = torch.zeros( - (1, 1, interval_length * 160), device=x.device - ) - concatenated_shifted_x = [] - offsets = [0] - torch.backends.cudnn.benchmark = False - for i in range(batch_size): - shift_ratio_i = shift_ratio[i] - shift_ratio_fraction_i = Fraction.from_float( - shift_ratio_i - ).limit_denominator(30) - shift_numer_i = shift_ratio_fraction_i.numerator - shift_denom_i = shift_ratio_fraction_i.denominator - shift_ratio_i = shift_numer_i / shift_denom_i - shift_i = int( - round( - math.log2(shift_ratio_i) - * pitch_estimator.pitch_bins_per_octave - ) - ) - shift.append(shift_i) - shift_ratio[i] = shift_ratio_i - # [1, 1, wav_length / shift_ratio] - with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): - shifted_x_i = self._get_resampler( - shift_numer_i, shift_denom_i, x.device - )(x[i])[None] - if shifted_x_i.size(2) % 160 != 0: - shifted_x_i = F.pad( - shifted_x_i, - (0, 160 - shifted_x_i.size(2) % 160), - mode="reflect", - ) - assert shifted_x_i.size(2) % 160 == 0 - offsets.append( - offsets[-1] + interval_length + shifted_x_i.size(2) // 160 - ) - concatenated_shifted_x.extend([interval_zeros, shifted_x_i]) - if offsets[-1] % 256 != 0: - # 長さが同じ方が何かのキャッシュが効いて早くなるようなので - # 適当に 256 の倍数になるようにパディングして長さのパターン数を減らす - concatenated_shifted_x.append( - torch.zeros( - (1, 1, (256 - offsets[-1] % 256) * 160), device=x.device - ) - ) - # [batch_size, 1, sum(wav_length) + batch_size * 16000] - concatenated_shifted_x = torch.cat(concatenated_shifted_x, dim=2) - assert concatenated_shifted_x.size(2) % (256 * 160) == 0 - # [1, pitch_bins, length / shift_ratio], [1, 1, length / shift_ratio] - concatenated_pitch, concatenated_energy = pitch_estimator( - concatenated_shifted_x - ) - for i in range(batch_size): - shift_i = shift[i] - shift_ratio_i = shift_ratio[i] - left = offsets[i] + interval_length - right = offsets[i + 1] - pitch_i = concatenated_pitch[:, :, left:right] - energy_i = concatenated_energy[:, :, left:right] - pitch_i = F.interpolate( - pitch_i, - scale_factor=shift_ratio_i, - mode="linear", - align_corners=False, - ) - energy_i = F.interpolate( - energy_i, - scale_factor=shift_ratio_i, - mode="linear", - align_corners=False, - ) - assert pitch_i.size(2) == energy_i.size(2) - assert abs(pitch_i.size(2) - pitch.size(2)) <= 10 - length = min(pitch_i.size(2), pitch.size(2)) - - if shift_i > 0: - pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length] - pitch[i : i + 1, 1:-shift_i, :length] = pitch_i[ - :, 1 + shift_i :, :length - ] - pitch[i : i + 1, -shift_i:, :length] = -10.0 - elif shift_i < 0: - pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length] - pitch[i : i + 1, 1 : 1 - shift_i, :length] = -10.0 - pitch[i : i + 1, 1 - shift_i :, :length] = pitch_i[ - :, 1:shift_i, :length - ] - energy[i : i + 1, :, :length] = energy_i[:, :, :length] - torch.backends.cudnn.benchmark = True - - # [batch_size, pitch_bins, length] -> Long[batch_size, length], [batch_size, 3, length] - quantized_pitch, pitch_features = pitch_estimator.sample_pitch( - pitch, return_features=True - ) - if pitch_shift_semitone is not None: - quantized_pitch = torch.where( - quantized_pitch == 0, - quantized_pitch, - ( - quantized_pitch - + ( - pitch_shift_semitone[:, None] - * (pitch_estimator.pitch_bins_per_octave / 12.0) - ) - .round_() - .long() - ).clamp_(1, self.pitch_bins - 1), - ) - pitch = 55.0 * 2.0 ** ( - quantized_pitch.float() / pitch_estimator.pitch_bins_per_octave - ) - # phone が 2.5ms 先読みしているのに対して、 - # energy は 12.5ms, pitch_features は 22.5ms 先読みしているので、 - # ずらして phone に合わせる - energy = F.pad(energy[:, :, :-1], (1, 0), mode="reflect") - quantized_pitch = F.pad(quantized_pitch[:, :-2], (2, 0), mode="reflect") - pitch_features = F.pad(pitch_features[:, :, :-2], (2, 0), mode="reflect") - # [batch_size, 1, length], [batch_size, 3, length] -> [batch_size, 4, length] - pitch_features = torch.cat([energy, pitch_features], dim=1) - formant_shift_indices = ( - ((formant_shift_semitone + 2.0) * 2.0).round_().long() - ) - - phone = phone.clone() - quantized_pitch = quantized_pitch.clone() - pitch_features = pitch_features.clone() - formant_shift_indices = formant_shift_indices.clone() - pitch = pitch.clone() - - # [batch_sise, hidden_channels, length] - x = ( - self.embed_phone(phone) - + self.embed_quantized_pitch(quantized_pitch).transpose(1, 2) - + self.embed_pitch_features(pitch_features) - + ( - self.embed_speaker(target_speaker_id)[:, :, None] - + self.embed_formant_shift(formant_shift_indices)[:, :, None] - ) - ) - if slice_start_indices is not None: - assert slice_segment_length is not None - # [batch_size, hidden_channels, length] -> [batch_size, hidden_channels, segment_length] - x = beatrice_slice_segments(x, slice_start_indices, slice_segment_length) - x = F.silu(x, inplace=True) - - speaker_embedding = self.key_value_speaker_embedding(target_speaker_id).view( - batch_size, - self.key_value_speaker_embedding_length, - self.key_value_speaker_embedding_channels, - ) - - # [batch_size, hidden_channels, segment_length] -> [batch_size, 1, segment_length * 240] - y_g_hat, stats = self.vocoder(x, pitch, speaker_embedding) - stats["pitch"] = pitch - if return_stats: - return y_g_hat, stats - else: - return y_g_hat - - def _normalize_melsp(self, x): - return x.clamp(min=1e-10).log_() - - def forward_and_compute_loss( - self, - noisy_wavs_16k: torch.Tensor, - target_speaker_id: torch.Tensor, - formant_shift_semitone: torch.Tensor, - slice_start_indices: torch.Tensor, - slice_segment_length: int, - y_all: torch.Tensor, - enable_loss_ap: bool = False, - ) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - dict[str, float], - ]: - # noisy_wavs_16k: [batch_size, 1, wav_length] - # target_speaker_id: Long[batch_size] - # formant_shift_semitone: [batch_size] - # slice_start_indices: [batch_size] - # slice_segment_length: int - # y_all: [batch_size, 1, wav_length] - - stats = {} - loss_mel = 0.0 - loss_loudness = 0.0 - loudness_win_lengths = [512, 1024, 2048, 4096] - - # [batch_size, 1, wav_length] -> [batch_size, 1, wav_length * 240] - y_hat_all, intermediates = self( - noisy_wavs_16k, - target_speaker_id, - formant_shift_semitone, - return_stats=True, - ) - y_hat_all = y_hat_all.detach().where(y_all == 0.0, y_hat_all) - - with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): - periodic_signal = intermediates["periodic_signal"].float() - aperiodic_signal = intermediates["aperiodic_signal"].float() - noise_excitation = intermediates["noise_excitation"].float() - periodic_signal = periodic_signal[:, : noise_excitation.size(1)] - aperiodic_signal = aperiodic_signal[:, : noise_excitation.size(1)] - y_hat_all = y_hat_all.float() - floor_noise = torch.randn_like(y_all) * self.floor_noise_level - y_all = y_all + floor_noise - y_hat_all += floor_noise - y_hat_all_truncated = y_hat_all.squeeze(1)[:, : periodic_signal.size(1)] - y_all_truncated = y_all.squeeze(1)[:, : periodic_signal.size(1)] - - y_loudness = compute_loudness( - y_all_truncated, self.out_sample_rate, loudness_win_lengths - ) - y_hat_loudness = compute_loudness( - y_hat_all_truncated, self.out_sample_rate, loudness_win_lengths - ) - for win_length, y_loudness_i, y_hat_loudness_i in zip( - loudness_win_lengths, y_loudness, y_hat_loudness - ): - loss_loudness_i = F.mse_loss(y_hat_loudness_i, y_loudness_i) - loss_loudness += loss_loudness_i * math.sqrt(win_length) - stats[f"loss_loudness_{win_length}"] = loss_loudness_i.item() - - for melspectrogram in self.melspectrograms: - melsp_periodic_signal = melspectrogram(periodic_signal) - melsp_aperiodic_signal = melspectrogram(aperiodic_signal) - melsp_noise_excitation = melspectrogram(noise_excitation) - # [1, n_mels, 1] - # 1/6 ... [-0.5, 0.5] の一様乱数の平均パワー - # 3/8 ... ハン窓をかけた時のパワー減衰 - # 0.5 ... 謎 - reference_melsp = melspectrogram.mel_scale( - torch.full( - (1, melspectrogram.n_fft // 2 + 1, 1), - (1 / 6) * (3 / 8) * 0.5 * melspectrogram.win_length, - device=noisy_wavs_16k.device, - ) - ) - aperiodic_ratio = melsp_aperiodic_signal / ( - melsp_periodic_signal + melsp_aperiodic_signal + 1e-5 - ) - compensation_ratio = reference_melsp / (melsp_noise_excitation + 1e-5) - - melsp_y_hat = melspectrogram(y_hat_all_truncated) - melsp_y_hat = melsp_y_hat * ( - (1.0 - aperiodic_ratio) + aperiodic_ratio * compensation_ratio - ) - y_hat_mel = self._normalize_melsp(melsp_y_hat) - - y_mel = self._normalize_melsp(melspectrogram(y_all_truncated)) - loss_mel_i = F.l1_loss(y_hat_mel, y_mel) - loss_mel += loss_mel_i - stats[ - f"loss_mel_{melspectrogram.win_length}_{melspectrogram.n_mels}" - ] = loss_mel_i.item() - - loss_mel /= len(self.melspectrograms) - - if enable_loss_ap: - t = ( - torch.arange(intermediates["pitch"].size(1), device=y_all.device) - * 0.01 - + 0.005 - ) - y_coarse_aperiodicity, y_rms = d4c( - y_all.squeeze(1), - intermediates["pitch"], - t, - self.vocoder.out_sample_rate, - coarse_only=True, - ) - y_coarse_aperiodicity = 10.0 ** (y_coarse_aperiodicity / 10.0) - y_hat_coarse_aperiodicity, y_hat_rms = d4c( - y_hat_all.squeeze(1), - intermediates["pitch"], - t, - self.vocoder.out_sample_rate, - coarse_only=True, - ) - y_hat_coarse_aperiodicity = 10.0 ** (y_hat_coarse_aperiodicity / 10.0) - rms = torch.maximum(y_rms, y_hat_rms) - loss_ap = F.mse_loss( - y_hat_coarse_aperiodicity, y_coarse_aperiodicity, reduction="none" - ) - loss_ap *= (rms / (rms + 1e-3) * (rms > 1e-5))[:, :, None] - loss_ap = loss_ap.mean() - else: - loss_ap = torch.tensor(0.0) - - # [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240] - y_hat = beatrice_slice_segments( - y_hat_all, slice_start_indices * 240, slice_segment_length * 240 - ) - # [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240] - y = beatrice_slice_segments(y_all, slice_start_indices * 240, slice_segment_length * 240) - return y, y_hat, y_hat_all, loss_loudness, loss_mel, loss_ap, stats - - def merge_weights(self): - self.vocoder.merge_weights() - - def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): - if isinstance(f, (str, bytes, os.PathLike)): - with open(f, "wb") as f: - self.dump(f) - return - if not hasattr(f, "write"): - raise TypeError - - dump_layer(self.embed_phone, f) - dump_layer(self.embed_quantized_pitch, f) - dump_layer(self.embed_pitch_features, f) - dump_layer(self.vocoder, f) - - def dump_speaker_embeddings(self, f: Union[BinaryIO, str, bytes, os.PathLike]): - if isinstance(f, (str, bytes, os.PathLike)): - with open(f, "wb") as f: - self.dump_speaker_embeddings(f) - return - if not hasattr(f, "write"): - raise TypeError - - dump_params(self.vq.codebooks, f) - dump_layer(self.embed_speaker, f) - dump_layer(self.embed_formant_shift, f) - dump_layer(self.key_value_speaker_embedding, f) - - def dump_embedding_setter(self, f: Union[BinaryIO, str, bytes, os.PathLike]): - if isinstance(f, (str, bytes, os.PathLike)): - with open(f, "wb") as f: - self.dump_embedding_setter(f) - return - if not hasattr(f, "write"): - raise TypeError - - self.vocoder.prenet.dump_kv(f) - - -# Discriminator - - -def _normalize(tensor: torch.Tensor, dim: int) -> torch.Tensor: - denom = tensor.norm(p=2.0, dim=dim, keepdim=True).clamp_min(1e-6) - return tensor / denom - - -class SANConv2d(nn.Conv2d): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: int = 0, - dilation: int = 1, - bias: bool = True, - padding_mode="zeros", - device=None, - dtype=None, - ): - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding=padding, - dilation=dilation, - groups=1, - bias=bias, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - scale = self.weight.norm(p=2.0, dim=[1, 2, 3], keepdim=True).clamp_min(1e-6) - self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight)) - self.scale = nn.parameter.Parameter(scale.view(out_channels)) - if bias: - self.bias = nn.parameter.Parameter( - torch.zeros(in_channels, device=device, dtype=dtype) - ) - else: - self.register_parameter("bias", None) - - def forward( - self, input: torch.Tensor, flg_san_train: bool = False - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - if self.bias is not None: - input = input + self.bias.view(self.in_channels, 1, 1) - normalized_weight = self._get_normalized_weight() - scale = self.scale.view(self.out_channels, 1, 1) - if flg_san_train: - out_fun = F.conv2d( - input, - normalized_weight.detach(), - None, - self.stride, - self.padding, - self.dilation, - self.groups, - ) - out_dir = F.conv2d( - input.detach(), - normalized_weight, - None, - self.stride, - self.padding, - self.dilation, - self.groups, - ) - out = out_fun * scale, out_dir * scale.detach() - else: - out = F.conv2d( - input, - normalized_weight, - None, - self.stride, - self.padding, - self.dilation, - self.groups, - ) - out = out * scale - return out - - @torch.no_grad() - def normalize_weight(self): - self.weight.data = self._get_normalized_weight() - - def _get_normalized_weight(self) -> torch.Tensor: - return _normalize(self.weight, dim=[1, 2, 3]) - - -def get_padding(kernel_size: int, dilation: int = 1) -> int: - return (kernel_size * dilation - dilation) // 2 - - -class BeatriceDiscriminatorP(nn.Module): - def __init__( - self, period: int, kernel_size: int = 5, stride: int = 3, san: bool = False - ): - super().__init__() - self.period = period - self.san = san - # fmt: off - self.convs = nn.ModuleList([ - weight_norm(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), - weight_norm(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), - weight_norm(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), - weight_norm(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), - weight_norm(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, (get_padding(kernel_size, 1), 0))), - ]) - # fmt: on - if san: - self.conv_post = SANConv2d(1024, 1, (3, 1), 1, (1, 0)) - else: - self.conv_post = weight_norm(nn.Conv2d(1024, 1, (3, 1), 1, (1, 0))) - - def forward( - self, x: torch.Tensor, flg_san_train: bool = False - ) -> tuple[ - Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor] - ]: - fmap = [] - - b, c, t = x.shape - if t % self.period != 0: - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for conv in self.convs: - x = conv(x) - x = F.silu(x, inplace=True) - fmap.append(x) - if self.san: - x = self.conv_post(x, flg_san_train=flg_san_train) - else: - x = self.conv_post(x) - if flg_san_train: - x_fun, x_dir = x - fmap.append(x_fun) - x_fun = torch.flatten(x_fun, 1, -1) - x_dir = torch.flatten(x_dir, 1, -1) - x = x_fun, x_dir - else: - fmap.append(x) - x = torch.flatten(x, 1, -1) - return x, fmap - - -class BeatriceDiscriminatorR(nn.Module): - def __init__(self, resolution: int, san: bool = False): - super().__init__() - self.resolution = resolution - self.san = san - assert len(self.resolution) == 3 - self.convs = nn.ModuleList( - [ - weight_norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), - weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), - weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), - weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), - weight_norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), - ] - ) - if san: - self.conv_post = SANConv2d(32, 1, (3, 3), padding=(1, 1)) - else: - self.conv_post = weight_norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) - - def forward( - self, x: torch.Tensor, flg_san_train: bool = False - ) -> tuple[ - Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor] - ]: - fmap = [] - - x = self._spectrogram(x).unsqueeze(1) - for conv in self.convs: - x = conv(x) - x = F.silu(x, inplace=True) - fmap.append(x) - if self.san: - x = self.conv_post(x, flg_san_train=flg_san_train) - else: - x = self.conv_post(x) - if flg_san_train: - x_fun, x_dir = x - fmap.append(x_fun) - x_fun = torch.flatten(x_fun, 1, -1) - x_dir = torch.flatten(x_dir, 1, -1) - x = x_fun, x_dir - else: - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - def _spectrogram(self, x: torch.Tensor) -> torch.Tensor: - n_fft, hop_length, win_length = self.resolution - x = F.pad( - x, ((n_fft - hop_length) // 2, (n_fft - hop_length) // 2), mode="reflect" - ).squeeze(1) - with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): - mag = torch.stft( - x.float(), - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - window=torch.ones(win_length, device=x.device), - center=False, - return_complex=True, - ).abs() - - return mag - - -class BeatriceMultiPeriodDiscriminator(nn.Module): - def __init__(self, san: bool = False): - super().__init__() - resolutions = [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]] - periods = [2, 3, 5, 7, 11] - self.discriminators = nn.ModuleList( - [BeatriceDiscriminatorR(r, san=san) for r in resolutions] - + [BeatriceDiscriminatorP(p, san=san) for p in periods] - ) - self.discriminator_names = [f"R_{n}_{h}_{w}" for n, h, w in resolutions] + [ - f"P_{p}" for p in periods - ] - self.san = san - - def forward( - self, y: torch.Tensor, y_hat: torch.Tensor, flg_san_train: bool = False - ) -> tuple[ - list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]], - list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]], - list[list[torch.Tensor]], - list[list[torch.Tensor]], - ]: - batch_size = y.size(0) - concatenated_y_y_hat = torch.cat([y, y_hat]) - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for d in self.discriminators: - if flg_san_train: - (y_d_fun, y_d_dir), fmap = d( - concatenated_y_y_hat, flg_san_train=flg_san_train - ) - y_d_r_fun, y_d_g_fun = torch.split(y_d_fun, batch_size) - y_d_r_dir, y_d_g_dir = torch.split(y_d_dir, batch_size) - y_d_r = y_d_r_fun, y_d_r_dir - y_d_g = y_d_g_fun, y_d_g_dir - else: - y_d, fmap = d(concatenated_y_y_hat, flg_san_train=flg_san_train) - y_d_r, y_d_g = torch.split(y_d, batch_size) - fmap_r = [] - fmap_g = [] - for fm in fmap: - fm_r, fm_g = torch.split(fm, batch_size) - fmap_r.append(fm_r) - fmap_g.append(fm_g) - y_d_rs.append(y_d_r) - y_d_gs.append(y_d_g) - fmap_rs.append(fmap_r) - fmap_gs.append(fmap_g) - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - def forward_and_compute_loss( - self, y: torch.Tensor, y_hat: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float]]: - y_d_rs, y_d_gs, fmap_rs, fmap_gs = self(y, y_hat, flg_san_train=self.san) - stats = {} - assert len(y_d_gs) == len(y_d_rs) == len(self.discriminators) - with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): - # discriminator loss - d_loss = 0.0 - for dr, dg, name in zip(y_d_rs, y_d_gs, self.discriminator_names): - if self.san: - dr_fun, dr_dir = map(lambda x: x.float(), dr) - dg_fun, dg_dir = map(lambda x: x.float(), dg) - r_loss_fun = F.softplus(1.0 - dr_fun).square().mean() - g_loss_fun = F.softplus(dg_fun).square().mean() - r_loss_dir = F.softplus(1.0 - dr_dir).square().mean() - g_loss_dir = -F.softplus(1.0 - dg_dir).square().mean() - r_loss = r_loss_fun + r_loss_dir - g_loss = g_loss_fun + g_loss_dir - else: - dr = dr.float() - dg = dg.float() - r_loss = (1.0 - dr).square().mean() - g_loss = dg.square().mean() - stats[f"{name}_dr_loss"] = r_loss.item() - stats[f"{name}_dg_loss"] = g_loss.item() - d_loss += r_loss + g_loss - # adversarial loss - adv_loss = 0.0 - for dg, name in zip(y_d_gs, self.discriminator_names): - if self.san: - dg_fun = dg[0].float() - g_loss = F.softplus(1.0 - dg_fun).square().mean() - else: - dg = dg.float() - g_loss = (1.0 - dg).square().mean() - stats[f"{name}_gg_loss"] = g_loss.item() - adv_loss += g_loss - # feature mathcing loss - fm_loss = 0.0 - for fr, fg, name in zip(fmap_rs, fmap_gs, self.discriminator_names): - fm_loss_i = 0.0 - for j, (r, g) in enumerate(zip(fr, fg)): - fm_loss_ij = (r.detach().float() - g.float()).abs().mean() - stats[f"~{name}_fm_loss_{j}"] = fm_loss_ij.item() - fm_loss_i += fm_loss_ij - stats[f"{name}_fm_loss"] = fm_loss_i.item() - fm_loss += fm_loss_i - return d_loss, adv_loss, fm_loss, stats - - - -class GradBalancer: - """Adapted from https://github.com/facebookresearch/encodec/blob/main/encodec/balancer.py""" - - def __init__( - self, - weights: dict[str, float], - rescale_grads: bool = True, - total_norm: float = 1.0, - ema_decay: float = 0.999, - per_batch_item: bool = True, - ): - self.weights = weights - self.per_batch_item = per_batch_item - self.total_norm = total_norm - self.ema_decay = ema_decay - self.rescale_grads = rescale_grads - - self.ema_total: dict[str, float] = defaultdict(float) - self.ema_fix: dict[str, float] = defaultdict(float) - - def backward( - self, - losses: dict[str, torch.Tensor], - input: torch.Tensor, - scaler: Optional[torch.amp.GradScaler] = None, - skip_update_ema: bool = False, - ) -> dict[str, float]: - stats = {} - if skip_update_ema: - assert len(losses) == len(self.ema_total) - ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()} - else: - # 各 loss に対して d loss / d input とそのノルムを計算する - norms = {} - grads = {} - for name, loss in losses.items(): - if scaler is not None: - loss = scaler.scale(loss) - (grad,) = torch.autograd.grad(loss, [input], retain_graph=True) - if not grad.isfinite().all(): - input.backward(grad) - return {} - grad = grad.detach() / (1.0 if scaler is None else scaler.get_scale()) - if self.per_batch_item: - dims = tuple(range(1, grad.dim())) - ema_norm = grad.norm(dim=dims).mean() - else: - ema_norm = grad.norm() - norms[name] = float(ema_norm) - grads[name] = grad - - # ノルムの移動平均を計算する - for key, value in norms.items(): - self.ema_total[key] = self.ema_total[key] * self.ema_decay + value - self.ema_fix[key] = self.ema_fix[key] * self.ema_decay + 1.0 - ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()} - - # ログを取る - total_ema_norm = sum(ema_norms.values()) - for k, ema_norm in ema_norms.items(): - stats[f"grad_norm_value_{k}"] = ema_norm - stats[f"grad_norm_ratio_{k}"] = ema_norm / (total_ema_norm + 1e-12) - - # loss の係数の比率を計算する - if self.rescale_grads: - total_weights = sum([self.weights[k] for k in ema_norms]) - ratios = {k: w / total_weights for k, w in self.weights.items()} - - # 勾配を修正する - loss = 0.0 - for name, ema_norm in ema_norms.items(): - if self.rescale_grads: - scale = ratios[name] * self.total_norm / (ema_norm + 1e-12) - else: - scale = self.weights[name] - loss += (losses if skip_update_ema else grads)[name] * scale - if scaler is not None: - loss = scaler.scale(loss) - if skip_update_ema: - (loss,) = torch.autograd.grad(loss, [input]) - input.backward(loss) - return stats - - def state_dict(self) -> dict[str, dict[str, float]]: - return { - "ema_total": dict(self.ema_total), - "ema_fix": dict(self.ema_fix), - } - - def load_state_dict(self, state_dict): - self.ema_total = defaultdict(float, state_dict["ema_total"]) - self.ema_fix = defaultdict(float, state_dict["ema_fix"]) - - -class QualityTester(nn.Module): - def __init__(self): - super().__init__() - self.utmos = torch.hub.load( - "tarepan/SpeechMOS:v1.0.0", "utmos22_strong", trust_repo=True - ).eval() - - @torch.inference_mode() - def compute_mos(self, wav: torch.Tensor) -> dict[str, list[float]]: - res = {"utmos": self.utmos(wav, sr=16000).tolist()} - return res - - def test( - self, converted_wav: torch.Tensor, source_wav: torch.Tensor - ) -> dict[str, list[float]]: - # [batch_size, wav_length] - res = {} - res.update(self.compute_mos(converted_wav)) - return res - - def test_many( - self, converted_wavs: list[torch.Tensor], source_wavs: list[torch.Tensor] - ) -> tuple[dict[str, float], dict[str, list[float]]]: - # list[batch_size, wav_length] - results = defaultdict(list) - assert len(converted_wavs) == len(source_wavs) - for converted_wav, source_wav in zip(converted_wavs, source_wavs): - res = self.test(converted_wav, source_wav) - for metric_name, value in res.items(): - results[metric_name].extend(value) - return { - metric_name: sum(values) / len(values) - for metric_name, values in results.items() - }, results - - -def compute_grad_norm( - model: nn.Module, return_stats: bool = False -) -> Union[float, dict[str, float]]: - total_norm = 0.0 - stats = {} - for name, p in model.named_parameters(): - if p.grad is None: - continue - param_norm = p.grad.data.norm().item() - if not math.isfinite(param_norm): - param_norm = p.grad.data.float().norm().item() - total_norm += param_norm * param_norm - if return_stats: - stats[f"grad_norm_{name}"] = param_norm - total_norm = math.sqrt(total_norm) - if return_stats: - return total_norm, stats - else: - return total_norm - - -def compute_mean_f0( - files: list[Path], method: Literal["dio", "harvest"] = "dio" -) -> float: - sum_log_f0 = 0.0 - n_frames = 0 - for file in files: - wav, sr = beatrice_load_audio(file) - if method == "dio": - f0, _ = pyworld.dio(wav.ravel().numpy().astype(np.float64), sr) - elif method == "harvest": - f0, _ = pyworld.harvest(wav.ravel().numpy().astype(np.float64), sr) - else: - raise ValueError(f"Invalid method: {method}") - f0 = f0[f0 > 0] - sum_log_f0 += float(np.log(f0).sum()) - n_frames += len(f0) - if n_frames == 0: - return math.nan - mean_log_f0 = sum_log_f0 / n_frames - return math.exp(mean_log_f0) - - - -def get_resampler( - sr_before: int, sr_after: int, device="cpu", cache={} -) -> torchaudio.transforms.Resample: - if not isinstance(device, str): - device = str(device) - if (sr_before, sr_after, device) not in cache: - cache[(sr_before, sr_after, device)] = torchaudio.transforms.Resample( - sr_before, sr_after - ).to(device) - return cache[(sr_before, sr_after, device)] - - -def convolve(signal: torch.Tensor, ir: torch.Tensor) -> torch.Tensor: - n = 1 << (signal.size(-1) + ir.size(-1) - 2).bit_length() - res = torch.fft.irfft(torch.fft.rfft(signal, n=n) * torch.fft.rfft(ir, n=n), n=n) - return res[..., : signal.size(-1)] - - -def random_formant_shift( - wav: torch.Tensor, - sample_rate: int, - formant_shift_semitone_min: float = -3.0, - formant_shift_semitone_max: float = 3.0, -) -> torch.Tensor: - assert wav.ndim == 2 - assert wav.size(0) == 1 - - device = wav.device - - hop_length = 256 - - # [wav_length] - wav_np = wav.ravel().double().cpu().numpy() - f0, t = pyworld.dio( - wav_np, - sample_rate, - f0_floor=55, - f0_ceil=1400, - frame_period=hop_length * 1000 / sample_rate, - ) - f0 = pyworld.stonemask(wav_np, f0, t, sample_rate) - world_sp = pyworld.cheaptrick(wav_np, f0, t, sample_rate) - world_sp = ( - torch.from_numpy(world_sp).float().to(device).sqrt_()[None] - ) # [1, length, n_fft // 2 + 1] - - n_fft = win_length = (world_sp.size(2) - 1) * 2 - - window = torch.hann_window(win_length, device=device) - - # [1, n_fft // 2 + 1, length] - stft_sp = torch.stft( - wav, - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - window=window, - return_complex=True, - ) - assert world_sp.size(1) == stft_sp.size(2), (world_sp.size(), stft_sp.size()) - assert world_sp.size(2) == stft_sp.size(1), (world_sp.size(), stft_sp.size()) - - shift_semitones = ( - torch.rand(()).item() - * (formant_shift_semitone_max - formant_shift_semitone_min) - + formant_shift_semitone_min - ) - shift_ratio = 2.0 ** (shift_semitones / 12.0) - shifted_world_sp = F.interpolate( - world_sp, scale_factor=shift_ratio, mode="linear", align_corners=True - ) - - if shifted_world_sp.size(2) > n_fft // 2 + 1: - shifted_world_sp = shifted_world_sp[:, :, : n_fft // 2 + 1] - elif shifted_world_sp.size(2) < n_fft // 2 + 1: - shifted_world_sp = F.pad( - shifted_world_sp, (0, n_fft // 2 + 1 - shifted_world_sp.size(2)) - ) - - ratio = ((shifted_world_sp + 1e-5) / (world_sp + 1e-5)).clamp(0.1, 10.0) - stft_sp *= ratio.transpose(-2, -1) # [1, n_fft // 2 + 1, length] - - out = torch.istft( - stft_sp, - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - window=window, - length=wav.size(-1), - ) - - return out - - -def random_filter(audio: torch.Tensor) -> torch.Tensor: - assert audio.ndim == 2 - ab = torch.rand(audio.size(0), 6) * 0.75 - 0.375 - a, b = ab[:, :3], ab[:, 3:] - a[:, 0] = 1.0 - b[:, 0] = 1.0 - audio = torchaudio.functional.lfilter(audio, a, b, clamp=False) - return audio - - -def get_noise( - n_samples: int, sample_rate: float, files: list[Union[str, bytes, os.PathLike]] -) -> torch.Tensor: - resample_augmentation_candidates = [0.9, 0.95, 1.0, 1.05, 1.1] - wavs = [] - current_length = 0 - while current_length < n_samples: - idx_files = torch.randint(0, len(files), ()) - file = files[idx_files] - wav, sr = beatrice_load_audio(file) - assert wav.size(0) == 1 - augmented_sample_rate = int( - round( - sample_rate - * resample_augmentation_candidates[ - torch.randint(0, len(resample_augmentation_candidates), ()) - ] - ) - ) - resampler = get_resampler(sr, augmented_sample_rate) - wav = resampler(wav) - wav = random_filter(wav) - wav *= 0.99 / (wav.abs().max() + 1e-5) - wavs.append(wav) - current_length += wav.size(1) - start = torch.randint(0, current_length - n_samples + 1, ()) - wav = torch.cat(wavs, dim=1)[:, start : start + n_samples] - assert wav.size() == (1, n_samples), wav.size() - return wav - - -def get_butterworth_lpf( - cutoff_freq: float, sample_rate: int, cache={} -) -> tuple[torch.Tensor, torch.Tensor]: - if (cutoff_freq, sample_rate) not in cache: - q = math.sqrt(0.5) - omega = math.tau * cutoff_freq / sample_rate - cos_omega = math.cos(omega) - alpha = math.sin(omega) / (2.0 * q) - b1 = (1.0 - cos_omega) / (1.0 + alpha) - b0 = b1 * 0.5 - a1 = -2.0 * cos_omega / (1.0 + alpha) - a2 = (1.0 - alpha) / (1.0 + alpha) - cache[(cutoff_freq, sample_rate)] = ( - torch.tensor([b0, b1, b0]), - torch.tensor([1.0, a1, a2]), - ) - return cache[(cutoff_freq, sample_rate)] - - -def augment_audio( - clean: torch.Tensor, - sample_rate: int, - noise_files: list[Union[str, bytes, os.PathLike]], - ir_files: list[Union[str, bytes, os.PathLike]], - snr_candidates: list[float] = [20.0, 25.0, 30.0, 35.0, 40.0, 45.0], - formant_shift_probability: float = 0.5, - formant_shift_semitone_min: float = -3.0, - formant_shift_semitone_max: float = 3.0, - reverb_probability: float = 0.5, - lpf_probability: float = 0.2, - lpf_cutoff_freq_candidates: list[float] = [2000.0, 3000.0, 4000.0, 6000.0], -) -> torch.Tensor: - # [1, wav_length] - assert clean.size(0) == 1 - n_samples = clean.size(1) - - original_clean_rms = clean.square().mean().sqrt_() - - # clean をフォルマントシフトする - if torch.rand(()) < formant_shift_probability: - clean = random_formant_shift( - clean, sample_rate, formant_shift_semitone_min, formant_shift_semitone_max - ) - - # noise を取得して clean と concat する - noise = get_noise(n_samples, sample_rate, noise_files) - signals = torch.cat([clean, noise]) - - # clean, noise に異なるランダムフィルタをかける - signals = random_filter(signals) - - # clean, noise にリバーブをかける - if torch.rand(()) < reverb_probability: - ir_file = ir_files[torch.randint(0, len(ir_files), ())] - ir, sr = beatrice_load_audio(ir_file) - assert ir.size() == (2, sr), ir.size() - assert sr == sample_rate, (sr, sample_rate) - signals = convolve(signals, ir) - - # clean, noise に同じ LPF をかける - if torch.rand(()) < lpf_probability: - if signals.abs().max() > 0.8: - signals /= signals.abs().max() * 1.25 - cutoff_freq = lpf_cutoff_freq_candidates[ - torch.randint(0, len(lpf_cutoff_freq_candidates), ()) - ] - b, a = get_butterworth_lpf(cutoff_freq, sample_rate) - signals = torchaudio.functional.lfilter(signals, a, b, clamp=False) - - # clean の音量を合わせる - clean, noise = signals - clean_rms = clean.square().mean().sqrt_() - clean *= original_clean_rms / clean_rms - - if len(snr_candidates) >= 1: - # clean, noise の音量をピークを重視して取る - clean_level = clean.square().square_().mean().sqrt_().sqrt_() - noise_level = noise.square().square_().mean().sqrt_().sqrt_() - # SNR - snr = snr_candidates[torch.randint(0, len(snr_candidates), ())] - # noisy を生成 - noisy = clean + noise * ( - 0.1 ** (snr / 20.0) * clean_level / (noise_level + 1e-5) - ) - - return noisy - - -class WavDataset(torch.utils.data.Dataset): - def __init__( - self, - audio_files: list[tuple[Path, int]], - in_sample_rate: int = 16000, - out_sample_rate: int = 24000, - wav_length: int = 4 * 24000, # 4s - segment_length: int = 100, # 1s - noise_files: Optional[list[Union[str, bytes, os.PathLike]]] = None, - ir_files: Optional[list[Union[str, bytes, os.PathLike]]] = None, - augmentation_snr_candidates: list[float] = [20.0, 25.0, 30.0, 35.0, 40.0, 45.0], - augmentation_formant_shift_probability: float = 0.5, - augmentation_formant_shift_semitone_min: float = -3.0, - augmentation_formant_shift_semitone_max: float = 3.0, - augmentation_reverb_probability: float = 0.5, - augmentation_lpf_probability: float = 0.2, - augmentation_lpf_cutoff_freq_candidates: list[float] = [ - 2000.0, - 3000.0, - 4000.0, - 6000.0, - ], - ): - self.audio_files = audio_files - self.in_sample_rate = in_sample_rate - self.out_sample_rate = out_sample_rate - self.wav_length = wav_length - self.segment_length = segment_length - self.noise_files = noise_files - self.ir_files = ir_files - self.augmentation_snr_candidates = augmentation_snr_candidates - self.augmentation_formant_shift_probability = ( - augmentation_formant_shift_probability - ) - self.augmentation_formant_shift_semitone_min = ( - augmentation_formant_shift_semitone_min - ) - self.augmentation_formant_shift_semitone_max = ( - augmentation_formant_shift_semitone_max - ) - self.augmentation_reverb_probability = augmentation_reverb_probability - self.augmentation_lpf_probability = augmentation_lpf_probability - self.augmentation_lpf_cutoff_freq_candidates = ( - augmentation_lpf_cutoff_freq_candidates - ) - - if (noise_files is None) is not (ir_files is None): - raise ValueError("noise_files and ir_files must be both None or not None") - - self.in_hop_length = in_sample_rate // 100 - self.out_hop_length = out_sample_rate // 100 # 10ms 刻み - - def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, int, int]: - file, speaker_id = self.audio_files[index] - clean_wav, sample_rate = beatrice_load_audio(file) - if clean_wav.size(0) != 1: - ch = torch.randint(0, clean_wav.size(0), ()) - clean_wav = clean_wav[ch : ch + 1] - - formant_shift_candidates = [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0] - formant_shift = formant_shift_candidates[ - torch.randint(0, len(formant_shift_candidates), ()).item() - ] - - resampler_fraction = Fraction( - sample_rate / self.out_sample_rate * 2.0 ** (formant_shift / 12.0) - ).limit_denominator(300) - clean_wav = get_resampler( - resampler_fraction.numerator, resampler_fraction.denominator - )(clean_wav) - - assert clean_wav.size(0) == 1 - assert clean_wav.size(1) != 0 - - clean_wav = F.pad(clean_wav, (self.wav_length, self.wav_length)) - - if self.noise_files is None: - noisy_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)( - clean_wav - ) - else: - clean_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)( - clean_wav - ) - noisy_wav_16k = augment_audio( - clean_wav_16k, - self.in_sample_rate, - self.noise_files, - self.ir_files, - self.augmentation_snr_candidates, - self.augmentation_formant_shift_probability, - self.augmentation_formant_shift_semitone_min, - self.augmentation_formant_shift_semitone_max, - self.augmentation_reverb_probability, - self.augmentation_lpf_probability, - self.augmentation_lpf_cutoff_freq_candidates, - ) - - clean_wav = clean_wav.squeeze_(0) - noisy_wav_16k = noisy_wav_16k.squeeze_(0) - - # 音量をランダマイズする - amplitude = torch.rand(()).item() * 0.899 + 0.1 - factor = amplitude / clean_wav.abs().max() - clean_wav *= factor - noisy_wav_16k *= factor - while noisy_wav_16k.abs().max() >= 1.0: - clean_wav *= 0.5 - noisy_wav_16k *= 0.5 - - return clean_wav, noisy_wav_16k, speaker_id, formant_shift - - def __len__(self) -> int: - return len(self.audio_files) - - def collate( - self, batch: list[tuple[torch.Tensor, torch.Tensor, int, int]] - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - assert self.wav_length % self.out_hop_length == 0 - length = self.wav_length // self.out_hop_length - clean_wavs = [] - noisy_wavs = [] - slice_starts = [] - speaker_ids = [] - formant_shifts = [] - for clean_wav, noisy_wav, speaker_id, formant_shift in batch: - # 発声部分をランダムに 1 箇所選ぶ - (voiced,) = clean_wav.nonzero(as_tuple=True) - assert voiced.numel() != 0 - center = voiced[torch.randint(0, voiced.numel(), ()).item()].item() - # 発声部分が中央にくるように、スライス区間を選ぶ - slice_start = center - self.segment_length * self.out_hop_length // 2 - assert slice_start >= 0 - # スライス区間が含まれるように、ランダムに wav_length の長さを切り出す - r = torch.randint(0, length - self.segment_length + 1, ()).item() - offset = slice_start - r * self.out_hop_length - clean_wavs.append(clean_wav[offset : offset + self.wav_length]) - offset_in_sample_rate = int( - round(offset * self.in_sample_rate / self.out_sample_rate) - ) - noisy_wavs.append( - noisy_wav[ - offset_in_sample_rate : offset_in_sample_rate - + length * self.in_hop_length - ] - ) - slice_start = r - slice_starts.append(slice_start) - speaker_ids.append(speaker_id) - formant_shifts.append(formant_shift) - clean_wavs = torch.stack(clean_wavs) - noisy_wavs = torch.stack(noisy_wavs) - slice_starts = torch.tensor(slice_starts) - speaker_ids = torch.tensor(speaker_ids) - formant_shifts = torch.tensor(formant_shifts) - return ( - clean_wavs, # [batch_size, wav_length] - noisy_wavs, # [batch_size, wav_length] - slice_starts, # Long[batch_size] - speaker_ids, # Long[batch_size] - formant_shifts, # Long[batch_size] - ) - - - -AUDIO_FILE_SUFFIXES = { - ".wav", - ".aif", - ".aiff", - ".fla", - ".flac", - ".oga", - ".ogg", - ".opus", - ".mp3", -} - - -def get_compressed_optimizer_state_dict( - optimizer: torch.optim.Optimizer, -) -> dict: - state_dict = {} - for k0, v0 in optimizer.state_dict().items(): - if k0 != "state": - state_dict[k0] = v0 - continue - state_dict[k0] = {} - for k1, v1 in v0.items(): - state_dict[k0][k1] = {} - for k2, v2 in v1.items(): - if isinstance(v2, torch.Tensor): - state_dict[k0][k1][k2] = v2.bfloat16() - assert state_dict[k0][k1][k2].isfinite().all() - else: - state_dict[k0][k1][k2] = v2 - return state_dict - - -def get_decompressed_optimizer_state_dict(compressed_state_dict: dict) -> dict: - state_dict = {} - for k0, v0 in compressed_state_dict.items(): - if k0 != "state": - state_dict[k0] = v0 - continue - state_dict[k0] = {} - for k1, v1 in v0.items(): - state_dict[k0][k1] = {} - for k2, v2 in v1.items(): - if isinstance(v2, torch.Tensor): - state_dict[k0][k1][k2] = v2.float() - assert state_dict[k0][k1][k2].isfinite().all() - else: - state_dict[k0][k1][k2] = v2 - return state_dict - - - - -# ============================================================ -# BEATRICE V2 TRAINING - Embedded (downloads assets from HuggingFace) -# ============================================================ - -BEATRICE_AUDIO_FILE_SUFFIXES = {".wav", ".aif", ".aiff", ".fla", ".flac", ".oga", ".ogg", ".opus", ".mp3"} - - -def preprocess_audio_for_beatrice(audio_path: str, output_dir: str, speaker_name: str = "speaker"): - """Preprocess audio for Beatrice training using silence-based splitting""" - - # Create speaker directory structure required by Beatrice - speaker_dir = os.path.join(output_dir, speaker_name) - os.makedirs(speaker_dir, exist_ok=True) - - # Load audio at 16kHz (Beatrice input sample rate) - audio, sr = librosa.load(audio_path, sr=16000, mono=True) - - # Simple silence-based splitting (RMS threshold) - chunk_size = int(4.0 * sr) # 4 second chunks - hop = int(3.5 * sr) # 0.5s overlap - threshold = 0.01 # RMS threshold - - chunks_saved = 0 - for i, start in enumerate(range(0, len(audio) - chunk_size, hop)): - chunk = audio[start:start + chunk_size] - rms = np.sqrt(np.mean(chunk ** 2)) - - if rms > threshold: # Skip silence - # Normalize - max_val = np.abs(chunk).max() - if max_val > 0: - chunk = chunk / max_val * 0.9 - - chunk_path = os.path.join(speaker_dir, f"{speaker_name}_{chunks_saved:04d}.wav") - sf.write(chunk_path, chunk, sr) - chunks_saved += 1 - - logger.info(f"Beatrice preprocessing: {chunks_saved} chunks saved to {speaker_dir}") - return chunks_saved, output_dir - - -def train_beatrice_generator( - data_dir: str, - output_dir: str, - epochs: int = 30, - batch_size: int = 8, - lr_g: float = 5e-5, - lr_d: float = 5e-5, - use_augmentation: bool = False, - resume: bool = False, - progress_callback=None, -): - """Train Beatrice v2 model - generator yielding (message, model_path) tuples""" - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Download pretrained - yield "Downloading pretrained models...", None - phone_extractor_path = download_beatrice_asset("phone_extractor") - pitch_estimator_path = download_beatrice_asset("pitch_estimator") - pretrained_model_path = download_beatrice_asset("pretrained_model") - - # Discover speakers from directory structure - # Expected: data_dir/speaker_name/*.wav - speakers = [] - training_filelist = [] - speaker_audio_files = [] - for speaker_dir in sorted(Path(data_dir).iterdir()): - if not speaker_dir.is_dir(): - continue - candidates = [f for f in sorted(speaker_dir.rglob("*")) - if f.is_file() and f.suffix.lower() in BEATRICE_AUDIO_FILE_SUFFIXES] - if candidates: - speaker_id = len(speakers) - speakers.append(speaker_dir.name) - training_filelist.extend([(f, speaker_id) for f in candidates]) - speaker_audio_files.append(candidates) - - n_speakers = len(speakers) - if n_speakers == 0: - yield "Error: No speakers found in data directory", None - return - - yield f"Found {n_speakers} speaker(s), {len(training_filelist)} files", None - - # Augmentation assets (optional) - noise_files = None - ir_files = None - if use_augmentation: - try: - noise_dir, ir_dir = download_beatrice_augmentation() - if noise_dir and ir_dir: - noise_files = sorted(list(Path(noise_dir).rglob("*.wav")) + list(Path(noise_dir).rglob("*.flac"))) - ir_files = sorted(list(Path(ir_dir).rglob("*.wav")) + list(Path(ir_dir).rglob("*.flac"))) - if noise_files and ir_files: - yield f"Loaded augmentation: {len(noise_files)} noise, {len(ir_files)} IR files", None - else: - noise_files = None - ir_files = None - except Exception as e: - yield f"Warning: Could not load augmentation assets: {e}", None - - # Build models - yield "Building models...", None - - phone_extractor = PhoneExtractor().to(device).eval().requires_grad_(False) - pe_ckpt = torch.load(phone_extractor_path, map_location="cpu", weights_only=True) - phone_extractor.load_state_dict(pe_ckpt["phone_extractor"], strict=False) - del pe_ckpt - - pitch_estimator = PitchEstimator().to(device).eval().requires_grad_(False) - pi_ckpt = torch.load(pitch_estimator_path, map_location="cpu", weights_only=True) - pitch_estimator.load_state_dict(pi_ckpt["pitch_estimator"]) - del pi_ckpt - - hidden_channels = 256 - pitch_bins = 448 - - net_g = ConverterNetwork( - phone_extractor, pitch_estimator, - n_speakers=n_speakers, - pitch_bins=pitch_bins, - hidden_channels=hidden_channels, - vq_topk=4, - training_time_vq="none", - phone_noise_ratio=0.5, - floor_noise_level=1e-3, - ).to(device) - - net_d = BeatriceMultiPeriodDiscriminator(san=True).to(device) - - # Optimizers - optim_g = torch.optim.AdamW(net_g.parameters(), lr_g, betas=(0.8, 0.99), eps=1e-6) - optim_d = torch.optim.AdamW(net_d.parameters(), lr_d, betas=(0.8, 0.99), eps=1e-6) - - grad_scaler = torch.amp.GradScaler(device.type, enabled=device.type == "cuda") - grad_balancer = GradBalancer( - weights={ - "loss_loudness": 1.0, - "loss_mel": 45.0, - "loss_adv": 1.0, - "loss_fm": 2.0, - }, - ema_decay=0.999, - ) - - initial_iteration = 0 - os.makedirs(output_dir, exist_ok=True) - - # Load pretrained or resume - if resume: - latest_ckpt = os.path.join(output_dir, "checkpoint_latest.pt.gz") - if os.path.isfile(latest_ckpt): - yield "Resuming from checkpoint...", None - with gzip.open(latest_ckpt, "rb") as f: - ckpt = torch.load(f, map_location="cpu", weights_only=True) - net_g.load_state_dict(ckpt["net_g"], strict=False) - # Filter discriminator for shape mismatches - net_d_state = net_d.state_dict() - filtered_d = {k: v for k, v in ckpt["net_d"].items() - if k in net_d_state and v.shape == net_d_state[k].shape} - net_d.load_state_dict(filtered_d, strict=False) - optim_g.load_state_dict(get_decompressed_optimizer_state_dict(ckpt["optim_g"])) - optim_d.load_state_dict(get_decompressed_optimizer_state_dict(ckpt["optim_d"])) - if "grad_balancer" in ckpt: - grad_balancer.load_state_dict(ckpt["grad_balancer"]) - if "grad_scaler" in ckpt: - grad_scaler.load_state_dict(ckpt["grad_scaler"]) - initial_iteration = ckpt.get("iteration", 0) - del ckpt - else: - yield "No checkpoint found, starting fresh with pretrained", None - resume = False - - if not resume: - yield "Loading pretrained weights...", None - with gzip.open(pretrained_model_path, "rb") as f: - pretrained_ckpt = torch.load(f, map_location="cpu", weights_only=True) - # Adapt pretrained for our n_speakers - initial_speaker_emb = pretrained_ckpt["net_g"]["embed_speaker.weight"][:1] - pretrained_ckpt["net_g"]["embed_speaker.weight"] = initial_speaker_emb[[0] * n_speakers] - initial_kv_emb = pretrained_ckpt["net_g"]["key_value_speaker_embedding.weight"][:1] - pretrained_ckpt["net_g"]["key_value_speaker_embedding.weight"] = initial_kv_emb[[0] * n_speakers] - pretrained_ckpt["net_g"]["vq.codebooks"] = pretrained_ckpt["net_g"]["vq.codebooks"][[0] * n_speakers] - net_g.load_state_dict(pretrained_ckpt["net_g"], strict=False) - # Filter discriminator state dict for shape mismatches (pretrained may use san=False) - net_d_state = net_d.state_dict() - filtered_d = {k: v for k, v in pretrained_ckpt["net_d"].items() - if k in net_d_state and v.shape == net_d_state[k].shape} - net_d.load_state_dict(filtered_d, strict=False) - logger.info(f"Loaded {len(filtered_d)}/{len(pretrained_ckpt['net_d'])} discriminator weights") - # Don't load grad_balancer/grad_scaler from pretrained - our loss weights may differ - # These will be re-initialized fresh for fine-tuning - del pretrained_ckpt - - # Build VQ codebooks - yield "Building VQ codebooks...", None - - def wav_iterator(files): - for file in files: - wav, sr = beatrice_load_audio(file) - wav = wav.to(device) - if sr != 16000: - wav = get_resampler(sr, 16000, str(device))(wav) - yield wav[:, None, :] - - if resume: - net_g.enable_hook() - else: - net_g.initialize_vq([wav_iterator(files) for files in speaker_audio_files]) - - # Dataset - dataset = WavDataset( - training_filelist, - in_sample_rate=16000, - out_sample_rate=24000, - wav_length=96000, - segment_length=100, - noise_files=noise_files, - ir_files=ir_files, - ) - - _num_workers = min(4, os.cpu_count() or 1) - effective_batch = min(batch_size, len(training_filelist)) - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=_num_workers, - collate_fn=dataset.collate, - shuffle=True, - batch_size=effective_batch, - pin_memory=True, - drop_last=len(training_filelist) > effective_batch, - persistent_workers=_num_workers > 0, - ) - - # Calculate steps - steps_per_epoch = max(1, len(training_filelist) // batch_size) - total_steps = epochs * steps_per_epoch - warmup_steps = min(total_steps // 4, 5000) - - # LR scheduler with warmup - def lr_lambda(step): - if step < warmup_steps: - return step / max(1, warmup_steps) - return 0.999 ** (step - warmup_steps) - - scheduler_g = torch.optim.lr_scheduler.LambdaLR(optim_g, lr_lambda) - scheduler_d = torch.optim.lr_scheduler.LambdaLR(optim_d, lr_lambda) - - # Advance schedulers if resuming - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message=r"Detected call of `lr_scheduler\.step\(\)") - for _ in range(initial_iteration + 1): - scheduler_g.step() - scheduler_d.step() - - net_g.train() - net_d.train() - - yield f"Training {total_steps} steps ({epochs} epochs x {steps_per_epoch} steps/epoch)", None - - # Training loop - step = initial_iteration - data_iter = None - ckpt_path = None - - for epoch in range(epochs): - epoch_loss_g = 0.0 - epoch_loss_d = 0.0 - epoch_steps = 0 - - for batch_idx in range(steps_per_epoch): - if data_iter is None: - data_iter = iter(dataloader) - batch = next(data_iter, None) - if batch is None: - data_iter = iter(dataloader) - batch = next(data_iter, None) - if batch is None: - break - - clean_wavs, noisy_wavs_16k, slice_starts, speaker_ids, formant_shifts = \ - [x.to(device, non_blocking=True) for x in batch] - - with torch.amp.autocast(device.type, enabled=device.type == "cuda"): - # Generator forward - y, y_hat, y_hat_for_backward, loss_loudness, loss_mel, loss_ap, gen_stats = \ - net_g.forward_and_compute_loss( - noisy_wavs_16k[:, None, :], - speaker_ids, - formant_shifts, - slice_start_indices=slice_starts, - slice_segment_length=100, - y_all=clean_wavs[:, None, :], - ) - - # Discriminator forward - loss_disc, loss_adv, loss_fm, disc_stats = \ - net_d.forward_and_compute_loss(y, y_hat) - - # Discriminator backward - optim_d.zero_grad(set_to_none=True) - grad_scaler.scale(loss_disc).backward(retain_graph=True, inputs=list(net_d.parameters())) - grad_scaler.unscale_(optim_d) - - # Generator backward - optim_g.zero_grad(set_to_none=True) - grad_balancer.backward( - {"loss_loudness": loss_loudness, "loss_mel": loss_mel, - "loss_adv": loss_adv, "loss_fm": loss_fm}, - y_hat_for_backward, grad_scaler, - skip_update_ema=step > 10 and step % 5 != 0, - ) - grad_scaler.unscale_(optim_g) - - # Update - grad_scaler.step(optim_g) - grad_scaler.step(optim_d) - grad_scaler.update() - optim_g.zero_grad(set_to_none=True) - optim_d.zero_grad(set_to_none=True) - - scheduler_g.step() - scheduler_d.step() - - epoch_loss_g += loss_mel.item() - epoch_loss_d += loss_disc.item() - epoch_steps += 1 - step += 1 - - if progress_callback: - progress_callback(step / total_steps) - - avg_loss_g = epoch_loss_g / max(1, epoch_steps) - avg_loss_d = epoch_loss_d / max(1, epoch_steps) - - yield f"Epoch {epoch+1}/{epochs} | G loss: {avg_loss_g:.4f} | D loss: {avg_loss_d:.4f} | LR: {scheduler_g.get_last_lr()[0]:.2e}", None - - # Save checkpoint periodically - if (epoch + 1) % max(1, epochs // 5) == 0 or epoch == epochs - 1: - ckpt_path = os.path.join(output_dir, f"checkpoint_{step:08d}.pt.gz") - with gzip.open(ckpt_path, "wb") as f: - torch.save({ - "iteration": step, - "net_g": net_g.state_dict(), - "phone_extractor": phone_extractor.state_dict(), - "pitch_estimator": pitch_estimator.state_dict(), - "net_d": {k: v.half() for k, v in net_d.state_dict().items()}, - "optim_g": get_compressed_optimizer_state_dict(optim_g), - "optim_d": get_compressed_optimizer_state_dict(optim_d), - "grad_balancer": grad_balancer.state_dict(), - "grad_scaler": grad_scaler.state_dict(), - "h": { - "hidden_channels": hidden_channels, - "pitch_bins": pitch_bins, - "vq_topk": 4, - "training_time_vq": "none", - "phone_noise_ratio": 0.5, - "floor_noise_level": 1e-3, - "san": True, - }, - "speakers": speakers, - }, f) - shutil.copy(ckpt_path, os.path.join(output_dir, "checkpoint_latest.pt.gz")) - yield f"Saved checkpoint: {ckpt_path}", ckpt_path - - # Cleanup - del net_g, net_d, optim_g, optim_d, phone_extractor, pitch_estimator - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - yield "Training complete!", ckpt_path - - -def convert_voice_beatrice( - source_audio, - model_file, - target_speaker: int = 0, - pitch_shift: int = 0, - formant_shift: float = 0.0, - progress=None, -): - """Convert voice using Beatrice v2 model - - Args: - source_audio: Path to source audio file - model_file: Path to Beatrice checkpoint (.pt.gz) or file object with .name - target_speaker: Target speaker index - pitch_shift: Pitch shift in semitones - formant_shift: Formant shift in semitones (-2 to 2) - progress: Gradio progress callback - - Returns: - (output_path, status_message) tuple - """ - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Get model path - if hasattr(model_file, 'name'): - model_path = model_file.name - elif isinstance(model_file, str): - model_path = model_file - else: - return None, "Invalid model file" - - if not model_path or not os.path.exists(model_path): - return None, f"Model file not found: {model_path}" - - try: - if progress: - progress(0.1, "Loading models...") - - # Download pretrained assets (phone extractor + pitch estimator) - phone_extractor_path = download_beatrice_asset("phone_extractor") - pitch_estimator_path = download_beatrice_asset("pitch_estimator") - - # Build phone extractor - phone_extractor = PhoneExtractor().to(device).eval().requires_grad_(False) - pe_ckpt = torch.load(phone_extractor_path, map_location="cpu", weights_only=True) - phone_extractor.load_state_dict(pe_ckpt["phone_extractor"], strict=False) - del pe_ckpt - - # Build pitch estimator - pitch_estimator = PitchEstimator().to(device).eval().requires_grad_(False) - pi_ckpt = torch.load(pitch_estimator_path, map_location="cpu", weights_only=True) - pitch_estimator.load_state_dict(pi_ckpt["pitch_estimator"]) - del pi_ckpt - - if progress: - progress(0.3, "Loading trained model...") - - # Load trained checkpoint - with gzip.open(model_path, "rb") as f: - checkpoint = torch.load(f, map_location="cpu", weights_only=True) - - # Determine model params from checkpoint - n_speakers = checkpoint["net_g"]["embed_speaker.weight"].shape[0] - h = checkpoint.get("h", {}) - hidden_channels = h.get("hidden_channels", 256) - pitch_bins = h.get("pitch_bins", 448) - speakers = checkpoint.get("speakers", [f"Speaker {i}" for i in range(n_speakers)]) - - if target_speaker >= n_speakers: - target_speaker = 0 - - net_g = ConverterNetwork( - phone_extractor, pitch_estimator, - n_speakers=n_speakers, - pitch_bins=pitch_bins, - hidden_channels=hidden_channels, - vq_topk=h.get("vq_topk", 4), - training_time_vq=h.get("training_time_vq", "none"), - phone_noise_ratio=h.get("phone_noise_ratio", 0.5), - floor_noise_level=h.get("floor_noise_level", 1e-3), - ).to(device).eval() - - net_g.load_state_dict(checkpoint["net_g"], strict=False) - net_g.enable_hook() - del checkpoint - - if progress: - progress(0.5, "Converting voice...") - - # Load audio at 16kHz - audio_path = source_audio if isinstance(source_audio, str) else source_audio - audio, sr = librosa.load(audio_path, sr=16000, mono=True) - audio_tensor = torch.from_numpy(audio).float().unsqueeze(0).unsqueeze(0).to(device) - - # Pad to multiple of 160 (phone extractor stride) - original_length = audio_tensor.shape[-1] - if original_length % 160 != 0: - pad_len = 160 - original_length % 160 - audio_tensor = F.pad(audio_tensor, (0, pad_len)) - - # Convert - with torch.inference_mode(): - y_hat = net_g( - audio_tensor, - torch.tensor([target_speaker], device=device), - torch.tensor([formant_shift], device=device), - torch.tensor([float(pitch_shift)], device=device), - ) - - # Output is 24kHz, trim to match input duration - output_length = original_length // 160 * 240 # 16kHz→24kHz frame ratio - output = y_hat.squeeze().cpu().numpy()[:output_length] - - # Save - fd, output_path = tempfile.mkstemp(suffix=".wav") - os.close(fd) - sf.write(output_path, output, 24000) - - # Cleanup - del net_g, phone_extractor, pitch_estimator - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - speaker_name = speakers[target_speaker] if target_speaker < len(speakers) else f"Speaker {target_speaker}" - return output_path, f"Converted using Beatrice v2 | 24kHz | Speaker: {speaker_name} | Pitch: {pitch_shift:+d} | Formant: {formant_shift:+.1f}" - - except Exception as e: - logger.exception("Beatrice inference error") - return None, f"Error: {str(e)}" - - -# ============================================================ -# GRADIO UI - Gradio 6 Compatible -# ============================================================ - -def train_ui( - audio_file, - model_name: str, - epochs: int, - batch_size: int, - sample_rate: int, - f0_method: str = "rmvpe", - progress=gr.Progress() -): - """Training function for Gradio UI - Generator for live log updates""" - if audio_file is None: - yield None, None, "❌ Please upload training audio" - return - - if not model_name or model_name.strip() == "": - yield None, None, "❌ Please enter a model name" - return - - # Check if CUDA available - has_cuda = torch.cuda.is_available() - device_info = "GPU (CUDA)" if has_cuda else "CPU" - - # Log accumulator for live updates - logs = [] - - try: - model_name = sanitize_model_name(model_name) - output_dir = f"trained_models/{model_name}" - data_dir = f"{output_dir}/data" - - # Preprocessing phase - logs.append(f"🚀 Starting on {device_info}") - logs.append(f"📂 Output: {output_dir}") - yield None, None, "\n".join(logs) - - progress(0.1, "Preprocessing...") - logs.append("🔄 Preprocessing audio...") - yield None, None, "\n".join(logs) - - audio_path = audio_file if isinstance(audio_file, str) else audio_file.name - result = preprocess_audio_for_training(audio_path, data_dir, target_sr=sample_rate, f0_method=f0_method) - - if result is None: - logs.append("❌ Preprocessing failed - no valid audio chunks") - yield None, None, "\n".join(logs) - return - - logs.append("✅ Preprocessing complete") - logs.append(f"🏋️ Training {epochs} epochs...") - logs.append("─" * 40) - yield None, None, "\n".join(logs) - - # Training phase - iterate over generator for live updates - ckpt = None - idx = None - for msg, path, index in train_rvc_generator( - data_dir=data_dir, - output_dir=output_dir, - epochs=epochs, - batch_size=batch_size, - lr=1e-5, - target_sr=sample_rate, - progress_callback=progress - ): - logs.append(msg) - yield None, None, "\n".join(logs) - if path: - ckpt = path - if index: - idx = index - - if ckpt: - logs.append("─" * 40) - logs.append(f"✅ Training complete!") - logs.append(f"📦 Model: {ckpt}") - if idx: - logs.append(f"📦 Index: {idx}") - progress(1.0, "Done!") - yield ckpt, idx, "\n".join(logs) - else: - logs.append("❌ Training failed") - yield None, None, "\n".join(logs) - - except Exception as e: - logger.exception("Training error") - logs.append(f"❌ Error: {str(e)}") - yield None, None, "\n".join(logs) - -with gr.Blocks() as demo: - gr.Markdown(f"# 🎤 Voice Conversion (RVC + Beatrice)\nInference: CPU • Training: {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}") - - with gr.Tabs(): - # ==================== TAB 1: VOICE CONVERSION ==================== - with gr.Tab("🎵 Voice Conversion"): - with gr.Row(): - with gr.Column(): - source_audio = gr.Audio(label="Source Audio", type="filepath") - - gr.Markdown("### Model") - model_type = gr.Radio( - ["RVC v2", "Beatrice v2"], value="RVC v2", - label="Model Type", - info="RVC: .pth files | Beatrice: .pt.gz files" - ) - - # RVC model inputs - with gr.Group(visible=True) as rvc_model_group: - rvc_model_dropdown = gr.Dropdown( - choices=list_rvc_models(), - label="Select Voice Model", - info="Models auto-loaded from weights/ folder", - value=None, - ) - with gr.Row(): - model_file = gr.File(label="OR Upload New (.pth)", file_types=[".pth"]) - load_example_btn = gr.Button("Load Example (Benee)", size="sm") - index_file = gr.File(label="Index File (.index) - Optional", file_types=[".index"]) - - # Beatrice model inputs - with gr.Group(visible=False) as beatrice_model_group: - beatrice_model_file = gr.File(label="Beatrice Model (.pt.gz)", file_types=[".gz"]) - with gr.Row(): - beatrice_target_speaker = gr.Number(value=0, label="Target Speaker", precision=0) - beatrice_formant_shift = gr.Slider(-2, 2, value=0.0, step=0.5, label="Formant Shift") - - with gr.Row(): - pitch_shift = gr.Slider(-12, 12, value=0, step=1, label="Pitch (semitones)") - f0_method = gr.Radio(["rmvpe", "pm", "harvest"], value="rmvpe", label="F0 Method", visible=True) - - with gr.Row(visible=True) as rvc_extra_options: - index_rate = gr.Slider(0, 1, value=0.75, step=0.05, label="Index Rate") - protect = gr.Slider(0, 0.5, value=0.33, step=0.01, label="Protect (voiceless consonants)") - convert_btn = gr.Button("Convert", variant="primary") - - with gr.Column(): - output_audio = gr.Audio(label="Converted Audio", type="filepath") - output_info = gr.Textbox(label="Status", lines=2) - - def update_model_type(model_type_val): - is_rvc = model_type_val == "RVC v2" - return ( - gr.update(visible=is_rvc), # rvc_model_group - gr.update(visible=not is_rvc), # beatrice_model_group - gr.update(visible=is_rvc), # f0_method - gr.update(visible=is_rvc), # rvc_extra_options - ) - - model_type.change( - update_model_type, - [model_type], - [rvc_model_group, beatrice_model_group, f0_method, rvc_extra_options] - ) - - load_example_btn.click( - load_example_model, - [], - [model_file, index_file, output_info] - ) - - def convert_unified(source, m_type, rvc_dropdown_model, rvc_model, rvc_index, beat_model, - beat_speaker, beat_formant, pitch, f0, idx_rate, prot, - progress=gr.Progress()): - if m_type == "RVC v2": - # Resolve model: prefer uploaded file, fall back to dropdown selection - resolved_model = rvc_model - resolved_index = rvc_index - if resolved_model is None and rvc_dropdown_model: - pth_path = Path("weights") / rvc_dropdown_model - # Wrap in a simple object so convert_voice can call .name on it - class _FileObj: - def __init__(self, p): self.name = str(p) - resolved_model = _FileObj(pth_path) - # Auto-hunt for matching .index in weights/ - if resolved_index is None: - stem = pth_path.stem - index_path = Path("weights") / f"{stem}.index" - if index_path.exists(): - resolved_index = _FileObj(index_path) - return convert_voice(source, resolved_model, resolved_index, pitch, f0, idx_rate, prot, progress) - else: - return convert_voice_beatrice( - source, beat_model, - target_speaker=int(beat_speaker), - pitch_shift=int(pitch), - formant_shift=float(beat_formant), - progress=progress - ) - - convert_btn.click( - convert_unified, - [source_audio, model_type, rvc_model_dropdown, model_file, index_file, beatrice_model_file, - beatrice_target_speaker, beatrice_formant_shift, - pitch_shift, f0_method, index_rate, protect], - [output_audio, output_info], - api_name="convert", - concurrency_limit=1, - ) - - gr.Markdown("**Models:** [HuggingFace](https://huggingface.co/models?search=rvc) | [Weights.gg](https://weights.gg)") - - # ==================== TAB 2: TRAINING ==================== - with gr.Tab("🏋️ Training"): - # GPU Warning - gpu_status = "🟢 GPU Available (CUDA)" if torch.cuda.is_available() else "🟡 CPU Only (Training will be slow)" - gr.Markdown(f""" - ### Training Status: {gpu_status} - - {'**GPU detected!** Training will use CUDA acceleration.' if torch.cuda.is_available() else '**No GPU detected.** Training will run on CPU (~30 sec/epoch). For faster training, run locally with CUDA GPU.'} - """) - - # Trainer selector - Beatrice always available (downloads assets from HF) - trainer_selector = gr.Dropdown( - choices=["RVC v2", "Beatrice v2"], - value="RVC v2", - label="Trainer", - info="RVC v2: general purpose | Beatrice v2: low latency (~50ms)" - ) - - with gr.Row(): - with gr.Column(): - train_audio = gr.Audio(label="Training Audio", type="filepath") - train_model_name = gr.Textbox(label="Model Name", placeholder="my_voice", value="my_voice") - - # RVC-specific options - with gr.Group(visible=True) as rvc_options: - with gr.Row(): - train_epochs = gr.Slider(1, 500, value=50, step=1, label="Epochs") - train_batch = gr.Slider(1, 8, value=2, step=1, label="Batch Size") - with gr.Row(): - train_sr = gr.Radio([32000, 40000, 48000], value=40000, label="Sample Rate") - train_f0 = gr.Radio(["rmvpe", "pm", "harvest"], value="rmvpe", label="F0 Method") - - # Beatrice-specific options - with gr.Group(visible=False) as beatrice_options: - with gr.Row(): - beatrice_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs (30 recommended)") - beatrice_batch = gr.Slider(1, 64, value=8, step=1, label="Batch Size") - beatrice_resume = gr.Checkbox(label="Resume from checkpoint", value=False) - - train_btn = gr.Button("Start Training", variant="primary") - - with gr.Column(): - train_output_model = gr.File(label="Trained Model (.pth)") - train_output_index = gr.File(label="Index File (.index)") - train_status = gr.Textbox(label="Training Status", lines=6) - - # Toggle visibility based on trainer selection - def update_trainer_options(trainer): - is_rvc = trainer == "RVC v2" - return gr.update(visible=is_rvc), gr.update(visible=not is_rvc) - - trainer_selector.change( - update_trainer_options, - [trainer_selector], - [rvc_options, beatrice_options] - ) - - # Unified training function - def train_unified(trainer, audio, name, rvc_epochs, rvc_batch, rvc_sr, rvc_f0, - beat_epochs, beat_batch, beat_resume, progress=gr.Progress()): - if trainer == "RVC v2": - # Use generator for live updates (yields 3-tuples: model, index, status) - for result in train_ui(audio, name, rvc_epochs, rvc_batch, rvc_sr, rvc_f0, progress): - yield result - else: - # Beatrice training (embedded - downloads assets from HF) - if audio is None: - yield None, None, "❌ Please upload training audio" - return - if not name or name.strip() == "": - yield None, None, "❌ Please enter a model name" - return - - try: - name = sanitize_model_name(name) - output_dir = f"trained_models/beatrice_{name}" - data_dir = f"{output_dir}/training_data" - logs = [] - - # Preprocess audio into speaker chunks - logs.append(f"🚀 Starting Beatrice training on {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}") - yield None, None, "\n".join(logs) - - progress(0.05, "Preprocessing audio...") - audio_path = audio if isinstance(audio, str) else audio.name - chunks, _ = preprocess_audio_for_beatrice(audio_path, data_dir, name) - - if chunks == 0: - yield None, None, "❌ Preprocessing failed - no valid audio chunks" - return - - logs.append(f"✅ Preprocessed {chunks} audio chunks") - logs.append("─" * 40) - yield None, None, "\n".join(logs) - - # Train using embedded generator - ckpt = None - for msg, path in train_beatrice_generator( - data_dir=data_dir, - output_dir=output_dir, - epochs=beat_epochs, - batch_size=beat_batch, - resume=beat_resume, - progress_callback=progress - ): - logs.append(msg) - yield None, None, "\n".join(logs) - if path: - ckpt = path - - if ckpt: - logs.append("─" * 40) - logs.append(f"✅ Training complete!") - logs.append(f"📦 Model: {ckpt}") - progress(1.0, "Done!") - yield ckpt, None, "\n".join(logs) - else: - logs.append("❌ Training failed") - yield None, None, "\n".join(logs) - - except Exception as e: - logger.exception("Beatrice training error") - yield None, None, f"❌ Error: {str(e)}" - - train_btn.click( - train_unified, - [trainer_selector, train_audio, train_model_name, - train_epochs, train_batch, train_sr, train_f0, - beatrice_epochs, beatrice_batch, beatrice_resume], - [train_output_model, train_output_index, train_status], - api_name="train", - concurrency_limit=1, - ) - - gr.Markdown(""" - --- - ### Training Tips - - **RVC v2:** - - 50-100 epochs for quick test, 200-500 for quality - - CPU training: ~30 sec/epoch (100 epochs ≈ 50 min) - - **Beatrice v2:** - - 20-50 epochs recommended, GPU recommended (CPU works but slow) - - Pretrained assets downloaded automatically from HuggingFace - - Output: .pt.gz checkpoint (use in Voice Conversion tab) - - Lower latency (~50ms vs ~100ms) - - ### CLI - ```bash - python app.py train -a voice.mp3 -o ./model --epochs 100 - python app.py train-beatrice -a voice.mp3 -o ./beatrice_model --epochs 30 - python app.py infer -i input.wav -m beatrice_model.pt.gz -o output.wav - ``` - """) - - -def cli_convert(args): - """CLI mode conversion - supports both RVC and Beatrice models""" - print(f"Converting: {args.input}") - print(f"Model: {args.model}") - - # Auto-detect model type from extension or --type flag - model_type = getattr(args, 'type', None) - model_path = args.model - index_path = getattr(args, 'index', None) - - if args.example: - print("Downloading example model...") - model_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_MODEL_FILE) - index_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_INDEX_FILE) - model_type = "rvc" - print(f"Model: {model_path}") - - if not model_path: - print("Error: No model specified. Use -m MODEL.pth or --example") - sys.exit(1) - - # Auto-detect type from extension if not specified - if model_type is None: - if model_path.endswith('.pt.gz') or model_path.endswith('.gz'): - model_type = "beatrice" - else: - model_type = "rvc" - - if model_type == "beatrice": - # Beatrice inference - print(f"Using Beatrice v2 inference") - output_path, status = convert_voice_beatrice( - source_audio=args.input, - model_file=model_path, - target_speaker=getattr(args, 'speaker', 0), - pitch_shift=args.pitch, - formant_shift=getattr(args, 'formant_shift', 0.0), - ) - else: - # RVC inference - class FileObj: - def __init__(self, path): - self.name = path - - model_file = FileObj(model_path) - index_file = FileObj(index_path) if index_path else None - - output_path, status = convert_voice( - source_audio=args.input, - model_file=model_file, - index_file=index_file, - pitch_shift=args.pitch, - f0_method=args.f0, - index_rate=args.index_rate, - progress=lambda *a, **k: None - ) - - if output_path: - shutil.copy(output_path, args.output) - print(f"Output: {args.output}") - print(status) - else: - print(f"Failed: {status}") - sys.exit(1) - -def cli_train_beatrice(args): - """CLI Beatrice training mode (embedded - downloads assets from HF)""" - print(f"=== Beatrice v2 Training (Embedded) ===") - print(f"Input audio: {args.audio}") - print(f"Output dir: {args.output}") - - # Preprocess audio - Beatrice expects: data_dir/speaker_name/*.wav - data_dir = os.path.join(args.output, "training_data") - speaker_name = os.path.splitext(os.path.basename(args.audio))[0] - - print(f"\n[1/2] Preprocessing audio for Beatrice...") - chunks, _ = preprocess_audio_for_beatrice(args.audio, data_dir, speaker_name) - - if chunks == 0: - print("Preprocessing failed - no valid audio chunks") - sys.exit(1) - - print(f"Created {chunks} audio chunks") - - # Train using embedded code - print(f"\n[2/2] Training Beatrice model ({args.epochs} epochs)...") - ckpt = None - for msg, path in train_beatrice_generator( - data_dir=data_dir, - output_dir=args.output, - epochs=args.epochs, - batch_size=args.batch, - resume=args.resume, - ): - print(msg) - if path: - ckpt = path - - if ckpt: - print(f"\nTraining complete!") - print(f"Model: {ckpt}") - else: - print("Training failed!") - sys.exit(1) - -def cli_train(args): - """CLI training mode""" - print(f"=== RVC Training ===") - print(f"Input audio: {args.audio}") - print(f"Output dir: {args.output}") - - # Create temp dir for preprocessing - data_dir = f"{args.output}/data" - - # Preprocess - print("\n[1/2] Preprocessing audio...") - result = preprocess_audio_for_training(args.audio, data_dir, target_sr=args.sr, f0_method=args.f0) - if result is None: - print("Preprocessing failed!") - sys.exit(1) - - # Train - print(f"\n[2/2] Training for {args.epochs} epochs...") - ckpt, idx = train_rvc( - data_dir=data_dir, - output_dir=args.output, - epochs=args.epochs, - batch_size=args.batch, - lr=args.lr, - target_sr=args.sr - ) - - if ckpt: - print(f"\nTraining complete!") - print(f"Model saved: {ckpt}") - if idx: - print(f"Index saved: {idx}") - else: - print("Training failed!") - sys.exit(1) - -if __name__ == "__main__": - # Check if any CLI args (besides script name) - if len(sys.argv) > 1: - parser = argparse.ArgumentParser( - description="RVC Voice Conversion - Inference and Training", - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - subparsers = parser.add_subparsers(dest="command", help="Commands") - - # Inference subcommand - infer_parser = subparsers.add_parser("infer", help="Voice conversion inference (RVC + Beatrice)", - epilog=""" -Examples: - python app.py infer -i voice.wav -m model.pth -o output.wav - python app.py infer -i voice.wav -m beatrice_model.pt.gz -o output.wav --type beatrice - python app.py infer -i voice.wav --example -o output.wav - """) - infer_parser.add_argument("-i", "--input", required=True, help="Input audio file") - infer_parser.add_argument("-o", "--output", required=True, help="Output audio file") - infer_parser.add_argument("-m", "--model", help="Model file (.pth for RVC, .pt.gz for Beatrice)") - infer_parser.add_argument("--type", choices=["rvc", "beatrice"], default=None, help="Model type (auto-detected from extension)") - infer_parser.add_argument("--index", help="Index file (.index) - RVC only") - infer_parser.add_argument("--example", action="store_true", help="Use example model (Benee-RVC)") - infer_parser.add_argument("-p", "--pitch", type=int, default=0, help="Pitch shift (-12 to 12)") - infer_parser.add_argument("--f0", choices=["rmvpe", "pm", "harvest"], default="rmvpe", help="F0 method (RVC only)") - infer_parser.add_argument("--index-rate", type=float, default=0.75, help="Index rate 0-1 (RVC only)") - infer_parser.add_argument("--speaker", type=int, default=0, help="Target speaker index (Beatrice only)") - infer_parser.add_argument("--formant-shift", type=float, default=0.0, help="Formant shift -2 to 2 (Beatrice only)") - - # Training subcommand (RVC) - train_parser = subparsers.add_parser("train", help="Train RVC model", - epilog=""" -Examples: - python app.py train -a voice.mp3 -o ./my_model --epochs 5 - python app.py train -a dataset.wav -o ./trained --epochs 10 --batch 4 - """) - train_parser.add_argument("-a", "--audio", required=True, help="Training audio file") - train_parser.add_argument("-o", "--output", required=True, help="Output directory for model") - train_parser.add_argument("--epochs", type=int, default=5, help="Number of epochs (default: 5)") - train_parser.add_argument("--batch", type=int, default=2, help="Batch size (default: 2)") - train_parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate (default: 1e-5)") - train_parser.add_argument("--sr", type=int, default=40000, help="Sample rate (default: 40000)") - train_parser.add_argument("--f0", choices=["rmvpe", "pm", "harvest"], default="rmvpe", help="F0 method (default: rmvpe)") - - # Beatrice v2 Training subcommand - beatrice_parser = subparsers.add_parser("train-beatrice", help="Train Beatrice v2 model (downloads assets from HF)", - epilog=""" -Examples: - python app.py train-beatrice -a voice.mp3 -o ./beatrice_model --epochs 20 - python app.py train-beatrice -a dataset.wav -o ./trained --epochs 50 --batch 8 --resume - -Pretrained assets are downloaded automatically from HuggingFace. - """) - beatrice_parser.add_argument("-a", "--audio", required=True, help="Training audio file") - beatrice_parser.add_argument("-o", "--output", required=True, help="Output directory for model") - beatrice_parser.add_argument("--epochs", type=int, default=20, help="Number of epochs (default: 20)") - beatrice_parser.add_argument("--batch", type=int, default=24, help="Batch size (default: 24, reduce for less VRAM)") - beatrice_parser.add_argument("--resume", action="store_true", help="Resume from checkpoint") - - args = parser.parse_args() - - if args.command == "infer": - cli_convert(args) - elif args.command == "train": - cli_train(args) - elif args.command == "train-beatrice": - cli_train_beatrice(args) - else: - parser.print_help() - else: - # No args = Gradio mode (Gradio 6 syntax) - demo.launch( - mcp_server=True, - show_error=True, - ssr_mode=False, - ) \ No newline at end of file