Spaces:
Sleeping
Sleeping
| """ | |
| RVC + Beatrice v2 Voice Conversion - Single-file app for HuggingFace Spaces | |
| RVC-Project + Beatrice v2 (fierce-cats/beatrice-trainer), consolidated into single file | |
| - Inference: RVC v2 (.pth) + Beatrice v2 (.pt.gz), CPU or GPU | |
| - Training: RVC v2 + Beatrice v2, GPU recommended | |
| Usage: | |
| CLI: python app.py infer -i input.wav -m model.pth -o output.wav | |
| python app.py infer -i input.wav -m beatrice.pt.gz -o output.wav | |
| Gradio: python app.py | |
| """ | |
| import os | |
| import sys | |
| # MPS fallback for macOS | |
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
| import argparse | |
| import gc | |
| import gzip | |
| import json as json_module | |
| import logging | |
| import math | |
| import re | |
| import shutil | |
| import tempfile | |
| import warnings | |
| # Suppress known harmless warnings from HF Spaces / torch internals | |
| warnings.filterwarnings("ignore", message=".*torch.distributed.reduce_op.*", category=FutureWarning) | |
| warnings.filterwarnings("ignore", message=".*torch.nn.utils.weight_norm.*", category=FutureWarning) | |
| from collections import defaultdict | |
| from fractions import Fraction | |
| from functools import partial | |
| from pathlib import Path | |
| from random import Random | |
| from typing import Optional, List, Tuple, Union, BinaryIO, Literal, Sequence, Iterable, Callable | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn import Conv1d, ConvTranspose1d | |
| from torch.nn.utils import weight_norm, remove_weight_norm | |
| import librosa | |
| import pyworld | |
| import soundfile as sf | |
| import torchaudio | |
| from scipy import signal | |
| from huggingface_hub import hf_hub_download | |
| from tqdm.auto import tqdm | |
| # 48 Hz high-pass filter to remove low-frequency artifacts (same as Applio) | |
| FILTER_ORDER = 5 | |
| CUTOFF_FREQUENCY = 48 # Hz | |
| SAMPLE_RATE = 16000 # Hz | |
| bh, ah = signal.butter(N=FILTER_ORDER, Wn=CUTOFF_FREQUENCY, btype="high", fs=SAMPLE_RATE) | |
| def sanitize_model_name(name: str) -> str: | |
| """Sanitize model name for safe use in file paths""" | |
| name = os.path.basename(name.strip()) | |
| name = re.sub(r'[^\w\-.]', '_', name) | |
| return name or "unnamed_model" | |
| def list_rvc_models() -> list: | |
| """Scan the weights/ directory and return a sorted list of .pth model filenames.""" | |
| weights_dir = Path("weights") | |
| if not weights_dir.exists(): | |
| return [] | |
| return sorted([p.name for p in weights_dir.glob("*.pth")]) | |
| # Default example model | |
| DEFAULT_MODEL_REPO = "audo/Benee-RVC" | |
| DEFAULT_MODEL_FILE = "BENEE8000.pth" | |
| DEFAULT_INDEX_FILE = "added_IVF1054_Flat_nprobe_8.index" | |
| # RVC v2 pretrained weights from official repo | |
| RVC_PRETRAINED_REPO = "lj1995/VoiceConversionWebUI" | |
| RVC_PRETRAINED_V2 = { | |
| # Generator with f0 (pitch) | |
| "f0G48k": "pretrained_v2/f0G48k.pth", | |
| "f0G40k": "pretrained_v2/f0G40k.pth", | |
| "f0G32k": "pretrained_v2/f0G32k.pth", | |
| # Discriminator with f0 | |
| "f0D48k": "pretrained_v2/f0D48k.pth", | |
| "f0D40k": "pretrained_v2/f0D40k.pth", | |
| "f0D32k": "pretrained_v2/f0D32k.pth", | |
| # Generator without f0 | |
| "G48k": "pretrained_v2/G48k.pth", | |
| "G40k": "pretrained_v2/G40k.pth", | |
| "G32k": "pretrained_v2/G32k.pth", | |
| # Discriminator without f0 | |
| "D48k": "pretrained_v2/D48k.pth", | |
| "D40k": "pretrained_v2/D40k.pth", | |
| "D32k": "pretrained_v2/D32k.pth", | |
| } | |
| def download_pretrained_rvc(name: str) -> str: | |
| """Download RVC v2 pretrained weights from HuggingFace""" | |
| if name not in RVC_PRETRAINED_V2: | |
| raise ValueError(f"Unknown pretrained: {name}. Available: {list(RVC_PRETRAINED_V2.keys())}") | |
| filepath = RVC_PRETRAINED_V2[name] | |
| logger.info(f"Downloading pretrained {name} from {RVC_PRETRAINED_REPO}...") | |
| return hf_hub_download(repo_id=RVC_PRETRAINED_REPO, filename=filepath) | |
| # Beatrice v2 pretrained assets | |
| BEATRICE_REPO = "fierce-cats/beatrice-trainer" | |
| BEATRICE_PRETRAINED = { | |
| "phone_extractor": "assets/pretrained/122_checkpoint_03000000.pt", | |
| "pitch_estimator": "assets/pretrained/104_3_checkpoint_00300000.pt", | |
| "pretrained_model": "assets/pretrained/151_checkpoint_libritts_r_200_02750000.pt.gz", | |
| } | |
| def download_beatrice_asset(name: str) -> str: | |
| """Download Beatrice v2 pretrained asset from HuggingFace""" | |
| if name not in BEATRICE_PRETRAINED: | |
| raise ValueError(f"Unknown asset: {name}. Available: {list(BEATRICE_PRETRAINED.keys())}") | |
| filepath = BEATRICE_PRETRAINED[name] | |
| logger.info(f"Downloading Beatrice asset {name} from {BEATRICE_REPO}...") | |
| return hf_hub_download(repo_id=BEATRICE_REPO, filename=filepath) | |
| def download_beatrice_augmentation(): | |
| """Download Beatrice augmentation assets (noise + IR) - optional for training""" | |
| try: | |
| from huggingface_hub import snapshot_download | |
| cache_dir = snapshot_download(repo_id=BEATRICE_REPO, allow_patterns=["assets/noise/*", "assets/ir/*"]) | |
| noise_dir = os.path.join(cache_dir, "assets", "noise") | |
| ir_dir = os.path.join(cache_dir, "assets", "ir") | |
| if os.path.isdir(noise_dir) and os.path.isdir(ir_dir): | |
| return noise_dir, ir_dir | |
| return None, None | |
| except Exception as e: | |
| logger.warning(f"Could not download augmentation assets: {e}") | |
| return None, None | |
| def load_pretrained_weights(model: nn.Module, pretrained_path: str) -> None: | |
| """Load pretrained weights into model, handling speaker embedding mismatch""" | |
| logger.info(f"Loading pretrained weights: {pretrained_path}") | |
| state_dict = torch.load(pretrained_path, map_location="cpu", weights_only=True) | |
| # Handle different checkpoint formats | |
| if "model" in state_dict: | |
| state_dict = state_dict["model"] | |
| # Filter out mismatched keys, but handle emb_g specially | |
| model_state = model.state_dict() | |
| filtered_state = {} | |
| skipped = [] | |
| for k, v in state_dict.items(): | |
| if k in model_state: | |
| if v.shape == model_state[k].shape: | |
| filtered_state[k] = v | |
| elif k == "emb_g.weight": | |
| # Initialize our speaker embedding with mean of pretrained embeddings | |
| # This gives a much better starting point than random initialization | |
| mean_emb = v.mean(dim=0, keepdim=True) # [1, 256] | |
| num_speakers = model_state[k].shape[0] | |
| filtered_state[k] = mean_emb.expand(num_speakers, -1).clone() | |
| logger.info(f"Initialized emb_g from pretrained mean ({v.shape[0]} -> {num_speakers} speakers)") | |
| else: | |
| skipped.append(f"{k}: {v.shape} vs {model_state[k].shape}") | |
| else: | |
| skipped.append(f"{k}: not in model") | |
| if skipped: | |
| logger.info(f"Skipped {len(skipped)} mismatched keys") | |
| model.load_state_dict(filtered_state, strict=False) | |
| logger.info(f"Loaded {len(filtered_state)}/{len(state_dict)} pretrained weights") | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Device selection: | |
| # - Inference: Always CPU (HF Spaces free tier, also works everywhere) | |
| # - Training: GPU if available for speed, CPU fallback | |
| device = torch.device("cpu") # For inference | |
| train_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # For training | |
| logger.info(f"Inference device: {device}") | |
| logger.info(f"Training device: {train_device}") | |
| # ============================================================ | |
| # CPU OPTIMIZATION — Locked 2-core HuggingFace Spaces config | |
| # ============================================================ | |
| # Restrict PyTorch to exactly 2 physical cores. | |
| # OpenMP and MKL must both be capped before any tensor ops fire. | |
| _CPU_CORES = 2 | |
| torch.set_num_threads(_CPU_CORES) | |
| torch.set_num_interop_threads(_CPU_CORES) | |
| os.environ["OMP_NUM_THREADS"] = str(_CPU_CORES) | |
| os.environ["MKL_NUM_THREADS"] = str(_CPU_CORES) | |
| os.environ["OPENBLAS_NUM_THREADS"] = str(_CPU_CORES) | |
| os.environ["VECLIB_MAXIMUM_THREADS"] = str(_CPU_CORES) | |
| os.environ["NUMEXPR_NUM_THREADS"] = str(_CPU_CORES) | |
| # torch.inference_mode is heavier than no_grad but also frees the | |
| # autograd graph eagerly, which helps on a memory-constrained CPU. | |
| # Enable oneDNN graph fusion (fuses conv+bn, linear+relu etc. into | |
| # single kernels — measurable speedup on Intel Xeon VMs). | |
| torch.backends.mkldnn.enabled = True | |
| try: | |
| torch.jit.enable_onednn_fusion(True) | |
| except Exception: | |
| pass | |
| logger.info(f"PyTorch CPU threads: {torch.get_num_threads()} (interop={torch.get_num_interop_threads()})") | |
| # ============================================================ | |
| # MEMORY MANAGEMENT — purge_memory() | |
| # Call this between every heavy operation to prevent OOM on | |
| # the 16 GB HuggingFace Spaces free-tier CPU instance. | |
| # ============================================================ | |
| import ctypes, platform | |
| def purge_memory(*tensors_or_arrays): | |
| """ | |
| Aggressively free memory after a heavy generation step. | |
| Pass any tensors / numpy arrays that should be deleted. | |
| The function: | |
| 1. Deletes every passed object from caller scope. | |
| 2. Runs Python gc (two passes: first collects cycles, | |
| second collects anything the first pass freed). | |
| 3. On Linux (HuggingFace Spaces), calls malloc_trim(0) via | |
| ctypes so glibc returns freed pages to the OS immediately. | |
| Without this, RSS can stay high even after gc.collect(). | |
| 4. Clears CUDA cache if a GPU is somehow available. | |
| """ | |
| for obj in tensors_or_arrays: | |
| try: | |
| del obj | |
| except Exception: | |
| pass | |
| # Two-pass gc: cycles first, then their referents | |
| gc.collect() | |
| gc.collect() | |
| # Return glibc memory to the OS (Linux only — HF Spaces is Linux) | |
| if platform.system() == "Linux": | |
| try: | |
| ctypes.CDLL("libc.so.6").malloc_trim(0) | |
| except Exception: | |
| pass | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| # ============================================================ | |
| # COMMONS - Helper functions from infer/lib/infer_pack/commons.py | |
| # ============================================================ | |
| def init_weights(m, mean=0.0, std=0.01): | |
| classname = m.__class__.__name__ | |
| if classname.find("Conv") != -1: | |
| m.weight.data.normal_(mean, std) | |
| def get_padding(kernel_size, dilation=1): | |
| return int((kernel_size * dilation - dilation) / 2) | |
| def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None): | |
| if max_length is None: | |
| max_length = length.max() | |
| x = torch.arange(max_length, dtype=length.dtype, device=length.device) | |
| return x.unsqueeze(0) < length.unsqueeze(1) | |
| def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): | |
| n_channels_int = n_channels[0] | |
| in_act = input_a + input_b | |
| t_act = torch.tanh(in_act[:, :n_channels_int, :]) | |
| s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) | |
| return t_act * s_act | |
| def slice_segments(x, ids_str, segment_size=4): | |
| """Slice segments from tensor""" | |
| ret = torch.zeros_like(x[:, :, :segment_size]) | |
| for i in range(x.size(0)): | |
| idx_str = ids_str[i] | |
| idx_end = idx_str + segment_size | |
| ret[i] = x[i, :, idx_str:idx_end] | |
| return ret | |
| def slice_segments2(x, ids_str, segment_size=4): | |
| """Slice segments from 2D tensor""" | |
| ret = torch.zeros_like(x[:, :segment_size]) | |
| for i in range(x.size(0)): | |
| idx_str = ids_str[i] | |
| idx_end = idx_str + segment_size | |
| ret[i] = x[i, idx_str:idx_end] | |
| return ret | |
| def rand_slice_segments(x, x_lengths=None, segment_size=4): | |
| """Random slice segments""" | |
| b, d, t = x.size() | |
| if x_lengths is None: | |
| x_lengths = t | |
| ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=1) | |
| ids_str = (torch.rand([b], device=x.device) * ids_str_max.float()).long() | |
| ret = slice_segments(x, ids_str, segment_size) | |
| return ret, ids_str | |
| # ============================================================ | |
| # MODULES - From infer/lib/infer_pack/modules.py | |
| # ============================================================ | |
| LRELU_SLOPE = 0.1 | |
| class LayerNorm(nn.Module): | |
| def __init__(self, channels, eps=1e-5): | |
| super().__init__() | |
| self.channels = channels | |
| self.eps = eps | |
| self.gamma = nn.Parameter(torch.ones(channels)) | |
| self.beta = nn.Parameter(torch.zeros(channels)) | |
| def forward(self, x): | |
| x = x.transpose(1, -1) | |
| x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) | |
| return x.transpose(1, -1) | |
| class WN(nn.Module): | |
| def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): | |
| super().__init__() | |
| assert kernel_size % 2 == 1 | |
| self.hidden_channels = hidden_channels | |
| self.kernel_size = (kernel_size,) | |
| self.dilation_rate = dilation_rate | |
| self.n_layers = n_layers | |
| self.gin_channels = gin_channels | |
| self.p_dropout = float(p_dropout) | |
| self.in_layers = nn.ModuleList() | |
| self.res_skip_layers = nn.ModuleList() | |
| self.drop = nn.Dropout(float(p_dropout)) | |
| if gin_channels != 0: | |
| cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) | |
| self.cond_layer = weight_norm(cond_layer, name="weight") | |
| for i in range(n_layers): | |
| dilation = dilation_rate ** i | |
| padding = int((kernel_size * dilation - dilation) / 2) | |
| in_layer = nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding) | |
| in_layer = weight_norm(in_layer, name="weight") | |
| self.in_layers.append(in_layer) | |
| if i < n_layers - 1: | |
| res_skip_channels = 2 * hidden_channels | |
| else: | |
| res_skip_channels = hidden_channels | |
| res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1) | |
| res_skip_layer = weight_norm(res_skip_layer, name="weight") | |
| self.res_skip_layers.append(res_skip_layer) | |
| def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None): | |
| output = torch.zeros_like(x) | |
| n_channels_tensor = torch.IntTensor([self.hidden_channels]) | |
| if g is not None: | |
| g = self.cond_layer(g) | |
| for i, (in_layer, res_skip_layer) in enumerate(zip(self.in_layers, self.res_skip_layers)): | |
| x_in = in_layer(x) | |
| if g is not None: | |
| cond_offset = i * 2 * self.hidden_channels | |
| g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] | |
| else: | |
| g_l = torch.zeros_like(x_in) | |
| acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) | |
| acts = self.drop(acts) | |
| res_skip_acts = res_skip_layer(acts) | |
| if i < self.n_layers - 1: | |
| res_acts = res_skip_acts[:, :self.hidden_channels, :] | |
| x = (x + res_acts) * x_mask | |
| output = output + res_skip_acts[:, self.hidden_channels:, :] | |
| else: | |
| output = output + res_skip_acts | |
| return output * x_mask | |
| def remove_weight_norm(self): | |
| if self.gin_channels != 0: | |
| remove_weight_norm(self.cond_layer) | |
| for l in self.in_layers: | |
| remove_weight_norm(l) | |
| for l in self.res_skip_layers: | |
| remove_weight_norm(l) | |
| class ResBlock1(nn.Module): | |
| def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): | |
| super().__init__() | |
| self.convs1 = nn.ModuleList([ | |
| weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), | |
| weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), | |
| weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]))), | |
| ]) | |
| self.convs1.apply(init_weights) | |
| self.convs2 = nn.ModuleList([ | |
| weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), | |
| weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), | |
| weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), | |
| ]) | |
| self.convs2.apply(init_weights) | |
| self.lrelu_slope = LRELU_SLOPE | |
| def forward(self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None): | |
| for c1, c2 in zip(self.convs1, self.convs2): | |
| xt = F.leaky_relu(x, self.lrelu_slope) | |
| if x_mask is not None: | |
| xt = xt * x_mask | |
| xt = c1(xt) | |
| xt = F.leaky_relu(xt, self.lrelu_slope) | |
| if x_mask is not None: | |
| xt = xt * x_mask | |
| xt = c2(xt) | |
| x = xt + x | |
| if x_mask is not None: | |
| x = x * x_mask | |
| return x | |
| def remove_weight_norm(self): | |
| for l in self.convs1: | |
| remove_weight_norm(l) | |
| for l in self.convs2: | |
| remove_weight_norm(l) | |
| class ResBlock2(nn.Module): | |
| def __init__(self, channels, kernel_size=3, dilation=(1, 3)): | |
| super().__init__() | |
| self.convs = nn.ModuleList([ | |
| weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), | |
| weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), | |
| ]) | |
| self.convs.apply(init_weights) | |
| self.lrelu_slope = LRELU_SLOPE | |
| def forward(self, x, x_mask: Optional[torch.Tensor] = None): | |
| for c in self.convs: | |
| xt = F.leaky_relu(x, self.lrelu_slope) | |
| if x_mask is not None: | |
| xt = xt * x_mask | |
| xt = c(xt) | |
| x = xt + x | |
| if x_mask is not None: | |
| x = x * x_mask | |
| return x | |
| def remove_weight_norm(self): | |
| for l in self.convs: | |
| remove_weight_norm(l) | |
| class Flip(nn.Module): | |
| def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): | |
| x = torch.flip(x, [1]) | |
| if not reverse: | |
| logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) | |
| return x, logdet | |
| else: | |
| return x, torch.zeros([1], device=x.device) | |
| class ResidualCouplingLayer(nn.Module): | |
| def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False): | |
| assert channels % 2 == 0 | |
| super().__init__() | |
| self.channels = channels | |
| self.hidden_channels = hidden_channels | |
| self.half_channels = channels // 2 | |
| self.mean_only = mean_only | |
| self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) | |
| self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=float(p_dropout), gin_channels=gin_channels) | |
| self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) | |
| self.post.weight.data.zero_() | |
| self.post.bias.data.zero_() | |
| def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): | |
| x0, x1 = torch.split(x, [self.half_channels] * 2, 1) | |
| h = self.pre(x0) * x_mask | |
| h = self.enc(h, x_mask, g=g) | |
| stats = self.post(h) * x_mask | |
| if not self.mean_only: | |
| m, logs = torch.split(stats, [self.half_channels] * 2, 1) | |
| else: | |
| m = stats | |
| logs = torch.zeros_like(m) | |
| if not reverse: | |
| x1 = m + x1 * torch.exp(logs) * x_mask | |
| x = torch.cat([x0, x1], 1) | |
| logdet = torch.sum(logs, [1, 2]) | |
| return x, logdet | |
| else: | |
| x1 = (x1 - m) * torch.exp(-logs) * x_mask | |
| x = torch.cat([x0, x1], 1) | |
| return x, torch.zeros([1]) | |
| def remove_weight_norm(self): | |
| self.enc.remove_weight_norm() | |
| # ============================================================ | |
| # ATTENTIONS - From infer/lib/infer_pack/attentions.py | |
| # ============================================================ | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, proximal_bias=False, proximal_init=False): | |
| super().__init__() | |
| assert channels % n_heads == 0 | |
| self.channels = channels | |
| self.out_channels = out_channels | |
| self.n_heads = n_heads | |
| self.p_dropout = p_dropout | |
| self.window_size = window_size | |
| self.heads_share = heads_share | |
| self.proximal_bias = proximal_bias | |
| self.proximal_init = proximal_init | |
| self.k_channels = channels // n_heads | |
| self.conv_q = nn.Conv1d(channels, channels, 1) | |
| self.conv_k = nn.Conv1d(channels, channels, 1) | |
| self.conv_v = nn.Conv1d(channels, channels, 1) | |
| self.conv_o = nn.Conv1d(channels, out_channels, 1) | |
| self.drop = nn.Dropout(p_dropout) | |
| if window_size is not None: | |
| n_heads_rel = 1 if heads_share else n_heads | |
| rel_stddev = self.k_channels ** -0.5 | |
| self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) | |
| self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) | |
| nn.init.xavier_uniform_(self.conv_q.weight) | |
| nn.init.xavier_uniform_(self.conv_k.weight) | |
| nn.init.xavier_uniform_(self.conv_v.weight) | |
| if proximal_init: | |
| with torch.no_grad(): | |
| self.conv_k.weight.copy_(self.conv_q.weight) | |
| self.conv_k.bias.copy_(self.conv_q.bias) | |
| def forward(self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): | |
| q = self.conv_q(x) | |
| k = self.conv_k(c) | |
| v = self.conv_v(c) | |
| x, _ = self.attention(q, k, v, mask=attn_mask) | |
| x = self.conv_o(x) | |
| return x | |
| def attention(self, query, key, value, mask=None): | |
| b, d, t_s = key.size() | |
| t_t = query.size(2) | |
| query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) | |
| key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) | |
| value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) | |
| scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) | |
| if self.window_size is not None: | |
| key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) | |
| rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) | |
| scores_local = self._relative_position_to_absolute_position(rel_logits) | |
| scores = scores + scores_local | |
| if self.proximal_bias: | |
| scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) | |
| if mask is not None: | |
| scores = scores.masked_fill(mask == 0, -1e4) | |
| p_attn = F.softmax(scores, dim=-1) | |
| p_attn = self.drop(p_attn) | |
| output = torch.matmul(p_attn, value) | |
| if self.window_size is not None: | |
| relative_weights = self._absolute_position_to_relative_position(p_attn) | |
| value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) | |
| output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) | |
| output = output.transpose(2, 3).contiguous().view(b, d, t_t) | |
| return output, p_attn | |
| def _matmul_with_relative_values(self, x, y): | |
| return torch.matmul(x, y.unsqueeze(0)) | |
| def _matmul_with_relative_keys(self, x, y): | |
| return torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) | |
| def _get_relative_embeddings(self, relative_embeddings, length): | |
| pad_length = max(length - (self.window_size + 1), 0) | |
| slice_start_position = max((self.window_size + 1) - length, 0) | |
| slice_end_position = slice_start_position + 2 * length - 1 | |
| if pad_length > 0: | |
| padded_relative_embeddings = F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) | |
| else: | |
| padded_relative_embeddings = relative_embeddings | |
| return padded_relative_embeddings[:, slice_start_position:slice_end_position] | |
| def _relative_position_to_absolute_position(self, x): | |
| batch, heads, length, _ = x.size() | |
| x = F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0]) | |
| x_flat = x.view([batch, heads, length * 2 * length]) | |
| x_flat = F.pad(x_flat, [0, int(length) - 1, 0, 0, 0, 0]) | |
| x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:] | |
| return x_final | |
| def _absolute_position_to_relative_position(self, x): | |
| batch, heads, length, _ = x.size() | |
| x = F.pad(x, [0, int(length) - 1, 0, 0, 0, 0, 0, 0]) | |
| x_flat = x.view([batch, heads, int(length ** 2) + int(length * (length - 1))]) | |
| x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) | |
| x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] | |
| return x_final | |
| def _attention_bias_proximal(self, length): | |
| r = torch.arange(length, dtype=torch.float32) | |
| diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) | |
| return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) | |
| class FFN(nn.Module): | |
| def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.filter_channels = filter_channels | |
| self.kernel_size = kernel_size | |
| self.p_dropout = p_dropout | |
| self.causal = causal | |
| self.is_activation = activation == "gelu" | |
| self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) | |
| self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) | |
| self.drop = nn.Dropout(p_dropout) | |
| def forward(self, x: torch.Tensor, x_mask: torch.Tensor): | |
| x = self.conv_1(self._padding(x, x_mask)) | |
| if self.is_activation: | |
| x = x * torch.sigmoid(1.702 * x) | |
| else: | |
| x = torch.relu(x) | |
| x = self.drop(x) | |
| x = self.conv_2(self._padding(x, x_mask)) | |
| return x * x_mask | |
| def _padding(self, x, x_mask): | |
| if self.causal: | |
| if self.kernel_size == 1: | |
| return x * x_mask | |
| pad_l = self.kernel_size - 1 | |
| return F.pad(x * x_mask, [pad_l, 0, 0, 0, 0, 0]) | |
| else: | |
| if self.kernel_size == 1: | |
| return x * x_mask | |
| pad_l = (self.kernel_size - 1) // 2 | |
| pad_r = self.kernel_size // 2 | |
| return F.pad(x * x_mask, [pad_l, pad_r, 0, 0, 0, 0]) | |
| class Encoder(nn.Module): | |
| def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, window_size=10, **kwargs): | |
| super().__init__() | |
| self.hidden_channels = hidden_channels | |
| self.n_layers = int(n_layers) | |
| self.drop = nn.Dropout(p_dropout) | |
| self.attn_layers = nn.ModuleList() | |
| self.norm_layers_1 = nn.ModuleList() | |
| self.ffn_layers = nn.ModuleList() | |
| self.norm_layers_2 = nn.ModuleList() | |
| for i in range(self.n_layers): | |
| self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) | |
| self.norm_layers_1.append(LayerNorm(hidden_channels)) | |
| self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) | |
| self.norm_layers_2.append(LayerNorm(hidden_channels)) | |
| def forward(self, x, x_mask): | |
| attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) | |
| x = x * x_mask | |
| for attn, norm1, ffn, norm2 in zip(self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2): | |
| y = attn(x, x, attn_mask) | |
| y = self.drop(y) | |
| x = norm1(x + y) | |
| y = ffn(x, x_mask) | |
| y = self.drop(y) | |
| x = norm2(x + y) | |
| return x * x_mask | |
| # ============================================================ | |
| # MODELS - From infer/lib/infer_pack/models.py | |
| # ============================================================ | |
| sr2sr = {"32k": 32000, "40k": 40000, "48k": 48000} | |
| class TextEncoder256(nn.Module): | |
| def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True): | |
| super().__init__() | |
| self.out_channels = out_channels | |
| self.hidden_channels = hidden_channels | |
| self.emb_phone = nn.Linear(256, hidden_channels) | |
| self.lrelu = nn.LeakyReLU(0.1, inplace=True) | |
| if f0: | |
| self.emb_pitch = nn.Embedding(256, hidden_channels) | |
| self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) | |
| self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) | |
| def forward(self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor): | |
| if pitch is None: | |
| x = self.emb_phone(phone) | |
| else: | |
| x = self.emb_phone(phone) + self.emb_pitch(pitch) | |
| x = x * math.sqrt(self.hidden_channels) | |
| x = self.lrelu(x) | |
| x = torch.transpose(x, 1, -1) | |
| x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype) | |
| x = self.encoder(x * x_mask, x_mask) | |
| stats = self.proj(x) * x_mask | |
| m, logs = torch.split(stats, self.out_channels, dim=1) | |
| return m, logs, x_mask | |
| class TextEncoder768(nn.Module): | |
| def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True): | |
| super().__init__() | |
| self.out_channels = out_channels | |
| self.hidden_channels = hidden_channels | |
| self.emb_phone = nn.Linear(768, hidden_channels) | |
| self.lrelu = nn.LeakyReLU(0.1, inplace=True) | |
| if f0: | |
| self.emb_pitch = nn.Embedding(256, hidden_channels) | |
| self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) | |
| self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) | |
| def forward(self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor): | |
| if pitch is None: | |
| x = self.emb_phone(phone) | |
| else: | |
| x = self.emb_phone(phone) + self.emb_pitch(pitch) | |
| x = x * math.sqrt(self.hidden_channels) | |
| x = self.lrelu(x) | |
| x = torch.transpose(x, 1, -1) | |
| x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype) | |
| x = self.encoder(x * x_mask, x_mask) | |
| stats = self.proj(x) * x_mask | |
| m, logs = torch.split(stats, self.out_channels, dim=1) | |
| return m, logs, x_mask | |
| class ResidualCouplingBlock(nn.Module): | |
| def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): | |
| super().__init__() | |
| self.n_flows = n_flows | |
| self.flows = nn.ModuleList() | |
| for i in range(n_flows): | |
| self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) | |
| self.flows.append(Flip()) | |
| def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): | |
| if not reverse: | |
| for flow in self.flows: | |
| x, _ = flow(x, x_mask, g=g, reverse=reverse) | |
| else: | |
| for flow in self.flows[::-1]: | |
| x, _ = flow.forward(x, x_mask, g=g, reverse=reverse) | |
| return x | |
| def remove_weight_norm(self): | |
| for i in range(self.n_flows): | |
| self.flows[i * 2].remove_weight_norm() | |
| class PosteriorEncoder(nn.Module): | |
| def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0): | |
| super().__init__() | |
| self.out_channels = out_channels | |
| self.pre = nn.Conv1d(in_channels, hidden_channels, 1) | |
| self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) | |
| self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) | |
| def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None): | |
| x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) | |
| x = self.pre(x) * x_mask | |
| x = self.enc(x, x_mask, g=g) | |
| stats = self.proj(x) * x_mask | |
| m, logs = torch.split(stats, self.out_channels, dim=1) | |
| z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask | |
| return z, m, logs, x_mask | |
| def remove_weight_norm(self): | |
| self.enc.remove_weight_norm() | |
| class Generator(nn.Module): | |
| def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): | |
| super().__init__() | |
| self.num_kernels = len(resblock_kernel_sizes) | |
| self.num_upsamples = len(upsample_rates) | |
| self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) | |
| resblock_class = ResBlock1 if resblock == "1" else ResBlock2 | |
| self.ups = nn.ModuleList() | |
| for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): | |
| self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2))) | |
| self.resblocks = nn.ModuleList() | |
| for i in range(len(self.ups)): | |
| ch = upsample_initial_channel // (2 ** (i + 1)) | |
| for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): | |
| self.resblocks.append(resblock_class(ch, k, d)) | |
| self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) | |
| self.ups.apply(init_weights) | |
| if gin_channels != 0: | |
| self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) | |
| def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None): | |
| x = self.conv_pre(x) | |
| if g is not None: | |
| x = x + self.cond(g) | |
| for i in range(self.num_upsamples): | |
| x = F.leaky_relu(x, LRELU_SLOPE) | |
| x = self.ups[i](x) | |
| xs = None | |
| for j in range(self.num_kernels): | |
| if xs is None: | |
| xs = self.resblocks[i * self.num_kernels + j](x) | |
| else: | |
| xs += self.resblocks[i * self.num_kernels + j](x) | |
| x = xs / self.num_kernels | |
| x = F.leaky_relu(x) | |
| x = self.conv_post(x) | |
| x = torch.tanh(x) | |
| return x | |
| def remove_weight_norm(self): | |
| for l in self.ups: | |
| remove_weight_norm(l) | |
| for l in self.resblocks: | |
| l.remove_weight_norm() | |
| class SineGen(nn.Module): | |
| def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0): | |
| super().__init__() | |
| self.sine_amp = sine_amp | |
| self.noise_std = noise_std | |
| self.harmonic_num = harmonic_num | |
| self.dim = harmonic_num + 1 | |
| self.sampling_rate = samp_rate | |
| self.voiced_threshold = voiced_threshold | |
| def _f02uv(self, f0): | |
| uv = torch.ones_like(f0) | |
| uv = uv * (f0 > self.voiced_threshold) | |
| return uv.float() | |
| def forward(self, f0: torch.Tensor, upp: int): | |
| with torch.no_grad(): | |
| f0 = f0[:, None].transpose(1, 2) | |
| f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) | |
| f0_buf[:, :, 0] = f0[:, :, 0] | |
| for idx in range(self.harmonic_num): | |
| f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2) | |
| rad_values = (f0_buf / self.sampling_rate) % 1 | |
| rand_ini = torch.rand(f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device) | |
| rand_ini[:, 0] = 0 | |
| rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini | |
| tmp_over_one = torch.cumsum(rad_values, 1) | |
| tmp_over_one *= upp | |
| tmp_over_one = F.interpolate(tmp_over_one.transpose(2, 1), scale_factor=float(upp), mode="linear", align_corners=True).transpose(2, 1) | |
| rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1) | |
| tmp_over_one %= 1 | |
| tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 | |
| cumsum_shift = torch.zeros_like(rad_values) | |
| cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 | |
| sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi) | |
| sine_waves = sine_waves * self.sine_amp | |
| uv = self._f02uv(f0) | |
| uv = F.interpolate(uv.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1) | |
| noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 | |
| noise = noise_amp * torch.randn_like(sine_waves) | |
| sine_waves = sine_waves * uv + noise | |
| return sine_waves, uv, noise | |
| class SourceModuleHnNSF(nn.Module): | |
| def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0, is_half=False): | |
| super().__init__() | |
| self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod) | |
| self.l_linear = nn.Linear(harmonic_num + 1, 1) | |
| self.l_tanh = nn.Tanh() | |
| def forward(self, x: torch.Tensor, upp: int = 1): | |
| sine_wavs, uv, _ = self.l_sin_gen(x, upp) | |
| sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype) | |
| sine_merge = self.l_tanh(self.l_linear(sine_wavs)) | |
| return sine_merge, None, None | |
| class GeneratorNSF(nn.Module): | |
| def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, sr, is_half=False): | |
| super().__init__() | |
| self.num_kernels = len(resblock_kernel_sizes) | |
| self.num_upsamples = len(upsample_rates) | |
| self.f0_upsamp = nn.Upsample(scale_factor=math.prod(upsample_rates)) | |
| self.m_source = SourceModuleHnNSF(sampling_rate=sr, harmonic_num=0, is_half=is_half) | |
| self.noise_convs = nn.ModuleList() | |
| self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) | |
| resblock_class = ResBlock1 if resblock == "1" else ResBlock2 | |
| self.ups = nn.ModuleList() | |
| for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): | |
| c_cur = upsample_initial_channel // (2 ** (i + 1)) | |
| self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2))) | |
| if i + 1 < len(upsample_rates): | |
| stride_f0 = math.prod(upsample_rates[i + 1:]) | |
| self.noise_convs.append(Conv1d(1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)) | |
| else: | |
| self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) | |
| self.resblocks = nn.ModuleList() | |
| for i in range(len(self.ups)): | |
| ch = upsample_initial_channel // (2 ** (i + 1)) | |
| for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): | |
| self.resblocks.append(resblock_class(ch, k, d)) | |
| self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) | |
| self.ups.apply(init_weights) | |
| if gin_channels != 0: | |
| self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) | |
| self.upp = math.prod(upsample_rates) | |
| self.lrelu_slope = LRELU_SLOPE | |
| def forward(self, x, f0, g: Optional[torch.Tensor] = None): | |
| har_source, _, _ = self.m_source(f0, self.upp) | |
| har_source = har_source.transpose(1, 2) | |
| x = self.conv_pre(x) | |
| if g is not None: | |
| x = x + self.cond(g) | |
| for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)): | |
| if i < self.num_upsamples: | |
| x = F.leaky_relu(x, self.lrelu_slope) | |
| x = ups(x) | |
| x_source = noise_convs(har_source) | |
| x = x + x_source | |
| xs = None | |
| l = [i * self.num_kernels + j for j in range(self.num_kernels)] | |
| for j, resblock in enumerate(self.resblocks): | |
| if j in l: | |
| if xs is None: | |
| xs = resblock(x) | |
| else: | |
| xs += resblock(x) | |
| x = xs / self.num_kernels | |
| x = F.leaky_relu(x) | |
| x = self.conv_post(x) | |
| x = torch.tanh(x) | |
| return x | |
| def remove_weight_norm(self): | |
| for l in self.ups: | |
| remove_weight_norm(l) | |
| for l in self.resblocks: | |
| l.remove_weight_norm() | |
| # Synthesizer classes for different model versions | |
| class SynthesizerTrnMs256NSFsid(nn.Module): | |
| """RVC v1 model with f0""" | |
| def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, **kwargs): | |
| super().__init__() | |
| if isinstance(sr, str): | |
| sr = sr2sr[sr] | |
| self.segment_size = segment_size | |
| self.gin_channels = gin_channels | |
| self.spk_embed_dim = spk_embed_dim | |
| self.enc_p = TextEncoder256(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) | |
| self.dec = GeneratorNSF(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, is_half=kwargs.get("is_half", False)) | |
| self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) | |
| self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) | |
| self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) | |
| def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: torch.Tensor, nsff0: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): | |
| g = self.emb_g(sid).unsqueeze(-1) | |
| m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) | |
| z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask | |
| if rate is not None: | |
| head = int(z_p.shape[2] * (1 - rate.item())) | |
| z_p = z_p[:, :, head:] | |
| x_mask = x_mask[:, :, head:] | |
| nsff0 = nsff0[:, head:] | |
| z = self.flow(z_p, x_mask, g=g, reverse=True) | |
| o = self.dec(z * x_mask, nsff0, g=g) | |
| return o, x_mask, (z, z_p, m_p, logs_p) | |
| class SynthesizerTrnMs768NSFsid(nn.Module): | |
| """RVC v2 model with f0""" | |
| def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, **kwargs): | |
| super().__init__() | |
| if isinstance(sr, str): | |
| sr = sr2sr[sr] | |
| self.segment_size = segment_size | |
| self.gin_channels = gin_channels | |
| self.spk_embed_dim = spk_embed_dim | |
| self.enc_p = TextEncoder768(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) | |
| self.dec = GeneratorNSF(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, is_half=kwargs.get("is_half", False)) | |
| self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) | |
| self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) | |
| self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) | |
| def forward(self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds): | |
| """Training forward pass""" | |
| g = self.emb_g(ds).unsqueeze(-1) | |
| m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) | |
| z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) | |
| z_p = self.flow(z, y_mask, g=g) | |
| z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size) | |
| pitchf = slice_segments2(pitchf, ids_slice, self.segment_size) | |
| o = self.dec(z_slice, pitchf, g=g) | |
| return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) | |
| def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: torch.Tensor, nsff0: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): | |
| g = self.emb_g(sid).unsqueeze(-1) | |
| m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) | |
| z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask | |
| if rate is not None: | |
| head = int(z_p.shape[2] * (1.0 - rate.item())) | |
| z_p = z_p[:, :, head:] | |
| x_mask = x_mask[:, :, head:] | |
| nsff0 = nsff0[:, head:] | |
| z = self.flow(z_p, x_mask, g=g, reverse=True) | |
| o = self.dec(z * x_mask, nsff0, g=g) | |
| return o, x_mask, (z, z_p, m_p, logs_p) | |
| class SynthesizerTrnMs256NSFsid_nono(nn.Module): | |
| """RVC v1 model without f0""" | |
| def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr=None, **kwargs): | |
| super().__init__() | |
| self.segment_size = segment_size | |
| self.gin_channels = gin_channels | |
| self.enc_p = TextEncoder256(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), f0=False) | |
| self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) | |
| self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) | |
| self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) | |
| self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) | |
| def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): | |
| g = self.emb_g(sid).unsqueeze(-1) | |
| m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) | |
| z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask | |
| if rate is not None: | |
| head = int(z_p.shape[2] * (1.0 - rate.item())) | |
| z_p = z_p[:, :, head:] | |
| x_mask = x_mask[:, :, head:] | |
| z = self.flow(z_p, x_mask, g=g, reverse=True) | |
| o = self.dec(z * x_mask, g=g) | |
| return o, x_mask, (z, z_p, m_p, logs_p) | |
| class SynthesizerTrnMs768NSFsid_nono(nn.Module): | |
| """RVC v2 model without f0""" | |
| def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr=None, **kwargs): | |
| super().__init__() | |
| self.segment_size = segment_size | |
| self.gin_channels = gin_channels | |
| self.enc_p = TextEncoder768(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), f0=False) | |
| self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) | |
| self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) | |
| self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) | |
| self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) | |
| def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): | |
| g = self.emb_g(sid).unsqueeze(-1) | |
| m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) | |
| z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask | |
| if rate is not None: | |
| head = int(z_p.shape[2] * (1.0 - rate.item())) | |
| z_p = z_p[:, :, head:] | |
| x_mask = x_mask[:, :, head:] | |
| z = self.flow(z_p, x_mask, g=g, reverse=True) | |
| o = self.dec(z * x_mask, g=g) | |
| return o, x_mask, (z, z_p, m_p, logs_p) | |
| # ============================================================ | |
| # DISCRIMINATOR - For training | |
| # ============================================================ | |
| class DiscriminatorS(nn.Module): | |
| def __init__(self, use_spectral_norm=False): | |
| super().__init__() | |
| norm_f = nn.utils.spectral_norm if use_spectral_norm else weight_norm | |
| self.convs = nn.ModuleList([ | |
| norm_f(Conv1d(1, 16, 15, 1, padding=7)), | |
| norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), | |
| norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), | |
| norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), | |
| norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), | |
| norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), | |
| ]) | |
| self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) | |
| def forward(self, x): | |
| fmap = [] | |
| for l in self.convs: | |
| x = l(x) | |
| x = F.leaky_relu(x, 0.1) | |
| fmap.append(x) | |
| x = self.conv_post(x) | |
| fmap.append(x) | |
| x = torch.flatten(x, 1, -1) | |
| return x, fmap | |
| class DiscriminatorP(nn.Module): | |
| def __init__(self, period, use_spectral_norm=False): | |
| super().__init__() | |
| self.period = period | |
| norm_f = nn.utils.spectral_norm if use_spectral_norm else weight_norm | |
| self.convs = nn.ModuleList([ | |
| norm_f(nn.Conv2d(1, 32, (5, 1), (3, 1), padding=(2, 0))), | |
| norm_f(nn.Conv2d(32, 128, (5, 1), (3, 1), padding=(2, 0))), | |
| norm_f(nn.Conv2d(128, 512, (5, 1), (3, 1), padding=(2, 0))), | |
| norm_f(nn.Conv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0))), | |
| norm_f(nn.Conv2d(1024, 1024, (5, 1), 1, padding=(2, 0))), | |
| ]) | |
| self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) | |
| def forward(self, x): | |
| fmap = [] | |
| b, c, t = x.shape | |
| if t % self.period != 0: | |
| n_pad = self.period - (t % self.period) | |
| x = F.pad(x, (0, n_pad), "reflect") | |
| t = t + n_pad | |
| x = x.view(b, c, t // self.period, self.period) | |
| for l in self.convs: | |
| x = l(x) | |
| x = F.leaky_relu(x, 0.1) | |
| fmap.append(x) | |
| x = self.conv_post(x) | |
| fmap.append(x) | |
| x = torch.flatten(x, 1, -1) | |
| return x, fmap | |
| class MultiPeriodDiscriminator(nn.Module): | |
| def __init__(self, use_spectral_norm=False): | |
| super().__init__() | |
| periods = [2, 3, 5, 7, 11, 17, 23, 37] # 8 periods for v2 pretrained (9 total discriminators) | |
| self.discriminators = nn.ModuleList( | |
| [DiscriminatorS(use_spectral_norm)] + | |
| [DiscriminatorP(p, use_spectral_norm) for p in periods] | |
| ) | |
| def forward(self, y, y_hat): | |
| y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], [] | |
| for d in self.discriminators: | |
| y_d_r, fmap_r = d(y) | |
| y_d_g, fmap_g = d(y_hat) | |
| y_d_rs.append(y_d_r) | |
| y_d_gs.append(y_d_g) | |
| fmap_rs.append(fmap_r) | |
| fmap_gs.append(fmap_g) | |
| return y_d_rs, y_d_gs, fmap_rs, fmap_gs | |
| # ============================================================ | |
| # TRAINING LOSSES | |
| # ============================================================ | |
| def feature_loss(fmap_r, fmap_g): | |
| loss = 0 | |
| for dr, dg in zip(fmap_r, fmap_g): | |
| for rl, gl in zip(dr, dg): | |
| loss += torch.mean(torch.abs(rl.float().detach() - gl.float())) | |
| return loss * 2 | |
| def discriminator_loss(disc_real_outputs, disc_generated_outputs): | |
| loss = 0 | |
| for dr, dg in zip(disc_real_outputs, disc_generated_outputs): | |
| loss += torch.mean((1 - dr.float()) ** 2) + torch.mean(dg.float() ** 2) | |
| return loss | |
| def generator_loss(disc_outputs): | |
| loss = 0 | |
| for dg in disc_outputs: | |
| loss += torch.mean((1 - dg.float()) ** 2) | |
| return loss | |
| def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): | |
| z_p, logs_q, m_p, logs_p, z_mask = [x.float() for x in [z_p, logs_q, m_p, logs_p, z_mask]] | |
| kl = logs_p - logs_q - 0.5 + 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) | |
| return torch.sum(kl * z_mask) / torch.sum(z_mask) | |
| # ============================================================ | |
| # HUBERT EXTRACTION - Using torchaudio bundle | |
| # ============================================================ | |
| # ContentVec model for v1 (256-dim) and HuBERT for v2 (768-dim) | |
| _contentvec_model = None # For v1 models (256-dim output) | |
| _hubert_model = None # For v2 models (768-dim output) | |
| _hubert_bundle = None | |
| CONTENTVEC_REPO = "IAHispano/Applio" | |
| CONTENTVEC_MODEL = "Resources/embedders/contentvec/pytorch_model.bin" | |
| CONTENTVEC_CONFIG = "Resources/embedders/contentvec/config.json" | |
| def load_contentvec(): | |
| """Load ContentVec model from HuggingFace for v1 models (256-dim output)""" | |
| global _contentvec_model | |
| if _contentvec_model is None: | |
| try: | |
| from transformers import HubertModel, HubertConfig | |
| logger.info("Loading ContentVec model from HuggingFace...") | |
| # Download model files | |
| model_path = hf_hub_download(repo_id=CONTENTVEC_REPO, filename=CONTENTVEC_MODEL) | |
| config_path = hf_hub_download(repo_id=CONTENTVEC_REPO, filename=CONTENTVEC_CONFIG) | |
| # Create model with final_proj layer | |
| class HubertModelWithFinalProj(HubertModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) | |
| config = HubertConfig.from_pretrained(config_path) | |
| _contentvec_model = HubertModelWithFinalProj(config) | |
| state_dict = torch.load(model_path, map_location="cpu", weights_only=True) | |
| _contentvec_model.load_state_dict(state_dict) | |
| _contentvec_model.to(device).eval() | |
| logger.info(f"ContentVec loaded: hidden={config.hidden_size}, proj={config.classifier_proj_size}") | |
| except Exception as e: | |
| logger.warning(f"Failed to load ContentVec: {e}, falling back to torchaudio HuBERT") | |
| _contentvec_model = None | |
| return _contentvec_model | |
| def load_hubert(): | |
| """Load HuBERT model via torchaudio for v2 models (768-dim output)""" | |
| global _hubert_model, _hubert_bundle | |
| if _hubert_model is None: | |
| import torchaudio | |
| logger.info("Loading HuBERT model via torchaudio...") | |
| _hubert_bundle = torchaudio.pipelines.HUBERT_BASE | |
| _hubert_model = _hubert_bundle.get_model().to(device) | |
| _hubert_model.eval() | |
| logger.info("HuBERT model loaded") | |
| return _hubert_model, _hubert_bundle | |
| def extract_hubert_features(audio: np.ndarray, sr: int = 16000, version: str = "v2") -> torch.Tensor: | |
| """Extract ContentVec features from audio (same as Applio) | |
| v1 models: Use ContentVec with final_proj (256-dim) | |
| v2 models: Use ContentVec without final_proj (768-dim) | |
| """ | |
| audio = audio.astype(np.float32) | |
| if np.abs(audio).max() > 1.0: | |
| audio = audio / np.abs(audio).max() | |
| inputs = torch.from_numpy(audio).unsqueeze(0).to(device) | |
| # Use ContentVec for ALL versions (same as Applio) | |
| contentvec = load_contentvec() | |
| if contentvec is not None: | |
| with torch.no_grad(): | |
| output = contentvec(inputs) | |
| if version == "v1": | |
| # v1: use final_proj for 256-dim | |
| feats = contentvec.final_proj(output.last_hidden_state) | |
| else: | |
| # v2: use raw hidden state (768-dim) | |
| feats = output.last_hidden_state | |
| return feats | |
| # Fallback to torchaudio HuBERT if ContentVec not available | |
| logger.warning("ContentVec not available, using torchaudio HuBERT (results may be degraded)") | |
| hubert, bundle = load_hubert() | |
| with torch.no_grad(): | |
| features, _ = hubert.extract_features(inputs) | |
| layer_idx = 11 if version == "v2" else 8 | |
| feats = features[min(layer_idx, len(features)-1)] | |
| if version == "v1": | |
| proj = nn.Linear(768, 256, bias=False).to(device) | |
| with torch.no_grad(): | |
| w = torch.zeros(256, 768) | |
| for i in range(256): | |
| w[i, i*3:(i+1)*3] = 1/3 | |
| proj.weight.copy_(w) | |
| feats = proj(feats) | |
| return feats | |
| # ============================================================ | |
| # F0 EXTRACTION | |
| # ============================================================ | |
| def extract_f0_pm(audio: np.ndarray, sr: int = 16000, f0_up_key: int = 0) -> Tuple[np.ndarray, np.ndarray]: | |
| """Extract F0 using parselmouth (pm method)""" | |
| import parselmouth | |
| p_len = audio.shape[0] // 160 + 1 | |
| f0_min = 65 | |
| f0_max = 1100 | |
| l_pad = int(np.ceil(1.5 / f0_min * 16000)) | |
| r_pad = l_pad + 1 | |
| s = parselmouth.Sound(np.pad(audio, (l_pad, r_pad)), 16000).to_pitch_ac( | |
| time_step=0.01, voicing_threshold=0.6, pitch_floor=f0_min, pitch_ceiling=f0_max, | |
| ) | |
| f0 = s.selected_array["frequency"] | |
| if len(f0) < p_len: | |
| f0 = np.pad(f0, (0, p_len - len(f0))) | |
| f0 = f0[:p_len] | |
| f0 *= pow(2, f0_up_key / 12) | |
| return f0_to_coarse(f0) | |
| def extract_f0_harvest(audio: np.ndarray, sr: int = 16000, f0_up_key: int = 0) -> Tuple[np.ndarray, np.ndarray]: | |
| """Extract F0 using pyworld harvest""" | |
| import pyworld | |
| from scipy import signal as scipy_signal | |
| f0, t = pyworld.harvest(audio.astype(np.double), fs=16000, f0_ceil=1100, f0_floor=50, frame_period=10) | |
| f0 = scipy_signal.medfilt(f0, 3) | |
| f0 *= pow(2, f0_up_key / 12) | |
| return f0_to_coarse(f0) | |
| def f0_to_coarse(f0: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | |
| """Convert f0 to coarse representation""" | |
| f0_min = 50 | |
| f0_max = 1100 | |
| f0_mel_min = 1127 * np.log(1 + f0_min / 700) | |
| f0_mel_max = 1127 * np.log(1 + f0_max / 700) | |
| f0bak = f0.copy() | |
| f0_mel = 1127 * np.log(1 + f0 / 700) | |
| f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1 | |
| f0_mel[f0_mel <= 1] = 1 | |
| f0_mel[f0_mel > 255] = 255 | |
| f0_coarse = np.rint(f0_mel).astype(np.int32) | |
| return f0_coarse, f0bak | |
| # ============================================================ | |
| # RMVPE F0 EXTRACTION (from Applio - IAHispano/Applio) | |
| # ============================================================ | |
| class RMVPE_ConvBlockRes(nn.Module): | |
| def __init__(self, in_channels, out_channels, momentum=0.01): | |
| super().__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False), | |
| nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), | |
| nn.Conv2d(out_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False), | |
| nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), | |
| ) | |
| self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) if in_channels != out_channels else None | |
| def forward(self, x): | |
| r = self.conv(x) | |
| return r + self.shortcut(x) if self.shortcut else r + x | |
| class RMVPE_ResEncoderBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01): | |
| super().__init__() | |
| self.conv = nn.ModuleList([RMVPE_ConvBlockRes(in_channels, out_channels, momentum)]) | |
| for _ in range(n_blocks - 1): | |
| self.conv.append(RMVPE_ConvBlockRes(out_channels, out_channels, momentum)) | |
| self.kernel_size = kernel_size | |
| if kernel_size is not None: | |
| self.pool = nn.AvgPool2d(kernel_size=kernel_size) | |
| def forward(self, x): | |
| for c in self.conv: | |
| x = c(x) | |
| return (x, self.pool(x)) if self.kernel_size is not None else x | |
| class RMVPE_Encoder(nn.Module): | |
| def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01): | |
| super().__init__() | |
| self.n_encoders = n_encoders | |
| self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) | |
| self.layers = nn.ModuleList() | |
| for _ in range(n_encoders): | |
| self.layers.append(RMVPE_ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum)) | |
| in_channels = out_channels | |
| out_channels *= 2 | |
| in_size //= 2 | |
| self.out_size = in_size | |
| self.out_channel = out_channels | |
| def forward(self, x): | |
| concat_tensors = [] | |
| x = self.bn(x) | |
| for layer in self.layers: | |
| t, x = layer(x) | |
| concat_tensors.append(t) | |
| return x, concat_tensors | |
| class RMVPE_Intermediate(nn.Module): | |
| def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01): | |
| super().__init__() | |
| self.layers = nn.ModuleList([RMVPE_ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)]) | |
| for _ in range(n_inters - 1): | |
| self.layers.append(RMVPE_ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)) | |
| def forward(self, x): | |
| for layer in self.layers: | |
| x = layer(x) | |
| return x | |
| class RMVPE_ResDecoderBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): | |
| super().__init__() | |
| out_padding = (0, 1) if stride == (1, 2) else (1, 1) | |
| self.conv1 = nn.Sequential( | |
| nn.ConvTranspose2d(in_channels, out_channels, (3, 3), stride, (1, 1), out_padding, bias=False), | |
| nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), | |
| ) | |
| self.conv2 = nn.ModuleList([RMVPE_ConvBlockRes(out_channels * 2, out_channels, momentum)]) | |
| for _ in range(n_blocks - 1): | |
| self.conv2.append(RMVPE_ConvBlockRes(out_channels, out_channels, momentum)) | |
| def forward(self, x, concat_tensor): | |
| x = self.conv1(x) | |
| x = torch.cat((x, concat_tensor), dim=1) | |
| for c in self.conv2: | |
| x = c(x) | |
| return x | |
| class RMVPE_Decoder(nn.Module): | |
| def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01): | |
| super().__init__() | |
| self.layers = nn.ModuleList() | |
| for _ in range(n_decoders): | |
| out_channels = in_channels // 2 | |
| self.layers.append(RMVPE_ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)) | |
| in_channels = out_channels | |
| self.n_decoders = n_decoders | |
| def forward(self, x, concat_tensors): | |
| for i in range(self.n_decoders): | |
| x = self.layers[i](x, concat_tensors[-1 - i]) | |
| return x | |
| class RMVPE_DeepUnet(nn.Module): | |
| def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): | |
| super().__init__() | |
| self.encoder = RMVPE_Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels) | |
| self.intermediate = RMVPE_Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) | |
| self.decoder = RMVPE_Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) | |
| def forward(self, x): | |
| x, concat_tensors = self.encoder(x) | |
| x = self.intermediate(x) | |
| x = self.decoder(x, concat_tensors) | |
| return x | |
| class RMVPE_BiGRU(nn.Module): | |
| def __init__(self, input_features, hidden_features, num_layers): | |
| super().__init__() | |
| self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) | |
| def forward(self, x): | |
| return self.gru(x)[0] | |
| RMVPE_N_MELS = 128 | |
| RMVPE_N_CLASS = 360 | |
| class RMVPE_E2E(nn.Module): | |
| def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): | |
| super().__init__() | |
| self.unet = RMVPE_DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) | |
| self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) | |
| if n_gru: | |
| self.fc = nn.Sequential( | |
| RMVPE_BiGRU(3 * 128, 256, n_gru), | |
| nn.Linear(512, RMVPE_N_CLASS), nn.Dropout(0.25), nn.Sigmoid(), | |
| ) | |
| else: | |
| self.fc = nn.Sequential(nn.Linear(3 * RMVPE_N_MELS, RMVPE_N_CLASS), nn.Dropout(0.25), nn.Sigmoid()) | |
| def forward(self, mel): | |
| mel = mel.transpose(-1, -2).unsqueeze(1) | |
| x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) | |
| return self.fc(x) | |
| class RMVPE_MelSpectrogram(nn.Module): | |
| def __init__(self, n_mel_channels=128, sample_rate=16000, win_length=1024, hop_length=160, n_fft=None, mel_fmin=30, mel_fmax=8000, clamp=1e-5): | |
| super().__init__() | |
| from librosa.filters import mel as librosa_mel | |
| n_fft = win_length if n_fft is None else n_fft | |
| self.hann_window = {} | |
| mel_basis = librosa_mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True) | |
| self.register_buffer("mel_basis", torch.from_numpy(mel_basis).float()) | |
| self.n_fft = n_fft | |
| self.hop_length = hop_length | |
| self.win_length = win_length | |
| self.clamp = clamp | |
| def forward(self, audio, keyshift=0, speed=1, center=True): | |
| factor = 2 ** (keyshift / 12) | |
| n_fft_new = int(np.round(self.n_fft * factor)) | |
| win_length_new = int(np.round(self.win_length * factor)) | |
| hop_length_new = int(np.round(self.hop_length * speed)) | |
| key = f"{keyshift}_{audio.device}" | |
| if key not in self.hann_window: | |
| self.hann_window[key] = torch.hann_window(win_length_new).to(audio.device) | |
| fft = torch.stft(audio, n_fft=n_fft_new, hop_length=hop_length_new, win_length=win_length_new, | |
| window=self.hann_window[key], center=center, return_complex=True) | |
| magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) | |
| if keyshift != 0: | |
| size = self.n_fft // 2 + 1 | |
| resize = magnitude.size(1) | |
| if resize < size: | |
| magnitude = F.pad(magnitude, (0, 0, 0, size - resize)) | |
| magnitude = magnitude[:, :size, :] * self.win_length / win_length_new | |
| mel_output = torch.matmul(self.mel_basis, magnitude) | |
| return torch.log(torch.clamp(mel_output, min=self.clamp)) | |
| _rmvpe_model = None | |
| def load_rmvpe(): | |
| """Download and load RMVPE model for f0 extraction""" | |
| global _rmvpe_model | |
| if _rmvpe_model is None: | |
| logger.info("Downloading RMVPE model...") | |
| rmvpe_path = hf_hub_download(repo_id="IAHispano/Applio", filename="Resources/predictors/rmvpe.pt") | |
| model = RMVPE_E2E(4, 1, (2, 2)) | |
| ckpt = torch.load(rmvpe_path, map_location="cpu", weights_only=True) | |
| model.load_state_dict(ckpt) | |
| model.eval().to(device) | |
| mel_extractor = RMVPE_MelSpectrogram().to(device) | |
| cents_mapping = 20 * np.arange(RMVPE_N_CLASS) + 1997.3794084376191 | |
| _rmvpe_model = (model, mel_extractor, np.pad(cents_mapping, (4, 4))) | |
| logger.info("RMVPE model loaded") | |
| return _rmvpe_model | |
| def extract_f0_rmvpe(audio: np.ndarray, sr: int = 16000, f0_up_key: int = 0, thred: float = 0.03) -> Tuple[np.ndarray, np.ndarray]: | |
| """Extract F0 using RMVPE (best quality, neural network based)""" | |
| model, mel_extractor, cents_mapping = load_rmvpe() | |
| audio_t = torch.from_numpy(audio).float().to(device).unsqueeze(0) | |
| mel = mel_extractor(audio_t, center=True) | |
| del audio_t | |
| # mel2hidden with chunking | |
| with torch.no_grad(): | |
| n_frames = mel.shape[-1] | |
| mel_padded = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect") | |
| chunks = [] | |
| for start in range(0, mel_padded.shape[-1], 32000): | |
| end = min(start + 32000, mel_padded.shape[-1]) | |
| chunks.append(model(mel_padded[..., start:end])) | |
| hidden = torch.cat(chunks, dim=1)[:, :n_frames].squeeze(0).cpu().numpy() | |
| # Decode hidden to f0 | |
| center = np.argmax(hidden, axis=1) | |
| salience = np.pad(hidden, ((0, 0), (4, 4))) | |
| center += 4 | |
| todo_salience = [] | |
| todo_cents = [] | |
| for idx in range(salience.shape[0]): | |
| s, e = center[idx] - 4, center[idx] + 5 | |
| todo_salience.append(salience[idx, s:e]) | |
| todo_cents.append(cents_mapping[s:e]) | |
| todo_salience = np.array(todo_salience) | |
| todo_cents = np.array(todo_cents) | |
| cents_pred = np.sum(todo_salience * todo_cents, 1) / np.sum(todo_salience, 1) | |
| cents_pred[np.max(salience, axis=1) <= thred] = 0 | |
| f0 = 10 * (2 ** (cents_pred / 1200)) | |
| f0[f0 == 10] = 0 | |
| f0 *= pow(2, f0_up_key / 12) | |
| return f0_to_coarse(f0) | |
| # ============================================================ | |
| # MODEL LOADING | |
| # ============================================================ | |
| _model_cache = {} | |
| def load_rvc_model(model_path: str): | |
| """Load RVC model and auto-detect version""" | |
| if model_path in _model_cache: | |
| return _model_cache[model_path] | |
| logger.info(f"Loading RVC model: {model_path}") | |
| try: | |
| cpt = torch.load(model_path, map_location="cpu", weights_only=True) | |
| except Exception: | |
| logger.warning("Model requires unsafe loading - may be an older format") | |
| cpt = torch.load(model_path, map_location="cpu", weights_only=False) | |
| weight_key = None | |
| for key in ["weight", "model", "state_dict", "net_g"]: | |
| if key in cpt: | |
| weight_key = key | |
| break | |
| if weight_key is None: | |
| raise ValueError(f"Cannot find model weights. Keys: {list(cpt.keys())}") | |
| config = cpt.get("config", None) | |
| if config is None: | |
| config = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4], 109, 256, 40000] | |
| logger.warning("No config found, using v2 defaults") | |
| version = cpt.get("version", "v1") | |
| if_f0 = cpt.get("f0", 1) | |
| if weight_key in cpt: | |
| emb_weight = cpt[weight_key].get("emb_g.weight") | |
| if emb_weight is not None: | |
| config[-3] = emb_weight.shape[0] | |
| sr = config[-1] if isinstance(config[-1], int) else 40000 | |
| if version == "v1": | |
| model_class = SynthesizerTrnMs256NSFsid if if_f0 == 1 else SynthesizerTrnMs256NSFsid_nono | |
| else: | |
| model_class = SynthesizerTrnMs768NSFsid if if_f0 == 1 else SynthesizerTrnMs768NSFsid_nono | |
| model = model_class( | |
| spec_channels=config[0], segment_size=config[1], inter_channels=config[2], | |
| hidden_channels=config[3], filter_channels=config[4], n_heads=config[5], | |
| n_layers=config[6], kernel_size=config[7], p_dropout=config[8], | |
| resblock=config[9], resblock_kernel_sizes=config[10], | |
| resblock_dilation_sizes=config[11], upsample_rates=config[12], | |
| upsample_initial_channel=config[13], upsample_kernel_sizes=config[14], | |
| spk_embed_dim=config[15], gin_channels=config[16], sr=sr, is_half=False | |
| ) | |
| model.load_state_dict(cpt[weight_key], strict=False) | |
| model.eval().to(device) | |
| _model_cache[model_path] = (model, sr, version, if_f0) | |
| logger.info(f"Model loaded: version={version}, f0={if_f0}, sr={sr}") | |
| return model, sr, version, if_f0 | |
| # ============================================================ | |
| # TRAINING - Simplified for CPU testing | |
| # ============================================================ | |
| def spectrogram_torch(y, n_fft, hop_size, win_size, center=False): | |
| """Compute spectrogram""" | |
| hann_window = torch.hann_window(win_size).to(y.device) | |
| y = F.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect').squeeze(1) | |
| spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, | |
| center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) | |
| spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6) | |
| return spec | |
| # Mel spectrogram for training loss | |
| _mel_basis_cache = {} | |
| def spec_to_mel_torch(spec, n_fft=2048, num_mels=125, sampling_rate=40000, fmin=0, fmax=None): | |
| """Convert spectrogram to mel spectrogram""" | |
| from librosa.filters import mel as librosa_mel_fn | |
| global _mel_basis_cache | |
| if fmax is None: | |
| fmax = sampling_rate // 2 | |
| key = f"{n_fft}_{num_mels}_{sampling_rate}_{fmin}_{fmax}_{spec.dtype}_{spec.device}" | |
| if key not in _mel_basis_cache: | |
| mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) | |
| _mel_basis_cache[key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) | |
| melspec = torch.matmul(_mel_basis_cache[key], spec) | |
| melspec = torch.log(torch.clamp(melspec, min=1e-5)) # Log-amplitude | |
| return melspec | |
| def preprocess_audio_for_training(audio_path: str, output_dir: str, target_sr: int = 40000, f0_method: str = "rmvpe"): | |
| """Preprocess audio file for training - slice and extract features""" | |
| import scipy.signal as signal | |
| os.makedirs(output_dir, exist_ok=True) | |
| os.makedirs(f"{output_dir}/wavs", exist_ok=True) | |
| os.makedirs(f"{output_dir}/hubert", exist_ok=True) | |
| os.makedirs(f"{output_dir}/f0", exist_ok=True) | |
| logger.info(f"Preprocessing: {audio_path}") | |
| # Load and resample audio | |
| audio, sr = librosa.load(audio_path, sr=target_sr, mono=True) | |
| # High-pass filter | |
| bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=target_sr) | |
| audio = signal.lfilter(bh, ah, audio) | |
| # Slice into chunks (3.7 seconds with 0.3 overlap) | |
| chunk_size = int(3.7 * target_sr) | |
| hop = int(3.4 * target_sr) | |
| chunks = [] | |
| for i, start in enumerate(range(0, len(audio) - chunk_size, hop)): | |
| chunk = audio[start:start + chunk_size] | |
| # Normalize | |
| max_val = np.abs(chunk).max() | |
| if max_val > 0.01: # Skip silence | |
| chunk = chunk / max_val * 0.9 | |
| chunks.append((i, chunk)) | |
| if not chunks: | |
| logger.warning("No valid audio chunks found") | |
| return None | |
| logger.info(f"Created {len(chunks)} chunks") | |
| # Save chunks and extract features | |
| manifest = [] | |
| for idx, chunk in chunks: | |
| # Save wav | |
| wav_path = f"{output_dir}/wavs/{idx:04d}.wav" | |
| sf.write(wav_path, chunk, target_sr) | |
| # Resample to 16k for HuBERT | |
| chunk_16k = librosa.resample(chunk, orig_sr=target_sr, target_sr=16000) | |
| # Extract HuBERT features | |
| feats = extract_hubert_features(chunk_16k, sr=16000, version="v2") | |
| hubert_path = f"{output_dir}/hubert/{idx:04d}.npy" | |
| np.save(hubert_path, feats.squeeze(0).cpu().numpy()) | |
| # Extract F0 | |
| if f0_method == "rmvpe": | |
| f0_coarse, f0 = extract_f0_rmvpe(chunk_16k, 16000, 0) | |
| elif f0_method == "harvest": | |
| f0_coarse, f0 = extract_f0_harvest(chunk_16k, 16000, 0) | |
| else: | |
| f0_coarse, f0 = extract_f0_pm(chunk_16k, 16000, 0) | |
| f0_path = f"{output_dir}/f0/{idx:04d}.npy" | |
| np.save(f0_path, np.stack([f0_coarse, f0], axis=0)) | |
| manifest.append(f"{idx:04d}") | |
| # Save manifest | |
| with open(f"{output_dir}/manifest.txt", "w") as f: | |
| f.write("\n".join(manifest)) | |
| logger.info(f"Preprocessing complete: {len(manifest)} samples") | |
| return output_dir | |
| def train_rvc_generator( | |
| data_dir: str, | |
| output_dir: str, | |
| epochs: int = 10, | |
| batch_size: int = 2, | |
| lr: float = 1e-5, # Lower LR prevents overfitting on small data | |
| target_sr: int = 40000, | |
| progress_callback=None | |
| ): | |
| """Generator version of train_rvc - yields (epoch_msg, ckpt_path) tuples""" | |
| logger.info(f"Starting training: {data_dir} -> {output_dir}") | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Load manifest | |
| with open(f"{data_dir}/manifest.txt") as f: | |
| samples = [l.strip() for l in f if l.strip()] | |
| if len(samples) < 1: | |
| logger.error("No training samples found") | |
| return None | |
| logger.info(f"Training with {len(samples)} samples") | |
| # Model config (v2 40k defaults) | |
| config = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4], 1, 256, target_sr] | |
| # Create models (v2 only - 768-dim HuBERT features) | |
| net_g = SynthesizerTrnMs768NSFsid( | |
| spec_channels=config[0], segment_size=config[1], inter_channels=config[2], | |
| hidden_channels=config[3], filter_channels=config[4], n_heads=config[5], | |
| n_layers=config[6], kernel_size=config[7], p_dropout=config[8], | |
| resblock=config[9], resblock_kernel_sizes=config[10], | |
| resblock_dilation_sizes=config[11], upsample_rates=config[12], | |
| upsample_initial_channel=config[13], upsample_kernel_sizes=config[14], | |
| spk_embed_dim=config[15], gin_channels=config[16], sr=target_sr | |
| ).to(train_device) | |
| net_d = MultiPeriodDiscriminator().to(train_device) | |
| logger.info(f"Training on device: {train_device}") | |
| # Download and load pretrained weights (essential for good results) | |
| sr_key = f"{target_sr // 1000}k" # e.g., "40k" | |
| try: | |
| pretrain_g_path = download_pretrained_rvc(f"f0G{sr_key}") | |
| pretrain_d_path = download_pretrained_rvc(f"f0D{sr_key}") | |
| load_pretrained_weights(net_g, pretrain_g_path) | |
| load_pretrained_weights(net_d, pretrain_d_path) | |
| except Exception as e: | |
| logger.warning(f"Failed to load pretrained weights: {e}") | |
| logger.warning("Training from scratch (results may be poor)") | |
| # Optimizers (after loading pretrained weights) | |
| optim_g = torch.optim.AdamW(net_g.parameters(), lr=lr, betas=(0.8, 0.99)) | |
| optim_d = torch.optim.AdamW(net_d.parameters(), lr=lr, betas=(0.8, 0.99)) | |
| # LR scheduler (matches Applio - exponential decay) | |
| lr_decay = 0.999875 | |
| scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=lr_decay) | |
| scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=lr_decay) | |
| net_g.train() | |
| net_d.train() | |
| # Training loop | |
| for epoch in range(epochs): | |
| total_loss_g, total_loss_d = 0, 0 | |
| np.random.shuffle(samples) | |
| for i in range(0, len(samples), batch_size): | |
| batch_samples = samples[i:i+batch_size] | |
| # Load batch data | |
| wavs, huberts, f0s = [], [], [] | |
| for s in batch_samples: | |
| wav, _ = librosa.load(f"{data_dir}/wavs/{s}.wav", sr=target_sr, mono=True) | |
| hubert = np.load(f"{data_dir}/hubert/{s}.npy") | |
| # Upsample 50Hz -> 100Hz using interpolation (same as inference) | |
| hubert_t = torch.from_numpy(hubert).unsqueeze(0).permute(0, 2, 1) # (1, 768, seq) | |
| hubert_t = F.interpolate(hubert_t, scale_factor=2, mode='linear', align_corners=False) | |
| hubert = hubert_t.permute(0, 2, 1).squeeze(0).numpy() # (seq*2, 768) | |
| f0_data = np.load(f"{data_dir}/f0/{s}.npy") | |
| wavs.append(wav) | |
| huberts.append(hubert) | |
| f0s.append(f0_data) | |
| # Compute spectrogram first to get target length | |
| max_wav_len = max(len(w) for w in wavs) | |
| wav_batch = np.zeros((len(wavs), max_wav_len)) | |
| for j, w in enumerate(wavs): | |
| wav_batch[j, :len(w)] = w | |
| wav_t = torch.FloatTensor(wav_batch).unsqueeze(1).to(train_device) | |
| spec = spectrogram_torch(wav_t.squeeze(1), 2048, 400, 2048) | |
| spec_len = spec.shape[2] # Target length for all features | |
| # Pad/truncate features to match spec length exactly | |
| hubert_batch = np.zeros((len(huberts), spec_len, huberts[0].shape[1])) | |
| f0_batch = np.zeros((len(f0s), spec_len)) | |
| f0f_batch = np.zeros((len(f0s), spec_len)) | |
| for j, (h, f) in enumerate(zip(huberts, f0s)): | |
| # Truncate or pad HuBERT to spec_len | |
| h_len = min(h.shape[0], spec_len) | |
| hubert_batch[j, :h_len] = h[:h_len] | |
| # Truncate or pad F0 to spec_len | |
| f0_len = min(f.shape[1], spec_len) | |
| f0_batch[j, :f0_len] = f[0, :f0_len] | |
| f0f_batch[j, :f0_len] = f[1, :f0_len] | |
| # To tensors - all features now have spec_len | |
| hubert_t = torch.FloatTensor(hubert_batch).to(train_device) | |
| f0_t = torch.LongTensor(f0_batch.astype(np.int64)).to(train_device) | |
| f0f_t = torch.FloatTensor(f0f_batch).to(train_device) | |
| lengths_t = torch.LongTensor([spec_len] * len(batch_samples)).to(train_device) | |
| sid_t = torch.LongTensor([0] * len(batch_samples)).to(train_device) | |
| spec_lengths = torch.LongTensor([spec_len] * len(batch_samples)).to(train_device) | |
| # Forward pass generator | |
| # Args: phone, phone_lengths, pitch, pitchf, y (spec), y_lengths, ds | |
| try: | |
| y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = net_g( | |
| hubert_t, lengths_t, f0_t, f0f_t, spec, spec_lengths, sid_t | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Generator forward failed: {e}") | |
| continue | |
| # Slice wav at same position model generated (CRITICAL for proper loss) | |
| # ids_slice is in latent space, multiply by hop_length to get waveform position | |
| hop_length = 400 | |
| segment_size_wav = 32 * hop_length # segment_size in latent * hop_length | |
| y = slice_segments(wav_t, ids_slice * hop_length, segment_size_wav) | |
| # Discriminator forward | |
| y_d_rs, y_d_gs, fmap_rs, fmap_gs = net_d(y, y_hat.detach()) | |
| # Discriminator loss | |
| loss_d = discriminator_loss(y_d_rs, y_d_gs) | |
| optim_d.zero_grad() | |
| loss_d.backward() | |
| optim_d.step() | |
| # Generator loss | |
| y_d_rs, y_d_gs, fmap_rs, fmap_gs = net_d(y, y_hat) | |
| loss_gen = generator_loss(y_d_gs) | |
| loss_fm = feature_loss(fmap_rs, fmap_gs) | |
| loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) | |
| # Mel spectrogram loss (crucial for quality) | |
| # Config: n_fft=2048, hop=400, win=2048, n_mels=125, fmin=0, fmax=None | |
| y_mel = spec_to_mel_torch(spectrogram_torch(y.squeeze(1), 2048, 400, 2048), | |
| n_fft=2048, num_mels=125, sampling_rate=target_sr, fmin=0, fmax=None) | |
| y_hat_mel = spec_to_mel_torch(spectrogram_torch(y_hat.squeeze(1), 2048, 400, 2048), | |
| n_fft=2048, num_mels=125, sampling_rate=target_sr, fmin=0, fmax=None) | |
| # Align lengths if needed | |
| min_len = min(y_mel.shape[2], y_hat_mel.shape[2]) | |
| loss_mel = F.l1_loss(y_mel[:, :, :min_len], y_hat_mel[:, :, :min_len]) * 45 # c_mel = 45 | |
| loss_g = loss_gen + loss_fm + loss_mel + loss_kl | |
| optim_g.zero_grad() | |
| loss_g.backward() | |
| optim_g.step() | |
| total_loss_g += loss_g.item() | |
| total_loss_d += loss_d.item() | |
| avg_loss_g = total_loss_g / max(1, len(samples) // batch_size) | |
| avg_loss_d = total_loss_d / max(1, len(samples) // batch_size) | |
| epoch_msg = f"Epoch {epoch+1}/{epochs} - G: {avg_loss_g:.2f}, D: {avg_loss_d:.2f}" | |
| logger.info(epoch_msg) | |
| # Update progress callback if provided | |
| if progress_callback: | |
| progress_pct = 0.30 + (0.65 * (epoch + 1) / epochs) | |
| progress_callback(progress_pct, epoch_msg) | |
| # Yield epoch message for live UI updates | |
| yield epoch_msg, None, None | |
| # Step LR schedulers | |
| scheduler_g.step() | |
| scheduler_d.step() | |
| # Save checkpoint | |
| ckpt_path = f"{output_dir}/model.pth" | |
| torch.save({ | |
| "weight": net_g.state_dict(), | |
| "config": config, | |
| "version": "v2", # v2 only | |
| "f0": 1, | |
| }, ckpt_path) | |
| logger.info(f"Saved checkpoint: {ckpt_path}") | |
| # Generate index file for better speaker similarity | |
| index_path = None | |
| try: | |
| import faiss | |
| hubert_dir = f"{data_dir}/hubert" | |
| npys = [] | |
| for name in sorted(os.listdir(hubert_dir)): | |
| if name.endswith('.npy'): | |
| phone = np.load(os.path.join(hubert_dir, name)) | |
| npys.append(phone) | |
| if npys: | |
| big_npy = np.concatenate(npys, axis=0) | |
| n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39) | |
| n_ivf = max(1, n_ivf) # Ensure at least 1 | |
| index = faiss.index_factory(big_npy.shape[1], f"IVF{n_ivf},Flat") | |
| index.train(big_npy) | |
| index.add(big_npy) | |
| index_path = f"{output_dir}/model.index" | |
| faiss.write_index(index, index_path) | |
| logger.info(f"Saved index: {index_path}") | |
| except Exception as e: | |
| logger.warning(f"Failed to generate index: {e}") | |
| # Cleanup training models to free memory | |
| purge_memory(net_g, net_d, optim_g, optim_d, scheduler_g, scheduler_d) | |
| yield "Training complete!", ckpt_path, index_path | |
| def train_rvc( | |
| data_dir: str, | |
| output_dir: str, | |
| epochs: int = 10, | |
| batch_size: int = 2, | |
| lr: float = 1e-5, # Lower LR prevents overfitting on small data | |
| target_sr: int = 40000, | |
| progress_callback=None | |
| ): | |
| """Non-generator wrapper for CLI use - returns (checkpoint_path, index_path)""" | |
| ckpt = None | |
| idx = None | |
| for msg, path, index in train_rvc_generator(data_dir, output_dir, epochs, batch_size, lr, target_sr, progress_callback): | |
| if path: | |
| ckpt = path | |
| if index: | |
| idx = index | |
| return ckpt, idx | |
| # ============================================================ | |
| # INFERENCE | |
| # ============================================================ | |
| def convert_voice( | |
| source_audio: str, | |
| model_file, | |
| index_file=None, | |
| pitch_shift: int = 0, | |
| f0_method: str = "pm", | |
| index_rate: float = 0.5, | |
| protect: float = 0.33, | |
| volume_envelope: float = 1.0, | |
| progress=gr.Progress() | |
| ) -> Tuple[str, str]: | |
| """Convert voice using RVC model (Applio-compatible pipeline).""" | |
| try: | |
| if source_audio is None: | |
| return None, "Please upload source audio" | |
| if model_file is None: | |
| return None, "Please upload RVC model (.pth)" | |
| model_path = model_file.name if hasattr(model_file, 'name') else model_file | |
| progress(0.1, "Loading model...") | |
| model, tgt_sr, version, if_f0 = load_rvc_model(model_path) | |
| progress(0.2, "Loading audio...") | |
| audio, sr = librosa.load(source_audio, sr=16000, mono=True) | |
| # Apply 48Hz high-pass filter (critical - removes low-frequency artifacts) | |
| audio = signal.filtfilt(bh, ah, audio) | |
| # Normalize audio | |
| audio_max = np.abs(audio).max() / 0.95 | |
| if audio_max > 1: | |
| audio /= audio_max | |
| # Pipeline constants (same as Applio) | |
| window = 160 # Critical for feature/pitch alignment | |
| x_pad = 1 # Padding in seconds | |
| t_pad = 16000 * x_pad # Padding in samples | |
| # Pad audio | |
| audio_pad = np.pad(audio, (t_pad, t_pad), mode="reflect") | |
| p_len = audio_pad.shape[0] // window | |
| progress(0.3, "Extracting features...") | |
| feats = extract_hubert_features(audio_pad, sr=16000, version=version) | |
| # Save original features for protect mechanism | |
| feats0 = feats.clone() if if_f0 == 1 and protect < 0.5 else None | |
| # Index retrieval (speaker similarity) | |
| if index_file is not None and index_rate > 0: | |
| try: | |
| import faiss | |
| index_path = index_file.name if hasattr(index_file, 'name') else index_file | |
| progress(0.4, "Loading index...") | |
| index = faiss.read_index(index_path) | |
| big_npy = index.reconstruct_n(0, index.ntotal) | |
| npy = feats[0].cpu().numpy().astype("float32") | |
| score, ix = index.search(npy, k=8) | |
| weight = np.square(1 / score) | |
| weight /= weight.sum(axis=1, keepdims=True) | |
| npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) | |
| feats = torch.from_numpy(npy).unsqueeze(0).to(device) * index_rate + (1 - index_rate) * feats | |
| except Exception as e: | |
| logger.warning(f"Index retrieval failed: {e}") | |
| # Feature upsampling by 2x | |
| feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) | |
| # Adjust length based on audio | |
| p_len = min(audio_pad.shape[0] // window, feats.shape[1]) | |
| pitch, pitchf = None, None | |
| if if_f0 == 1: | |
| progress(0.5, f"Extracting F0 ({f0_method})...") | |
| if f0_method == "rmvpe": | |
| pitch, pitchf = extract_f0_rmvpe(audio_pad, 16000, pitch_shift) | |
| elif f0_method == "harvest": | |
| pitch, pitchf = extract_f0_harvest(audio_pad, 16000, pitch_shift) | |
| else: | |
| pitch, pitchf = extract_f0_pm(audio_pad, 16000, pitch_shift) | |
| pitch = pitch[:p_len] | |
| pitchf = pitchf[:p_len] | |
| # Upsample feats0 for protect | |
| if feats0 is not None: | |
| feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) | |
| # Apply protect mechanism (preserve original features for unvoiced segments) | |
| if protect < 0.5 and feats0 is not None: | |
| pitchf_tensor = torch.from_numpy(pitchf).float().to(device) | |
| pitchff = pitchf_tensor.clone() | |
| pitchff[pitchf_tensor > 0] = 1 | |
| pitchff[pitchf_tensor < 1] = protect | |
| pitchff = pitchff.unsqueeze(0).unsqueeze(-1) | |
| feats = feats[:, :p_len, :] * pitchff + feats0[:, :p_len, :] * (1 - pitchff) | |
| if len(pitch) < p_len: | |
| pitch = np.pad(pitch, (0, p_len - len(pitch))) | |
| pitchf = np.pad(pitchf, (0, p_len - len(pitchf))) | |
| pitch = torch.LongTensor(pitch).unsqueeze(0).to(device) | |
| pitchf = torch.FloatTensor(pitchf).unsqueeze(0).to(device) | |
| p_len_tensor = torch.LongTensor([p_len]).to(device) | |
| sid = torch.LongTensor([0]).to(device) | |
| progress(0.7, "Running inference...") | |
| with torch.no_grad(): | |
| if if_f0 == 1: | |
| audio_out = model.infer(feats[:, :p_len, :], p_len_tensor, pitch, pitchf, sid)[0][0, 0].data.cpu().float().numpy() | |
| else: | |
| audio_out = model.infer(feats[:, :p_len, :], p_len_tensor, sid)[0][0, 0].data.cpu().float().numpy() | |
| # Remove padding from output | |
| t_pad_tgt = int(t_pad * tgt_sr / 16000) | |
| if len(audio_out) > 2 * t_pad_tgt: | |
| audio_out = audio_out[t_pad_tgt:-t_pad_tgt] | |
| # RMS mixing - match volume dynamics of source audio | |
| if volume_envelope != 1.0: | |
| try: | |
| source_at_tgt_sr = librosa.resample(audio, orig_sr=16000, target_sr=tgt_sr) | |
| frame_len = tgt_sr // 2 * 2 | |
| hop_len = tgt_sr // 2 | |
| rms_source = librosa.feature.rms(y=source_at_tgt_sr, frame_length=frame_len, hop_length=hop_len) | |
| rms_output = librosa.feature.rms(y=audio_out, frame_length=frame_len, hop_length=hop_len) | |
| rms_source = F.interpolate( | |
| torch.from_numpy(rms_source).float().unsqueeze(0), | |
| size=audio_out.shape[0], mode="linear" | |
| ).squeeze() | |
| rms_output = F.interpolate( | |
| torch.from_numpy(rms_output).float().unsqueeze(0), | |
| size=audio_out.shape[0], mode="linear" | |
| ).squeeze() | |
| rms_output = torch.maximum(rms_output, torch.zeros_like(rms_output) + 1e-6) | |
| # Applio formula: target * (source^(1-rate) * output^(rate-1)) | |
| audio_out = audio_out * (torch.pow(rms_source, 1 - volume_envelope) * torch.pow(rms_output, volume_envelope - 1)).numpy() | |
| except Exception as e: | |
| logger.warning(f"RMS mixing failed: {e}") | |
| # Final normalization | |
| audio_max = np.abs(audio_out).max() / 0.99 | |
| if audio_max > 1: | |
| audio_out /= audio_max | |
| progress(0.9, "Saving output...") | |
| fd, output_path = tempfile.mkstemp(suffix=".wav") | |
| os.close(fd) | |
| sf.write(output_path, audio_out, tgt_sr) | |
| # Aggressive memory purge after inference — frees glibc arena on Linux | |
| _model_cache.clear() | |
| _cleanup_args = [model, feats, audio_out, audio, audio_pad] | |
| if feats0 is not None: | |
| _cleanup_args.append(feats0) | |
| purge_memory(*_cleanup_args) | |
| return output_path, f"Converted: {version}, sr={tgt_sr}, pitch={pitch_shift:+d}" | |
| except Exception as e: | |
| logger.exception("Conversion failed") | |
| _model_cache.clear() | |
| purge_memory() | |
| return None, f"Error: {str(e)}" | |
| # ============================================================ | |
| # DEFAULT MODEL DOWNLOAD | |
| # ============================================================ | |
| def load_example_model(): | |
| """Download and load the default example model from HuggingFace""" | |
| import shutil | |
| try: | |
| logger.info(f"Downloading example model from {DEFAULT_MODEL_REPO}...") | |
| model_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_MODEL_FILE) | |
| index_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_INDEX_FILE) | |
| # Gradio 6 requires files to be in allowed directories (cwd or /tmp) | |
| # Copy from HF cache to temp directory | |
| temp_dir = tempfile.mkdtemp() | |
| temp_model = os.path.join(temp_dir, DEFAULT_MODEL_FILE) | |
| temp_index = os.path.join(temp_dir, DEFAULT_INDEX_FILE) | |
| shutil.copy2(model_path, temp_model) | |
| shutil.copy2(index_path, temp_index) | |
| return temp_model, temp_index, f"Loaded: {DEFAULT_MODEL_REPO}" | |
| except Exception as e: | |
| logger.exception("Failed to download example model") | |
| return None, None, f"Error: {str(e)}" | |
| # ============================================================ | |
| # BEATRICE V2 MODEL | |
| # ============================================================ | |
| def beatrice_load_audio(file, **kwargs): | |
| """Load audio using soundfile directly (for Beatrice dataset)""" | |
| data, sr = sf.read(file, dtype='float32') | |
| # soundfile returns (samples, channels), convert to torch (channels, samples) | |
| wav = torch.from_numpy(data) | |
| if wav.ndim == 1: | |
| wav = wav.unsqueeze(0) # mono -> (1, samples) | |
| else: | |
| wav = wav.T # (samples, channels) -> (channels, samples) | |
| return wav, sr | |
| class AttrDict(dict): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.__dict__ = self | |
| def dump_params(params: torch.Tensor, f: BinaryIO): | |
| if params is None: | |
| return | |
| if params.dtype == torch.bfloat16: | |
| f.write( | |
| params.detach() | |
| .clone() | |
| .float() | |
| .view(torch.short) | |
| .numpy() | |
| .ravel()[1::2] | |
| .tobytes() | |
| ) | |
| else: | |
| f.write(params.detach().numpy().ravel().tobytes()) | |
| f.flush() | |
| def dump_layer(layer: nn.Module, f: BinaryIO): | |
| dump = partial(dump_params, f=f) | |
| if hasattr(layer, "dump"): | |
| layer.dump(f) | |
| elif isinstance(layer, (nn.Linear, nn.Conv1d, nn.LayerNorm)): | |
| dump(layer.weight) | |
| dump(layer.bias) | |
| elif isinstance(layer, nn.MultiheadAttention): | |
| embed_dim = layer.embed_dim | |
| num_heads = layer.num_heads | |
| # [3 * embed_dim, embed_dim] | |
| in_proj_weight = layer.in_proj_weight.data.clone() | |
| in_proj_weight[: 2 * embed_dim] *= 1.0 / math.sqrt( | |
| math.sqrt(embed_dim // num_heads) | |
| ) | |
| in_proj_weight = in_proj_weight.view( | |
| 3, num_heads, embed_dim // num_heads, embed_dim | |
| ) | |
| # [num_heads, 3, embed_dim / num_heads, embed_dim] | |
| in_proj_weight = in_proj_weight.transpose(0, 1) | |
| # [3 * embed_dim] | |
| in_proj_bias = layer.in_proj_bias.data.clone() | |
| in_proj_bias[: 2 * embed_dim] *= 1.0 / math.sqrt( | |
| math.sqrt(embed_dim // num_heads) | |
| ) | |
| in_proj_bias = in_proj_bias.view(3, num_heads, embed_dim // num_heads) | |
| # [num_heads, 3, embed_dim / num_heads] | |
| in_proj_bias = in_proj_bias.transpose(0, 1) | |
| dump(in_proj_weight) | |
| dump(in_proj_bias) | |
| dump(layer.out_proj.weight) | |
| dump(layer.out_proj.bias) | |
| elif isinstance(layer, nn.Embedding): | |
| dump(layer.weight) | |
| elif isinstance(layer, nn.Parameter): | |
| dump(layer) | |
| elif isinstance(layer, nn.ModuleList): | |
| for layer_i in layer: | |
| dump_layer(layer_i, f) | |
| else: | |
| assert False, layer | |
| class CausalConv1d(nn.Conv1d): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| dilation: int = 1, | |
| groups: int = 1, | |
| bias: bool = True, | |
| delay: int = 0, | |
| ): | |
| padding = (kernel_size - 1) * dilation - delay | |
| self.trim = (kernel_size - 1) * dilation - 2 * delay | |
| if self.trim < 0: | |
| raise ValueError | |
| super().__init__( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias, | |
| ) | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| result = super().forward(input) | |
| if self.trim == 0: | |
| return result | |
| else: | |
| return result[:, :, : -self.trim] | |
| class WSConv1d(CausalConv1d): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| dilation: int = 1, | |
| groups: int = 1, | |
| bias: bool = True, | |
| delay: int = 0, | |
| ): | |
| super().__init__( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias, | |
| delay=delay, | |
| ) | |
| self.weight.data.normal_( | |
| 0.0, math.sqrt(1.0 / (in_channels * kernel_size // groups)) | |
| ) | |
| if bias: | |
| self.bias.data.zero_() | |
| self.gain = nn.Parameter(torch.ones((out_channels, 1, 1))) | |
| def standardized_weight(self) -> torch.Tensor: | |
| var, mean = torch.var_mean(self.weight, [1, 2], keepdim=True) | |
| scale = ( | |
| self.gain | |
| * ( | |
| self.in_channels * self.kernel_size[0] // self.groups * var + 1e-8 | |
| ).rsqrt() | |
| ) | |
| return scale * (self.weight - mean) | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| result = F.conv1d( | |
| input, | |
| self.standardized_weight(), | |
| self.bias, | |
| self.stride, | |
| self.padding, | |
| self.dilation, | |
| self.groups, | |
| ) | |
| if self.trim == 0: | |
| return result | |
| else: | |
| return result[:, :, : -self.trim] | |
| def merge_weights(self): | |
| self.weight.data[:] = self.standardized_weight().detach() | |
| self.gain.data.fill_(1.0) | |
| class WSLinear(nn.Linear): | |
| def __init__(self, in_features: int, out_features: int, bias: bool = True): | |
| super().__init__(in_features, out_features, bias) | |
| self.weight.data.normal_(0.0, math.sqrt(1.0 / in_features)) | |
| self.bias.data.zero_() | |
| self.gain = nn.Parameter(torch.ones((out_features, 1))) | |
| def standardized_weight(self) -> torch.Tensor: | |
| var, mean = torch.var_mean(self.weight, 1, keepdim=True) | |
| scale = self.gain * (self.in_features * var + 1e-8).rsqrt() | |
| return scale * (self.weight - mean) | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return F.linear(input, self.standardized_weight(), self.bias) | |
| def merge_weights(self): | |
| self.weight.data[:] = self.standardized_weight().detach() | |
| self.gain.data.fill_(1.0) | |
| class CrossAttention(nn.Module): | |
| def __init__( | |
| self, | |
| qk_channels: int, | |
| vo_channels: int, | |
| num_heads: int, | |
| in_q_channels: int, | |
| in_kv_channels: int, | |
| out_channels: int, | |
| dropout: float = 0.0, | |
| ): | |
| super().__init__() | |
| assert qk_channels % num_heads == 0 | |
| self.qk_channels = qk_channels | |
| self.vo_channels = vo_channels | |
| self.num_heads = num_heads | |
| self.in_q_channels = in_q_channels | |
| self.in_kv_channels = in_kv_channels | |
| self.out_channels = out_channels | |
| self.dropout = dropout | |
| self.head_qk_channels = qk_channels // num_heads | |
| self.head_vo_channels = vo_channels // num_heads | |
| self.q_projection = nn.Linear(in_q_channels, qk_channels) | |
| self.q_projection.weight.data.normal_(0.0, math.sqrt(1.0 / in_q_channels)) | |
| self.q_projection.bias.data.zero_() | |
| self.kv_projection = nn.Linear(in_kv_channels, qk_channels + vo_channels) | |
| self.kv_projection.weight.data.normal_(0.0, math.sqrt(1.0 / in_kv_channels)) | |
| self.kv_projection.bias.data.zero_() | |
| self.out_projection = nn.Linear(vo_channels, out_channels) | |
| self.out_projection.weight.data.normal_(0.0, math.sqrt(1.0 / vo_channels)) | |
| self.out_projection.bias.data.zero_() | |
| def forward( | |
| self, | |
| q: torch.Tensor, | |
| kv: torch.Tensor, | |
| ) -> torch.Tensor: | |
| # q: [batch_size, q_length, in_q_channels] | |
| # kv: [batch_size, kv_length, in_kv_channels] | |
| batch_size, q_length, _ = q.size() | |
| _, kv_length, _ = kv.size() | |
| # [batch_size, q_length, qk_channels] | |
| q = self.q_projection(q) | |
| # [batch_size, kv_length, qk_channels + vo_channels] | |
| kv = self.kv_projection(kv) | |
| # [batch_size, kv_length, qk_channels], [batch_size, kv_length, vo_channels] | |
| k, v = kv.split([self.qk_channels, self.vo_channels], dim=2) | |
| q = q.view( | |
| batch_size, q_length, self.num_heads, self.head_qk_channels | |
| ).transpose(1, 2) | |
| k = k.view( | |
| batch_size, kv_length, self.num_heads, self.head_qk_channels | |
| ).transpose(1, 2) | |
| v = v.view( | |
| batch_size, kv_length, self.num_heads, self.head_vo_channels | |
| ).transpose(1, 2) | |
| # [batch_size, num_heads, q_length, head_vo_channels] | |
| attn_out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout) | |
| # [batch_size, q_length, vo_channels] | |
| attn_out = ( | |
| attn_out.transpose(1, 2) | |
| .contiguous() | |
| .view(batch_size, q_length, self.vo_channels) | |
| ) | |
| # [batch_size, q_length, out_channels] | |
| attn_out = self.out_projection(attn_out) | |
| return attn_out | |
| def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): | |
| if isinstance(f, (str, bytes, os.PathLike)): | |
| with open(f, "wb") as f: | |
| self.dump(f) | |
| return | |
| if not hasattr(f, "write"): | |
| raise TypeError | |
| q_projection_weight = self.q_projection.weight.data.clone() | |
| q_projection_bias = self.q_projection.bias.data.clone() | |
| q_projection_weight *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) | |
| q_projection_bias *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) | |
| dump_params(q_projection_weight, f) | |
| dump_params(q_projection_bias, f) | |
| dump_layer(self.out_projection, f) | |
| def dump_kv(self, f: Union[BinaryIO, str, bytes, os.PathLike]): | |
| if isinstance(f, (str, bytes, os.PathLike)): | |
| with open(f, "wb") as f: | |
| self.dump_kv(f) | |
| return | |
| if not hasattr(f, "write"): | |
| raise TypeError | |
| kv_projection_weight = self.kv_projection.weight.data.clone() | |
| kv_projection_bias = self.kv_projection.bias.data.clone() | |
| k_projection_weight, v_projection_weight = kv_projection_weight.split( | |
| [self.qk_channels, self.vo_channels] | |
| ) | |
| k_projection_bias, v_projection_bias = kv_projection_bias.split( | |
| [self.qk_channels, self.vo_channels] | |
| ) | |
| k_projection_weight *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) | |
| k_projection_bias *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) | |
| # [qk_channels, in_kv_channels] -> [num_heads, head_qk_channels, in_kv_channels] | |
| k_projection_weight = k_projection_weight.view( | |
| self.num_heads, self.head_qk_channels, self.in_kv_channels | |
| ) | |
| # [qk_channels] -> [num_heads, head_qk_channels] | |
| k_projection_bias = k_projection_bias.view( | |
| self.num_heads, self.head_qk_channels | |
| ) | |
| # [vo_channels, in_kv_channels] -> [num_heads, head_vo_channels, in_kv_channels] | |
| v_projection_weight = v_projection_weight.view( | |
| self.num_heads, self.head_vo_channels, self.in_kv_channels | |
| ) | |
| # [vo_channels] -> [num_heads, head_vo_channels] | |
| v_projection_bias = v_projection_bias.view( | |
| self.num_heads, self.head_vo_channels | |
| ) | |
| for i in range(self.num_heads): | |
| # [head_qk_channels, in_kv_channels] | |
| dump_params(k_projection_weight[i], f) | |
| # [head_vo_channels, in_kv_channels] | |
| dump_params(v_projection_weight[i], f) | |
| for i in range(self.num_heads): | |
| # [head_qk_channels] | |
| dump_params(k_projection_bias[i], f) | |
| # [head_vo_channels] | |
| dump_params(v_projection_bias[i], f) | |
| class ConvNeXtBlock(nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| intermediate_channels: int, | |
| layer_scale_init_value: float, | |
| kernel_size: int = 7, | |
| use_weight_standardization: bool = False, | |
| enable_scaling: bool = False, | |
| pre_scale: float = 1.0, | |
| post_scale: float = 1.0, | |
| use_mha: bool = False, | |
| cross_attention: bool = False, | |
| num_heads: int = 4, | |
| attention_dropout: float = 0.1, | |
| attention_channels: Optional[int] = None, | |
| kv_channels: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| self.use_weight_standardization = use_weight_standardization | |
| self.enable_scaling = enable_scaling | |
| self.use_mha = use_mha | |
| self.cross_attention = cross_attention | |
| if use_mha: | |
| self.attn_norm = nn.LayerNorm(channels) | |
| if cross_attention: | |
| self.mha = CrossAttention( | |
| qk_channels=attention_channels, | |
| vo_channels=attention_channels, | |
| num_heads=num_heads, | |
| in_q_channels=channels, | |
| in_kv_channels=kv_channels, | |
| out_channels=channels, | |
| dropout=attention_dropout, | |
| ) | |
| else: # self-attention | |
| assert attention_channels is None | |
| assert kv_channels is None | |
| self.mha = nn.MultiheadAttention( | |
| embed_dim=channels, | |
| num_heads=num_heads, | |
| dropout=attention_dropout, | |
| batch_first=True, | |
| ) | |
| self.dwconv = CausalConv1d( | |
| channels, channels, kernel_size=kernel_size, groups=channels | |
| ) | |
| self.norm = nn.LayerNorm(channels) | |
| self.pwconv1 = nn.Linear(channels, intermediate_channels) | |
| self.pwconv2 = nn.Linear(intermediate_channels, channels) | |
| self.gamma = nn.Parameter(torch.full((channels,), layer_scale_init_value)) | |
| self.dwconv.weight.data.normal_(0.0, math.sqrt(1.0 / kernel_size)) | |
| self.dwconv.bias.data.zero_() | |
| self.pwconv1.weight.data.normal_(0.0, math.sqrt(2.0 / channels)) | |
| self.pwconv1.bias.data.zero_() | |
| self.pwconv2.weight.data.normal_(0.0, math.sqrt(1.0 / intermediate_channels)) | |
| self.pwconv2.bias.data.zero_() | |
| if use_weight_standardization: | |
| self.norm = nn.Identity() | |
| self.dwconv = WSConv1d(channels, channels, kernel_size, groups=channels) | |
| self.pwconv1 = WSLinear(channels, intermediate_channels) | |
| self.pwconv2 = WSLinear(intermediate_channels, channels) | |
| del self.gamma | |
| if enable_scaling: | |
| self.register_buffer("pre_scale", torch.tensor(pre_scale)) | |
| self.register_buffer("post_scale", torch.tensor(post_scale)) | |
| self.post_scale_weight = nn.Parameter(torch.ones(())) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| kv: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| if self.use_mha: | |
| batch_size, channels, length = x.size() | |
| if self.cross_attention: | |
| assert kv is not None | |
| else: | |
| assert kv is None | |
| assert length % 4 == 0 | |
| identity = x | |
| if self.cross_attention: | |
| # kv: [batch_size, kv_length, kv_channels] | |
| x = x.transpose(1, 2) | |
| x = self.attn_norm(x) | |
| x = self.mha(x, kv) | |
| x = x.transpose(1, 2) | |
| else: | |
| x = x.view(batch_size, channels, length // 4, 4) | |
| x = x.permute(0, 3, 2, 1) | |
| x = x.reshape(batch_size * 4, length // 4, channels) | |
| x = self.attn_norm(x) | |
| x, _ = self.mha( | |
| x, x, x, attn_mask=attn_mask, is_causal=True, need_weights=False | |
| ) | |
| x = x.view(batch_size, 4, length // 4, channels) | |
| x = x.permute(0, 3, 2, 1) | |
| x = x.reshape(batch_size, channels, length) | |
| x += identity | |
| identity = x | |
| if self.enable_scaling: | |
| x = x * self.pre_scale | |
| x = self.dwconv(x) | |
| x = x.transpose(1, 2) | |
| x = self.norm(x) | |
| x = self.pwconv1(x) | |
| x = F.gelu(x, approximate="tanh") | |
| x = self.pwconv2(x) | |
| if not self.use_weight_standardization: | |
| x *= self.gamma | |
| if self.enable_scaling: | |
| x *= self.post_scale * self.post_scale_weight | |
| x = x.transpose(1, 2) | |
| x += identity | |
| return x | |
| def merge_weights(self): | |
| if self.use_mha: | |
| if self.cross_attention: | |
| assert isinstance(self.mha, CrossAttention) | |
| self.mha.q_projection.bias.data += torch.mv( | |
| self.mha.q_projection.weight.data, self.attn_norm.bias.data | |
| ) | |
| self.mha.q_projection.weight.data *= self.attn_norm.weight.data[None, :] | |
| self.attn_norm.bias.data[:] = 0.0 | |
| self.attn_norm.weight.data[:] = 1.0 | |
| else: # self-attention | |
| assert isinstance(self.mha, nn.MultiheadAttention) | |
| self.mha.in_proj_bias.data += torch.mv( | |
| self.mha.in_proj_weight.data, self.attn_norm.bias.data | |
| ) | |
| self.mha.in_proj_weight.data *= self.attn_norm.weight.data[None, :] | |
| self.attn_norm.bias.data[:] = 0.0 | |
| self.attn_norm.weight.data[:] = 1.0 | |
| if self.use_weight_standardization: | |
| self.dwconv.merge_weights() | |
| self.pwconv1.merge_weights() | |
| self.pwconv2.merge_weights() | |
| else: | |
| self.pwconv1.bias.data += torch.mv( | |
| self.pwconv1.weight.data, self.norm.bias.data | |
| ) | |
| self.pwconv1.weight.data *= self.norm.weight.data[None, :] | |
| self.norm.bias.data[:] = 0.0 | |
| self.norm.weight.data[:] = 1.0 | |
| self.pwconv2.weight.data *= self.gamma.data[:, None] | |
| self.pwconv2.bias.data *= self.gamma.data | |
| self.gamma.data[:] = 1.0 | |
| if self.enable_scaling: | |
| self.dwconv.weight.data *= self.pre_scale.data | |
| self.pre_scale.data.fill_(1.0) | |
| self.pwconv2.weight.data *= ( | |
| self.post_scale.data * self.post_scale_weight.data | |
| ) | |
| self.pwconv2.bias.data *= self.post_scale.data * self.post_scale_weight.data | |
| self.post_scale.data.fill_(1.0) | |
| self.post_scale_weight.data.fill_(1.0) | |
| def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): | |
| if isinstance(f, (str, bytes, os.PathLike)): | |
| with open(f, "wb") as f: | |
| self.dump(f) | |
| return | |
| if not hasattr(f, "write"): | |
| raise TypeError | |
| if self.use_mha: | |
| dump_layer(self.mha, f) | |
| dump_layer(self.dwconv, f) | |
| dump_layer(self.pwconv1, f) | |
| dump_layer(self.pwconv2, f) | |
| class ConvNeXtStack(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| channels: int, | |
| intermediate_channels: int, | |
| n_blocks: int, | |
| delay: int, | |
| embed_kernel_size: int, | |
| kernel_size: int, | |
| use_weight_standardization: bool = False, | |
| enable_scaling: bool = False, | |
| use_mha: bool = False, | |
| cross_attention: bool = False, | |
| kv_channels: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| assert delay * 2 + 1 <= embed_kernel_size | |
| assert not (use_weight_standardization and use_mha) # 未対応 | |
| self.use_weight_standardization = use_weight_standardization | |
| self.use_mha = use_mha | |
| self.cross_attention = cross_attention | |
| self.embed = CausalConv1d(in_channels, channels, embed_kernel_size, delay=delay) | |
| self.norm = nn.LayerNorm(channels) | |
| self.convnext = nn.ModuleList() | |
| for i in range(n_blocks): | |
| pre_scale = 1.0 / math.sqrt(1.0 + i / n_blocks) if enable_scaling else 1.0 | |
| post_scale = 1.0 / math.sqrt(n_blocks) if enable_scaling else 1.0 | |
| block = ConvNeXtBlock( | |
| channels=channels, | |
| intermediate_channels=intermediate_channels, | |
| layer_scale_init_value=1.0 / n_blocks, | |
| kernel_size=kernel_size, | |
| use_weight_standardization=use_weight_standardization, | |
| enable_scaling=enable_scaling, | |
| pre_scale=pre_scale, | |
| post_scale=post_scale, | |
| use_mha=use_mha, | |
| cross_attention=cross_attention, | |
| num_heads=4, | |
| attention_dropout=0.1, | |
| attention_channels=kv_channels, | |
| kv_channels=kv_channels, | |
| ) | |
| self.convnext.append(block) | |
| self.final_layer_norm = nn.LayerNorm(channels) | |
| self.embed.weight.data.normal_( | |
| 0.0, math.sqrt(0.5 / (embed_kernel_size * in_channels)) | |
| ) | |
| self.embed.bias.data.zero_() | |
| if use_weight_standardization: | |
| self.embed = WSConv1d(in_channels, channels, embed_kernel_size, delay=delay) | |
| self.norm = nn.Identity() | |
| self.final_layer_norm = nn.Identity() | |
| def forward( | |
| self, x: torch.Tensor, kv: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| x = self.embed(x) | |
| x = self.norm(x.transpose(1, 2)).transpose(1, 2) | |
| if self.use_mha and not self.cross_attention: | |
| pad_length = -x.size(2) % 4 | |
| if pad_length: | |
| x = F.pad(x, (0, pad_length)) | |
| t40 = x.size(2) // 4 | |
| attn_mask = torch.ones((t40, t40), dtype=torch.bool, device=x.device).triu( | |
| 1 | |
| ) | |
| else: | |
| attn_mask = None | |
| for conv_block in self.convnext: | |
| x = conv_block(x, attn_mask=attn_mask, kv=kv) | |
| if self.use_mha and not self.cross_attention and pad_length: | |
| x = x[:, :, :-pad_length] | |
| x = self.final_layer_norm(x.transpose(1, 2)).transpose(1, 2) | |
| return x | |
| def merge_weights(self): | |
| if self.use_weight_standardization: | |
| self.embed.merge_weights() | |
| for conv_block in self.convnext: | |
| conv_block.merge_weights() | |
| def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): | |
| if isinstance(f, (str, bytes, os.PathLike)): | |
| with open(f, "wb") as f: | |
| self.dump(f) | |
| return | |
| if not hasattr(f, "write"): | |
| raise TypeError | |
| dump_layer(self.embed, f) | |
| if not self.use_weight_standardization: | |
| dump_layer(self.norm, f) | |
| dump_layer(self.convnext, f) | |
| if not self.use_weight_standardization: | |
| dump_layer(self.final_layer_norm, f) | |
| def dump_kv(self, f: Union[BinaryIO, str, bytes, os.PathLike]): | |
| if isinstance(f, (str, bytes, os.PathLike)): | |
| with open(f, "wb") as f: | |
| self.dump_kv(f) | |
| return | |
| if not hasattr(f, "write"): | |
| raise TypeError | |
| assert self.use_mha and self.cross_attention | |
| for conv_block in self.convnext: | |
| if not conv_block.use_mha or not conv_block.cross_attention: | |
| continue | |
| assert isinstance(conv_block, ConvNeXtBlock) | |
| assert hasattr(conv_block, "mha") | |
| assert isinstance(conv_block.mha, CrossAttention) | |
| conv_block.mha.dump_kv(f) | |
| class FeatureExtractor(nn.Module): | |
| def __init__(self, hidden_channels: int): | |
| super().__init__() | |
| # fmt: off | |
| self.conv0 = weight_norm(nn.Conv1d(1, hidden_channels // 8, 10, 5, bias=False)) | |
| self.conv1 = weight_norm(nn.Conv1d(hidden_channels // 8, hidden_channels // 4, 3, 2, bias=False)) | |
| self.conv2 = weight_norm(nn.Conv1d(hidden_channels // 4, hidden_channels // 2, 3, 2, bias=False)) | |
| self.conv3 = weight_norm(nn.Conv1d(hidden_channels // 2, hidden_channels, 3, 2, bias=False)) | |
| self.conv4 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 3, 2, bias=False)) | |
| self.conv5 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 2, 2, bias=False)) | |
| # fmt: on | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # x: [batch_size, 1, wav_length] | |
| wav_length = x.size(2) | |
| if wav_length % 160 != 0: | |
| warnings.warn("wav_length % 160 != 0") | |
| x = F.pad(x, (40, 40)) | |
| x = F.gelu(self.conv0(x), approximate="tanh") | |
| x = F.gelu(self.conv1(x), approximate="tanh") | |
| x = F.gelu(self.conv2(x), approximate="tanh") | |
| x = F.gelu(self.conv3(x), approximate="tanh") | |
| x = F.gelu(self.conv4(x), approximate="tanh") | |
| x = F.gelu(self.conv5(x), approximate="tanh") | |
| # [batch_size, hidden_channels, wav_length / 160] | |
| return x | |
| def remove_weight_norm(self): | |
| remove_weight_norm(self.conv0) | |
| remove_weight_norm(self.conv1) | |
| remove_weight_norm(self.conv2) | |
| remove_weight_norm(self.conv3) | |
| remove_weight_norm(self.conv4) | |
| remove_weight_norm(self.conv5) | |
| def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): | |
| if isinstance(f, (str, bytes, os.PathLike)): | |
| with open(f, "wb") as f: | |
| self.dump(f) | |
| return | |
| if not hasattr(f, "write"): | |
| raise TypeError | |
| dump_layer(self.conv0, f) | |
| dump_layer(self.conv1, f) | |
| dump_layer(self.conv2, f) | |
| dump_layer(self.conv3, f) | |
| dump_layer(self.conv4, f) | |
| dump_layer(self.conv5, f) | |
| class FeatureProjection(nn.Module): | |
| def __init__(self, channels: int): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(channels) | |
| self.dropout = nn.Dropout(0.1) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # [batch_size, channels, length] | |
| x = self.norm(x.transpose(1, 2)).transpose(1, 2) | |
| x = self.dropout(x) | |
| return x | |
| class PhoneExtractor(nn.Module): | |
| def __init__( | |
| self, | |
| phone_channels: int = 128, | |
| hidden_channels: int = 128, | |
| backbone_embed_kernel_size: int = 9, | |
| kernel_size: int = 17, | |
| n_blocks: int = 20, | |
| ): | |
| super().__init__() | |
| self.feature_extractor = FeatureExtractor(hidden_channels) | |
| self.feature_projection = FeatureProjection(hidden_channels) | |
| self.backbone = ConvNeXtStack( | |
| in_channels=hidden_channels, | |
| channels=hidden_channels, | |
| intermediate_channels=hidden_channels * 3, | |
| n_blocks=n_blocks, | |
| delay=0, | |
| embed_kernel_size=backbone_embed_kernel_size, | |
| kernel_size=kernel_size, | |
| use_mha=True, | |
| ) | |
| self.head = weight_norm(nn.Conv1d(hidden_channels, phone_channels, 1)) | |
| def forward( | |
| self, x: torch.Tensor, return_stats: bool = True | |
| ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]: | |
| # x: [batch_size, 1, wav_length] | |
| stats = {} | |
| # [batch_size, 1, wav_length] -> [batch_size, feature_extractor_hidden_channels, length] | |
| x = self.feature_extractor(x) | |
| if return_stats: | |
| stats["feature_norm"] = x.detach().norm(dim=1).mean() | |
| # [batch_size, feature_extractor_hidden_channels, length] -> [batch_size, hidden_channels, length] | |
| x = self.feature_projection(x) | |
| # [batch_size, hidden_channels, length] | |
| x = self.backbone(x) | |
| # [batch_size, hidden_channels, length] -> [batch_size, phone_channels, length] | |
| phone = self.head(F.gelu(x, approximate="tanh")) | |
| results = [phone] | |
| if return_stats: | |
| stats["code_norm"] = phone.detach().norm(dim=1).mean() | |
| results.append(stats) | |
| if len(results) == 1: | |
| return results[0] | |
| return tuple(results) | |
| def units(self, x: torch.Tensor) -> torch.Tensor: | |
| # x: [batch_size, 1, wav_length] | |
| # [batch_size, 1, wav_length] -> [batch_size, phone_channels, length] | |
| phone = self.forward(x, return_stats=False) | |
| # [batch_size, phone_channels, length] -> [batch_size, length, phone_channels] | |
| phone = phone.transpose(1, 2) | |
| # [batch_size, length, phone_channels] | |
| return phone | |
| def remove_weight_norm(self): | |
| self.feature_extractor.remove_weight_norm() | |
| remove_weight_norm(self.head) | |
| def merge_weights(self): | |
| self.backbone.merge_weights() | |
| self.backbone.embed.bias.data += ( | |
| ( | |
| self.feature_projection.norm.bias.data[None, :, None] | |
| * self.backbone.embed.weight.data # [o, i, k] | |
| ) | |
| .sum(1) | |
| .sum(1) | |
| ) | |
| self.backbone.embed.weight.data *= self.feature_projection.norm.weight.data[ | |
| None, :, None | |
| ] | |
| self.feature_projection.norm.bias.data[:] = 0.0 | |
| self.feature_projection.norm.weight.data[:] = 1.0 | |
| def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): | |
| if isinstance(f, (str, bytes, os.PathLike)): | |
| with open(f, "wb") as f: | |
| self.dump(f) | |
| return | |
| if not hasattr(f, "write"): | |
| raise TypeError | |
| dump_layer(self.feature_extractor, f) | |
| dump_layer(self.backbone, f) | |
| dump_layer(self.head, f) | |
| class VectorQuantizer(nn.Module): | |
| def __init__( | |
| self, | |
| n_speakers: int, | |
| codebook_size: int, | |
| channels: int, | |
| topk: int = 4, | |
| training_time_vq: Literal["none", "self", "random"] = "none", | |
| ): | |
| super().__init__() | |
| assert 1 <= topk <= codebook_size | |
| self.n_speakers = n_speakers | |
| self.codebook_size = codebook_size | |
| self.channels = channels | |
| self.topk = topk | |
| self.training_time_vq = training_time_vq | |
| self.register_buffer( | |
| "codebooks", | |
| torch.empty(n_speakers, codebook_size, channels, dtype=torch.half), | |
| ) | |
| self.codebooks: torch.Tensor | |
| # VQ の適用箇所を変更しやすいように hook にしている | |
| self._hook_handle: Optional[torch.utils.hooks.RemovableHandle] = None | |
| self.target_speaker_ids: Optional[torch.Tensor] = None | |
| def _hook(_, __, output): | |
| return self(output, self.target_speaker_ids) | |
| self._hook_fn = _hook | |
| def build_codebooks( | |
| self, | |
| collector_func: Callable, | |
| target_layer: nn.Module, | |
| inputs: Sequence[Iterable[torch.Tensor]], | |
| kmeans_n_iters: int = 50, | |
| ): | |
| assert len(inputs) == self.n_speakers | |
| assert self._hook_handle is None, "hook already installed" | |
| device = next(self.buffers()).device | |
| for spk_id, inps in enumerate(tqdm(inputs, desc="Building codebooks")): | |
| activations: list[torch.Tensor] = [] | |
| # TODO: データ多すぎる場合に間引く処理をする | |
| def _collect(_, __, output): | |
| # output: [batch_size, channels, length] | |
| activations.append(output.detach()) | |
| handle = target_layer.register_forward_hook(_collect) | |
| for x in inps: | |
| collector_func(x.to(device)) | |
| handle.remove() | |
| if not activations: | |
| raise RuntimeError(f"No activation collected for speaker {spk_id}") | |
| # [n_data, channels] | |
| activations: torch.Tensor = torch.cat( | |
| [ | |
| a.transpose(1, 2).reshape(a.size(0) * a.size(2), self.channels) | |
| for a in activations | |
| ] | |
| ) | |
| activations = activations.float() | |
| activations = F.normalize(activations, dim=1, eps=1e-6) | |
| # [codebook_size, channels] | |
| centers = ( | |
| self._kmeans_plus_plus(activations, self.codebook_size, kmeans_n_iters) | |
| if activations.size(0) >= self.codebook_size | |
| else self._pad_replicate(activations, self.codebook_size) | |
| ) | |
| self.codebooks[spk_id] = centers.to(self.codebooks.dtype) | |
| def forward( | |
| self, x: torch.Tensor, speaker_ids: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| batch_size, channels, length = x.size() | |
| assert channels == self.channels | |
| device = x.device | |
| dtype = x.dtype | |
| if self.training: | |
| if self.training_time_vq == "none": | |
| return x | |
| elif self.training_time_vq == "self": | |
| if self.target_speaker_ids is None: | |
| raise ValueError("target_speaker_ids is not set") | |
| elif self.training_time_vq == "random": | |
| speaker_ids = torch.randint( | |
| 0, self.n_speakers, (batch_size,), device=device | |
| ) | |
| else: | |
| raise ValueError(f"Unknown training_time_vq: {self.training_time_vq}") | |
| else: | |
| if speaker_ids is None: | |
| return x | |
| speaker_ids = speaker_ids.to(device) | |
| # [batch_size, channels, length] → [batch_size, length, channels] | |
| q = F.normalize(x, dim=1, eps=1e-6) | |
| codes = self.codebooks[speaker_ids].to(q.dtype) | |
| # [batch_size, length, codebook_size] | |
| sim = torch.einsum("bcl,bkc->blk", q, codes) | |
| # [batch_size, length, topk] | |
| _, topk_idx = sim.topk(self.topk, dim=-1) | |
| # [batch_size, length, codebook_size, channels] | |
| expanded_codes = codes[:, None, :, :].expand(-1, length, -1, -1) | |
| # [batch_size, length, topk, channels] | |
| expanded_topk_idx = topk_idx[:, :, :, None].expand(-1, -1, -1, channels) | |
| # [batch_size, length, topk, channels] | |
| gathered = expanded_codes.gather(2, expanded_topk_idx) | |
| # [batch_size, length, channels] | |
| gathered = gathered.mean(2) | |
| # [batch_size, channels, length] | |
| return gathered.transpose(1, 2).to(dtype) | |
| def enable_hook(self, target_layer: nn.Module): | |
| if self._hook_handle is not None: | |
| raise RuntimeError("hook already installed") | |
| self._hook_handle = target_layer.register_forward_hook(self._hook_fn) | |
| def disable_hook(self): | |
| if self._hook_handle is None: | |
| raise RuntimeError("hook not installed") | |
| self._hook_handle.remove() | |
| self._hook_handle = None | |
| def set_target_speaker_ids(self, speaker_ids: Optional[torch.Tensor]): | |
| # この話者が使われる条件は forward() を参照 | |
| self.target_speaker_ids = speaker_ids | |
| def _pad_replicate(x: torch.Tensor, n: int) -> torch.Tensor: | |
| # データ数が n に満たないとき適当に複製して埋める | |
| idx = torch.arange(n, device=x.device) % x.size(0) | |
| return x[idx] | |
| def _kmeans_plus_plus( | |
| x: torch.Tensor, n_clusters: int, n_iters: int = 50 | |
| ) -> torch.Tensor: | |
| n_data, _ = x.size() | |
| center_indices = [torch.randint(0, n_data, ()).item()] | |
| min_distances = torch.full((n_data,), math.inf, device=x.device) | |
| for _ in range(1, n_clusters): | |
| last_center_index = center_indices[-1] | |
| min_distances = min_distances.minimum( | |
| torch.cdist(x, x[last_center_index : last_center_index + 1]) | |
| .float() | |
| .square_() | |
| .squeeze_(1) | |
| ) | |
| probs = min_distances / (min_distances.sum() + 1e-12) | |
| center_indices.append(torch.multinomial(probs, 1).item()) | |
| centers = x[center_indices] | |
| del min_distances, probs | |
| for _ in range(n_iters): | |
| distances = torch.cdist(x, centers) # [n_data, n_clusters] | |
| labels = distances.argmin(1) # [n_data] | |
| # [n_clusters, dim] | |
| new_centers = torch.zeros_like(centers).index_add_(0, labels, x) | |
| # [n_clusters] | |
| counts = labels.bincount(minlength=n_clusters) | |
| if (counts == 0).sum().item() != 0: | |
| # TODO: 割り当てがないクラスタの処理 | |
| warnings.warn("Some clusters have no assigned data points.") | |
| new_centers /= counts[:, None].clamp_(min=1).float() | |
| centers = new_centers | |
| return centers | |
| def extract_pitch_features( | |
| y: torch.Tensor, # [..., wav_length] | |
| hop_length: int = 160, # 10ms | |
| win_length: int = 560, # 35ms | |
| max_corr_period: int = 256, # 16ms, 62.5Hz (16000 / 256) | |
| corr_win_length: int = 304, # 19ms | |
| instfreq_features_cutoff_bin: int = 64, # 1828Hz (16000 * 64 / 560) | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| assert max_corr_period + corr_win_length == win_length | |
| # パディングする | |
| padding_length = (win_length - hop_length) // 2 | |
| y = F.pad(y, (padding_length, padding_length)) | |
| # フレームにする | |
| # [..., win_length, n_frames] | |
| y_frames = y.unfold(-1, win_length, hop_length).transpose_(-2, -1) | |
| # 複素スペクトログラム | |
| # Complex[..., (win_length // 2 + 1), n_frames] | |
| spec: torch.Tensor = torch.fft.rfft(y_frames, n=win_length, dim=-2) | |
| # Complex[..., instfreq_features_cutoff_bin, n_frames] | |
| spec = spec[..., :instfreq_features_cutoff_bin, :] | |
| # 対数パワースペクトログラム | |
| log_power_spec = spec.abs().add_(1e-5).log10_() | |
| # 瞬時位相の時間差分 | |
| # 時刻 0 の値は 0 | |
| delta_spec = spec[..., :, 1:] * spec[..., :, :-1].conj() | |
| delta_spec /= delta_spec.abs().add_(1e-5) | |
| delta_spec = torch.cat( | |
| [torch.zeros_like(delta_spec[..., :, :1]), delta_spec], dim=-1 | |
| ) | |
| # [..., instfreq_features_cutoff_bin * 3, n_frames] | |
| instfreq_features = torch.cat( | |
| [log_power_spec, delta_spec.real, delta_spec.imag], dim=-2 | |
| ) | |
| # 自己相関 | |
| # 元々これに 2.0 / corr_win_length を掛けて使おうと思っていたが、 | |
| # この値は振幅の 2 乗に比例していて、NN に入力するために良い感じに分散を | |
| # 標準化する方法が思いつかなかったのでやめた | |
| flipped_y_frames = y_frames.flip((-2,)) | |
| a = torch.fft.rfft(flipped_y_frames, n=win_length, dim=-2) | |
| b = torch.fft.rfft(y_frames[..., -corr_win_length:, :], n=win_length, dim=-2) | |
| # [..., max_corr_period, n_frames] | |
| corr = torch.fft.irfft(a * b, n=win_length, dim=-2)[..., corr_win_length:, :] | |
| # エネルギー項 | |
| energy = flipped_y_frames.square_().cumsum_(-2) | |
| energy0 = energy[..., corr_win_length - 1 : corr_win_length, :] | |
| energy = energy[..., corr_win_length:, :] - energy[..., :-corr_win_length, :] | |
| # Difference function | |
| corr_diff = (energy0 + energy).sub_(corr.mul_(2.0)) | |
| assert corr_diff.min() >= -1e-3, corr_diff.min() | |
| corr_diff.clamp_(min=0.0) # 計算誤差対策 | |
| # 標準化 | |
| corr_diff *= 2.0 / corr_win_length | |
| corr_diff.sqrt_() | |
| # 変換モデルへの入力用のエネルギー | |
| energy = ( | |
| (y_frames * torch.signal.windows.cosine(win_length, device=y.device)[..., None]) | |
| .square_() | |
| .sum(-2, keepdim=True) | |
| ) | |
| energy.clamp_(min=1e-3).log10_() # >= -3, 振幅 1 の正弦波なら大体 2.15 | |
| energy *= 0.5 # >= -1.5, 振幅 1 の正弦波なら大体 1.07, 1 の差は振幅で 20dB の差 | |
| return ( | |
| instfreq_features, # [..., instfreq_features_cutoff_bin * 3, n_frames] | |
| corr_diff, # [..., max_corr_period, n_frames] | |
| energy, # [..., 1, n_frames] | |
| ) | |
| class PitchEstimator(nn.Module): | |
| def __init__( | |
| self, | |
| input_instfreq_channels: int = 192, | |
| input_corr_channels: int = 256, | |
| pitch_bins: int = 448, | |
| channels: int = 192, | |
| intermediate_channels: int = 192 * 2, | |
| n_blocks: int = 9, | |
| delay: int = 1, # 10ms, 特徴抽出と合わせると 22.5ms | |
| embed_kernel_size: int = 3, | |
| kernel_size: int = 33, | |
| pitch_bins_per_octave: int = 96, | |
| ): | |
| super().__init__() | |
| self.pitch_bins_per_octave = pitch_bins_per_octave | |
| self.instfreq_embed_0 = nn.Conv1d(input_instfreq_channels, channels, 1) | |
| self.instfreq_embed_1 = nn.Conv1d(channels, channels, 1) | |
| self.corr_embed_0 = nn.Conv1d(input_corr_channels, channels, 1) | |
| self.corr_embed_1 = nn.Conv1d(channels, channels, 1) | |
| self.backbone = ConvNeXtStack( | |
| channels, | |
| channels, | |
| intermediate_channels, | |
| n_blocks, | |
| delay, | |
| embed_kernel_size, | |
| kernel_size, | |
| enable_scaling=True, | |
| ) | |
| self.head = nn.Conv1d(channels, pitch_bins, 1) | |
| def forward(self, wav: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| # wav: [batch_size, 1, wav_length] | |
| # [batch_size, input_instfreq_channels, length], | |
| # [batch_size, input_corr_channels, length] | |
| with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): | |
| instfreq_features, corr_diff, energy = extract_pitch_features( | |
| wav.squeeze(1), | |
| hop_length=160, | |
| win_length=560, | |
| max_corr_period=256, | |
| corr_win_length=304, | |
| instfreq_features_cutoff_bin=64, | |
| ) | |
| instfreq_features = F.gelu( | |
| self.instfreq_embed_0(instfreq_features), approximate="tanh" | |
| ) | |
| instfreq_features = self.instfreq_embed_1(instfreq_features) | |
| corr_diff = F.gelu(self.corr_embed_0(corr_diff), approximate="tanh") | |
| corr_diff = self.corr_embed_1(corr_diff) | |
| # [batch_size, channels, length] | |
| x = F.gelu(instfreq_features + corr_diff, approximate="tanh") | |
| x = self.backbone(x) | |
| # [batch_size, pitch_bins, length] | |
| x = self.head(x) | |
| return x, energy | |
| def sample_pitch( | |
| self, pitch: torch.Tensor, band_width: int = 4, return_features: bool = False | |
| ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: | |
| # pitch: [batch_size, pitch_bins, length] | |
| # 返されるピッチの値には 0 は含まれない | |
| batch_size, pitch_bins, length = pitch.size() | |
| pitch = pitch.softmax(1) | |
| if return_features: | |
| unvoiced_proba = pitch[:, :1, :].clone() | |
| pitch[:, 0, :] = -100.0 | |
| pitch = ( | |
| pitch.transpose(1, 2).contiguous().view(batch_size * length, 1, pitch_bins) | |
| ) | |
| band_pitch = F.conv1d( | |
| pitch, | |
| torch.ones((1, 1, 1), device=pitch.device).expand(1, 1, band_width), | |
| ) | |
| # [batch_size * length, 1, pitch_bins - band_width + 1] -> Long[batch_size * length, 1] | |
| quantized_band_pitch = band_pitch.argmax(2) | |
| if return_features: | |
| # [batch_size * length, 1] | |
| band_proba = band_pitch.gather(2, quantized_band_pitch[:, :, None]) | |
| # [batch_size * length, 1] | |
| half_pitch_band_proba = band_pitch.gather( | |
| 2, | |
| (quantized_band_pitch - self.pitch_bins_per_octave).clamp_(min=1)[ | |
| :, :, None | |
| ], | |
| ) | |
| half_pitch_band_proba[ | |
| quantized_band_pitch <= self.pitch_bins_per_octave | |
| ] = 0.0 | |
| half_pitch_proba = (half_pitch_band_proba / (band_proba + 1e-6)).view( | |
| batch_size, 1, length | |
| ) | |
| # [batch_size * length, 1] | |
| double_pitch_band_proba = band_pitch.gather( | |
| 2, | |
| (quantized_band_pitch + self.pitch_bins_per_octave).clamp_( | |
| max=pitch_bins - band_width | |
| )[:, :, None], | |
| ) | |
| double_pitch_band_proba[ | |
| quantized_band_pitch | |
| > pitch_bins - band_width - self.pitch_bins_per_octave | |
| ] = 0.0 | |
| double_pitch_proba = (double_pitch_band_proba / (band_proba + 1e-6)).view( | |
| batch_size, 1, length | |
| ) | |
| # Long[1, pitch_bins] | |
| mask = torch.arange(pitch_bins, device=pitch.device)[None, :] | |
| # bool[batch_size * length, pitch_bins] | |
| mask = (quantized_band_pitch <= mask) & ( | |
| mask < quantized_band_pitch + band_width | |
| ) | |
| # Long[batch_size, length] | |
| quantized_pitch = (pitch.squeeze(1) * mask).argmax(1).view(batch_size, length) | |
| if return_features: | |
| features = torch.cat( | |
| [unvoiced_proba, half_pitch_proba, double_pitch_proba], dim=1 | |
| ) | |
| # Long[batch_size, length], [batch_size, 3, length] | |
| return quantized_pitch, features | |
| else: | |
| return quantized_pitch | |
| def merge_weights(self): | |
| self.backbone.merge_weights() | |
| def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): | |
| if isinstance(f, (str, bytes, os.PathLike)): | |
| with open(f, "wb") as f: | |
| self.dump(f) | |
| return | |
| if not hasattr(f, "write"): | |
| raise TypeError | |
| dump_layer(self.instfreq_embed_0, f) | |
| dump_layer(self.instfreq_embed_1, f) | |
| dump_layer(self.corr_embed_0, f) | |
| dump_layer(self.corr_embed_1, f) | |
| dump_layer(self.backbone, f) | |
| dump_layer(self.head, f) | |
| def overlap_add( | |
| ir_amp: torch.Tensor, | |
| ir_phase: torch.Tensor, | |
| window: torch.Tensor, | |
| pitch: torch.Tensor, | |
| hop_length: int = 240, | |
| delay: int = 0, | |
| sr: float = 24000.0, | |
| ) -> torch.Tensor: | |
| batch_size, ir_length, length = ir_amp.size() | |
| ir_length = (ir_length - 1) * 2 | |
| assert ir_phase.size() == ir_amp.size() | |
| assert window.size() == (ir_length,), (window.size(), ir_amp.size()) | |
| assert pitch.size() == (batch_size, length * hop_length) | |
| assert 0 <= delay < ir_length, (delay, ir_length) | |
| # 正規化角周波数 [2π rad] | |
| normalized_freq = pitch / sr | |
| # 初期位相 [2π rad] をランダムに設定 | |
| normalized_freq[:, 0] = torch.rand(batch_size, device=pitch.device) | |
| with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): | |
| phase = (normalized_freq.double().cumsum_(1) % 1.0).float() | |
| # 重ねる箇所を求める | |
| # [n_pitchmarks], [n_pitchmarks] | |
| indices0, indices1 = torch.nonzero(phase[:, :-1] > phase[:, 1:], as_tuple=True) | |
| # 重ねる箇所の小数部分 (位相の遅れ) を求める | |
| numer = 1.0 - phase[indices0, indices1] | |
| # [n_pitchmarks] | |
| fractional_part = numer / (numer + phase[indices0, indices1 + 1]) | |
| # 重ねる値を求める | |
| # Complex[n_pitchmarks, ir_length / 2 + 1] | |
| ir_amp = ir_amp[indices0, :, indices1 // hop_length] | |
| ir_phase = ir_phase[indices0, :, indices1 // hop_length] | |
| # 位相遅れの量 [rad] | |
| # [n_pitchmarks, ir_length / 2 + 1] | |
| delay_phase = ( | |
| torch.arange(ir_length // 2 + 1, device=pitch.device, dtype=torch.float32)[ | |
| None, : | |
| ] | |
| * (-math.tau / ir_length) | |
| * fractional_part[:, None] | |
| ) | |
| # Complex[n_pitchmarks, ir_length / 2 + 1] | |
| spec = torch.polar(ir_amp, ir_phase + delay_phase) | |
| # [n_pitchmarks, ir_length] | |
| ir = torch.fft.irfft(spec, n=ir_length, dim=1) | |
| ir *= window | |
| # 加算する値をサンプル単位にばらす | |
| # [n_pitchmarks * ir_length] | |
| ir = ir.ravel() | |
| # Long[n_pitchmarks * ir_length] | |
| indices0 = indices0[:, None].expand(-1, ir_length).ravel() | |
| # Long[n_pitchmarks * ir_length] | |
| indices1 = ( | |
| indices1[:, None] + torch.arange(ir_length, device=pitch.device) | |
| ).ravel() | |
| # overlap-add する | |
| overlap_added_signal = torch.zeros( | |
| (batch_size, length * hop_length + ir_length), device=pitch.device | |
| ) | |
| overlap_added_signal.index_put_((indices0, indices1), ir, accumulate=True) | |
| overlap_added_signal = overlap_added_signal[:, delay : -ir_length + delay] | |
| return overlap_added_signal | |
| def generate_noise( | |
| aperiodicity: torch.Tensor, delay: int = 0 | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| # aperiodicity: [batch_size, hop_length, length] | |
| batch_size, hop_length, length = aperiodicity.size() | |
| excitation = torch.rand( | |
| batch_size, (length + 1) * hop_length, device=aperiodicity.device | |
| ) | |
| excitation -= 0.5 | |
| n_fft = 2 * hop_length | |
| # 矩形窓で分析 | |
| # Complex[batch_size, hop_length + 1, length] | |
| noise = torch.stft( | |
| excitation, | |
| n_fft=n_fft, | |
| hop_length=hop_length, | |
| window=torch.ones(n_fft, device=excitation.device), | |
| center=False, | |
| return_complex=True, | |
| ) | |
| assert noise.size(2) == aperiodicity.size(2) | |
| noise[:, 0, :] = 0.0 | |
| noise[:, 1:, :] *= aperiodicity | |
| # ハン窓で合成 | |
| # torch.istft は最適合成窓が使われるので使えないことに注意 | |
| # [batch_size, 2 * hop_length, length] | |
| noise = torch.fft.irfft(noise, n=2 * hop_length, dim=1) | |
| noise *= torch.hann_window(2 * hop_length, device=noise.device)[None, :, None] | |
| # [batch_size, (length + 1) * hop_length] | |
| noise = F.fold( | |
| noise, | |
| (1, (length + 1) * hop_length), | |
| (1, 2 * hop_length), | |
| stride=(1, hop_length), | |
| ).squeeze_((1, 2)) | |
| assert delay < hop_length | |
| noise = noise[:, delay : -hop_length + delay] | |
| excitation = excitation[:, delay : -hop_length + delay] | |
| return noise, excitation # [batch_size, length * hop_length] | |
| D4C_PREVENT_ZERO_DIVISION = True # False にすると本家の処理 | |
| def interp(x: torch.Tensor, y: torch.Tensor, xi: torch.Tensor) -> torch.Tensor: | |
| # x が単調増加で等間隔と仮定 | |
| # 外挿は起こらないと仮定 | |
| x = torch.as_tensor(x) | |
| y = torch.as_tensor(y) | |
| xi = torch.as_tensor(xi) | |
| if xi.ndim < y.ndim: | |
| diff_ndim = y.ndim - xi.ndim | |
| xi = xi.view(tuple([1] * diff_ndim) + xi.size()) | |
| if xi.size()[:-1] != y.size()[:-1]: | |
| xi = xi.expand(y.size()[:-1] + (xi.size(-1),)) | |
| assert (x.min(-1).values == x[..., 0]).all() | |
| assert (x.max(-1).values == x[..., -1]).all() | |
| assert (xi.min(-1).values >= x[..., 0]).all() | |
| assert (xi.max(-1).values <= x[..., -1]).all() | |
| delta_x = (x[..., -1].double() - x[..., 0].double()) / (x.size(-1) - 1.0) | |
| delta_x = delta_x.to(x.dtype) | |
| xi = (xi - x[..., :1]).div_(delta_x[..., None]) | |
| xi_base = xi.floor() | |
| xi_fraction = xi.sub_(xi_base) | |
| xi_base = xi_base.long() | |
| delta_y = y.diff(dim=-1, append=y[..., -1:]) | |
| yi = y.gather(-1, xi_base) + delta_y.gather(-1, xi_base) * xi_fraction | |
| return yi | |
| def linear_smoothing( | |
| group_delay: torch.Tensor, sr: int, n_fft: int, width: torch.Tensor | |
| ) -> torch.Tensor: | |
| group_delay = torch.as_tensor(group_delay) | |
| assert group_delay.size(-1) == n_fft // 2 + 1 | |
| width = torch.as_tensor(width) | |
| boundary = (width.max() * n_fft / sr).long() + 1 | |
| dtype = group_delay.dtype | |
| device = group_delay.device | |
| fft_resolution = sr / n_fft | |
| mirroring_freq_axis = ( | |
| torch.arange(-boundary, n_fft // 2 + 1 + boundary, dtype=dtype, device=device) | |
| .add(0.5) | |
| .mul(fft_resolution) | |
| ) | |
| if group_delay.ndim == 1: | |
| mirroring_spec = F.pad( | |
| group_delay[None], (boundary, boundary), mode="reflect" | |
| ).squeeze_(0) | |
| elif group_delay.ndim >= 4: | |
| shape = group_delay.size() | |
| mirroring_spec = F.pad( | |
| group_delay.view(math.prod(shape[:-1]), group_delay.size(-1)), | |
| (boundary, boundary), | |
| mode="reflect", | |
| ).view(shape[:-1] + (shape[-1] + 2 * boundary,)) | |
| else: | |
| mirroring_spec = F.pad(group_delay, (boundary, boundary), mode="reflect") | |
| mirroring_segment = mirroring_spec.mul(fft_resolution).cumsum_(-1) | |
| center_freq = torch.arange(n_fft // 2 + 1, dtype=dtype, device=device).mul_( | |
| fft_resolution | |
| ) | |
| low_freq = center_freq - width[..., None] * 0.5 | |
| high_freq = center_freq + width[..., None] * 0.5 | |
| levels = interp( | |
| mirroring_freq_axis, mirroring_segment, torch.cat([low_freq, high_freq], dim=-1) | |
| ) | |
| low_levels, high_levels = levels.split([n_fft // 2 + 1] * 2, dim=-1) | |
| smoothed = (high_levels - low_levels).div_(width[..., None]) | |
| return smoothed | |
| def dc_correction( | |
| spec: torch.Tensor, sr: int, n_fft: int, f0: torch.Tensor | |
| ) -> torch.Tensor: | |
| spec = torch.as_tensor(spec) | |
| f0 = torch.as_tensor(f0) | |
| dtype = spec.dtype | |
| device = spec.device | |
| upper_limit = 2 + (f0 * (n_fft / sr)).long() | |
| max_upper_limit = upper_limit.max() | |
| upper_limit_mask = ( | |
| torch.arange(max_upper_limit - 1, device=device) < (upper_limit - 1)[..., None] | |
| ) | |
| low_freq_axis = torch.arange(max_upper_limit + 1, dtype=dtype, device=device) * ( | |
| sr / n_fft | |
| ) | |
| low_freq_replica = interp( | |
| f0[..., None] - low_freq_axis.flip(-1), | |
| spec[..., : max_upper_limit + 1].flip(-1), | |
| low_freq_axis[..., : max_upper_limit - 1] * upper_limit_mask, | |
| ) | |
| output = spec.clone() | |
| output[..., : max_upper_limit - 1] += low_freq_replica * upper_limit_mask | |
| return output | |
| def nuttall(n: int, device: torch.types.Device) -> torch.Tensor: | |
| t = torch.linspace(0, math.tau, n, device=device) | |
| coefs = torch.tensor([0.355768, -0.487396, 0.144232, -0.012604], device=device) | |
| terms = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device) | |
| cos_matrix = (terms[:, None] * t).cos_() # [4, n] | |
| window = coefs.matmul(cos_matrix) | |
| return window | |
| def get_windowed_waveform( | |
| x: torch.Tensor, | |
| sr: int, | |
| f0: torch.Tensor, | |
| position: torch.Tensor, | |
| half_window_length_ratio: float, | |
| window_type: Literal["hann", "blackman"], | |
| n_fft: int, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| x = torch.as_tensor(x) | |
| f0 = torch.as_tensor(f0) | |
| position = torch.as_tensor(position) | |
| current_sample = position * sr | |
| # [...] | |
| half_window_length = (half_window_length_ratio * sr / f0).add_(0.5).long() | |
| # [..., fft_size] | |
| base_index = -half_window_length[..., None] + torch.arange(n_fft, device=x.device) | |
| base_index_mask = base_index <= half_window_length[..., None] | |
| # [..., fft_size] | |
| safe_index = ((current_sample + 0.501).long()[..., None] + base_index).clamp_( | |
| 0, x.size(-1) - 1 | |
| ) | |
| # [..., fft_size] | |
| time_axis = base_index.to(x.dtype).div_(half_window_length_ratio) | |
| # [...] | |
| normalized_f0 = math.pi / sr * f0 | |
| # [..., fft_size] | |
| phase = time_axis.mul_(normalized_f0[..., None]) | |
| if window_type == "hann": | |
| window = phase.cos_().mul_(0.5).add_(0.5) | |
| elif window_type == "blackman": | |
| window = phase.mul(2.0).cos_().mul_(0.08).add_(phase.cos().mul_(0.5)).add_(0.42) | |
| else: | |
| assert False | |
| window *= base_index_mask | |
| prefix_shape = tuple( | |
| max(x_size, i_size) for x_size, i_size in zip(x.size(), safe_index.size()) | |
| )[:-1] | |
| waveform = ( | |
| x.expand(prefix_shape + (-1,)) | |
| .gather(-1, safe_index.expand(prefix_shape + (-1,))) | |
| .mul_(window) | |
| ) | |
| if not D4C_PREVENT_ZERO_DIVISION: | |
| waveform += torch.randn_like(window).mul_(1e-12) | |
| waveform *= base_index_mask | |
| waveform -= window * waveform.sum(-1, keepdim=True).div_( | |
| window.sum(-1, keepdim=True) | |
| ) | |
| return waveform, window | |
| def get_centroid(x: torch.Tensor, n_fft: int) -> torch.Tensor: | |
| x = torch.as_tensor(x) | |
| if D4C_PREVENT_ZERO_DIVISION: | |
| x = x / x.norm(dim=-1, keepdim=True).clamp(min=6e-8) | |
| else: | |
| x = x / x.norm(dim=-1, keepdim=True) | |
| spec0 = torch.fft.rfft(x, n_fft) | |
| spec1 = torch.fft.rfft( | |
| x * torch.arange(1, x.size(-1) + 1, dtype=x.dtype, device=x.device).div_(n_fft), | |
| n_fft, | |
| ) | |
| centroid = spec0.real * spec1.real + spec0.imag * spec1.imag | |
| return centroid | |
| def get_static_centroid( | |
| x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int | |
| ) -> torch.Tensor: | |
| """First step: calculation of temporally static parameters on basis of group delay""" | |
| x1, _ = get_windowed_waveform( | |
| x, sr, f0, position + 0.25 / f0, 2.0, "blackman", n_fft | |
| ) | |
| x2, _ = get_windowed_waveform( | |
| x, sr, f0, position - 0.25 / f0, 2.0, "blackman", n_fft | |
| ) | |
| centroid1 = get_centroid(x1, n_fft) | |
| centroid2 = get_centroid(x2, n_fft) | |
| return dc_correction(centroid1 + centroid2, sr, n_fft, f0) | |
| def get_smoothed_power_spec( | |
| x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| x = torch.as_tensor(x) | |
| f0 = torch.as_tensor(f0) | |
| x, window = get_windowed_waveform(x, sr, f0, position, 2.0, "hann", n_fft) | |
| window_weight = window.square().sum(-1, keepdim=True) | |
| rms = x.square().sum(-1, keepdim=True).div_(window_weight).sqrt_() | |
| if D4C_PREVENT_ZERO_DIVISION: | |
| x = x / (rms * math.sqrt(n_fft)).clamp_(min=6e-8) | |
| smoothed_power_spec = torch.fft.rfft(x, n_fft).abs().square_() | |
| smoothed_power_spec = dc_correction(smoothed_power_spec, sr, n_fft, f0) | |
| smoothed_power_spec = linear_smoothing(smoothed_power_spec, sr, n_fft, f0) | |
| return smoothed_power_spec, rms.detach().squeeze(-1) | |
| def get_static_group_delay( | |
| static_centroid: torch.Tensor, | |
| smoothed_power_spec: torch.Tensor, | |
| sr: int, | |
| f0: torch.Tensor, | |
| n_fft: int, | |
| ) -> torch.Tensor: | |
| """Second step: calculation of parameter shaping""" | |
| if D4C_PREVENT_ZERO_DIVISION: | |
| smoothed_power_spec = smoothed_power_spec.clamp(min=6e-8) | |
| static_group_delay = static_centroid / smoothed_power_spec # t_g | |
| static_group_delay = linear_smoothing( | |
| static_group_delay, sr, n_fft, f0 * 0.5 | |
| ) # t_gs | |
| smoothed_group_delay = linear_smoothing(static_group_delay, sr, n_fft, f0) # t_gb | |
| static_group_delay = static_group_delay - smoothed_group_delay # t_D | |
| return static_group_delay | |
| def get_coarse_aperiodicity( | |
| group_delay: torch.Tensor, | |
| sr: int, | |
| n_fft: int, | |
| freq_interval: int, | |
| n_aperiodicities: int, | |
| window: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Third step: estimation of band-aperiodicity""" | |
| group_delay = torch.as_tensor(group_delay) | |
| window = torch.as_tensor(window) | |
| boundary = int(round(n_fft * 8 / window.size(-1))) | |
| half_window_length = window.size(-1) // 2 | |
| coarse_aperiodicity = torch.empty( | |
| group_delay.size()[:-1] + (n_aperiodicities,), | |
| dtype=group_delay.dtype, | |
| device=group_delay.device, | |
| ) | |
| for i in range(n_aperiodicities): | |
| center = freq_interval * (i + 1) * n_fft // sr | |
| segment = ( | |
| group_delay[ | |
| ..., center - half_window_length : center + half_window_length + 1 | |
| ] | |
| * window | |
| ) | |
| power_spec: torch.Tensor = torch.fft.rfft(segment, n_fft).abs().square_() | |
| cumulative_power_spec = power_spec.sort(-1).values.cumsum_(-1) | |
| if D4C_PREVENT_ZERO_DIVISION: | |
| cumulative_power_spec.clamp_(min=6e-8) | |
| coarse_aperiodicity[..., i] = ( | |
| cumulative_power_spec[..., n_fft // 2 - boundary - 1] | |
| / cumulative_power_spec[..., -1] | |
| ) | |
| coarse_aperiodicity.log10_().mul_(10.0) | |
| return coarse_aperiodicity | |
| def d4c_love_train( | |
| x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, threshold: float | |
| ) -> int: | |
| x = torch.as_tensor(x) | |
| position = torch.as_tensor(position) | |
| f0: torch.Tensor = torch.as_tensor(f0) | |
| vuv = f0 != 0 | |
| lowest_f0 = 40 | |
| f0 = f0.clamp(min=lowest_f0) | |
| n_fft = 1 << (3 * sr // lowest_f0).bit_length() | |
| boundary0 = (100 * n_fft - 1) // sr + 1 | |
| boundary1 = (4000 * n_fft - 1) // sr + 1 | |
| boundary2 = (7900 * n_fft - 1) // sr + 1 | |
| waveform, _ = get_windowed_waveform(x, sr, f0, position, 1.5, "blackman", n_fft) | |
| power_spec = torch.fft.rfft(waveform, n_fft).abs().square_() | |
| power_spec[..., : boundary0 + 1] = 0.0 | |
| cumulative_spec = power_spec.cumsum_(-1) | |
| vuv = vuv & ( | |
| cumulative_spec[..., boundary1] > threshold * cumulative_spec[..., boundary2] | |
| ) | |
| return vuv | |
| def d4c_general_body( | |
| x: torch.Tensor, | |
| sr: int, | |
| f0: torch.Tensor, | |
| freq_interval: int, | |
| position: torch.Tensor, | |
| n_fft: int, | |
| n_aperiodicities: int, | |
| window: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| static_centroid = get_static_centroid(x, sr, f0, position, n_fft) | |
| smoothed_power_spec, rms = get_smoothed_power_spec(x, sr, f0, position, n_fft) | |
| static_group_delay = get_static_group_delay( | |
| static_centroid, smoothed_power_spec, sr, f0, n_fft | |
| ) | |
| coarse_aperiodicity = get_coarse_aperiodicity( | |
| static_group_delay, sr, n_fft, freq_interval, n_aperiodicities, window | |
| ) | |
| coarse_aperiodicity.add_((f0[..., None] - 100.0).div_(50.0)).clamp_(max=0.0) | |
| return coarse_aperiodicity, rms | |
| def d4c( | |
| x: torch.Tensor, | |
| f0: torch.Tensor, | |
| t: torch.Tensor, | |
| sr: int, | |
| threshold: float = 0.85, | |
| n_fft_spec: Optional[int] = None, | |
| coarse_only: bool = False, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Adapted from https://github.com/tuanad121/Python-WORLD/blob/master/world/d4c.py""" | |
| FLOOR_F0 = 71 | |
| FLOOR_F0_D4C = 47 | |
| UPPER_LIMIT = 15000 | |
| FREQ_INTERVAL = 3000 | |
| assert sr == int(sr) | |
| sr = int(sr) | |
| assert sr % 2 == 0 | |
| x = torch.as_tensor(x) | |
| f0 = torch.as_tensor(f0) | |
| temporal_positions = torch.as_tensor(t) | |
| n_fft_d4c = 1 << (4 * sr // FLOOR_F0_D4C).bit_length() | |
| if n_fft_spec is None: | |
| n_fft_spec = 1 << (3 * sr // FLOOR_F0).bit_length() | |
| n_aperiodicities = min(UPPER_LIMIT, sr // 2 - FREQ_INTERVAL) // FREQ_INTERVAL | |
| assert n_aperiodicities >= 1 | |
| window_length = FREQ_INTERVAL * n_fft_d4c // sr * 2 + 1 | |
| window = nuttall(window_length, device=x.device) | |
| freq_axis = torch.arange(n_fft_spec // 2 + 1, device=x.device) * (sr / n_fft_spec) | |
| coarse_aperiodicity, rms = d4c_general_body( | |
| x[..., None, :], | |
| sr, | |
| f0.clamp(min=FLOOR_F0_D4C), | |
| FREQ_INTERVAL, | |
| temporal_positions, | |
| n_fft_d4c, | |
| n_aperiodicities, | |
| window, | |
| ) | |
| if coarse_only: | |
| return coarse_aperiodicity, rms | |
| even_coarse_axis = ( | |
| torch.arange(n_aperiodicities + 3, device=x.device) * FREQ_INTERVAL | |
| ) | |
| assert even_coarse_axis[-2] <= sr // 2 < even_coarse_axis[-1], sr | |
| coarse_axis_low = ( | |
| torch.arange(n_aperiodicities + 1, dtype=torch.float, device=x.device) | |
| * FREQ_INTERVAL | |
| ) | |
| aperiodicity_low = interp( | |
| coarse_axis_low, | |
| F.pad(coarse_aperiodicity, (1, 0), value=-60.0), | |
| freq_axis[freq_axis < n_aperiodicities * FREQ_INTERVAL], | |
| ) | |
| coarse_axis_high = torch.tensor( | |
| [n_aperiodicities * FREQ_INTERVAL, sr * 0.5], device=x.device | |
| ) | |
| aperiodicity_high = interp( | |
| coarse_axis_high, | |
| F.pad(coarse_aperiodicity[..., -1:], (0, 1), value=-1e-12), | |
| freq_axis[freq_axis >= n_aperiodicities * FREQ_INTERVAL], | |
| ) | |
| aperiodicity = torch.cat([aperiodicity_low, aperiodicity_high], -1) | |
| aperiodicity = 10.0 ** (aperiodicity / 20.0) | |
| vuv = d4c_love_train(x[..., None, :], sr, f0, temporal_positions, threshold) | |
| aperiodicity = torch.where(vuv[..., None], aperiodicity, 1 - 1e-12) | |
| return aperiodicity, coarse_aperiodicity | |
| class Vocoder(nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| speaker_embedding_channels: int = 128, | |
| hop_length: int = 240, | |
| n_pre_blocks: int = 4, | |
| out_sample_rate: float = 24000.0, | |
| ): | |
| super().__init__() | |
| self.hop_length = hop_length | |
| self.out_sample_rate = out_sample_rate | |
| self.prenet = ConvNeXtStack( | |
| in_channels=channels, | |
| channels=channels, | |
| intermediate_channels=channels * 2, | |
| n_blocks=n_pre_blocks, | |
| delay=2, # 20ms 遅延 | |
| embed_kernel_size=7, | |
| kernel_size=33, | |
| enable_scaling=True, | |
| use_mha=True, | |
| cross_attention=True, | |
| kv_channels=speaker_embedding_channels, | |
| ) | |
| self.ir_generator = ConvNeXtStack( | |
| in_channels=channels, | |
| channels=channels, | |
| intermediate_channels=channels * 2, | |
| n_blocks=2, | |
| delay=0, | |
| embed_kernel_size=3, | |
| kernel_size=33, | |
| use_weight_standardization=True, | |
| enable_scaling=True, | |
| ) | |
| self.ir_generator_post = WSConv1d(channels, 512, 1) | |
| self.register_buffer("ir_scale", torch.tensor(1.0)) | |
| self.ir_window = nn.Parameter(torch.ones(512)) | |
| self.aperiodicity_generator = ConvNeXtStack( | |
| in_channels=channels, | |
| channels=channels, | |
| intermediate_channels=channels * 2, | |
| n_blocks=1, | |
| delay=0, | |
| embed_kernel_size=3, | |
| kernel_size=33, | |
| use_weight_standardization=True, | |
| enable_scaling=True, | |
| ) | |
| self.aperiodicity_generator_post = WSConv1d(channels, hop_length, 1, bias=False) | |
| self.register_buffer("aperiodicity_scale", torch.tensor(0.005)) | |
| self.post_filter_generator = ConvNeXtStack( | |
| in_channels=channels, | |
| channels=channels, | |
| intermediate_channels=channels * 2, | |
| n_blocks=1, | |
| delay=0, | |
| embed_kernel_size=3, | |
| kernel_size=33, | |
| use_weight_standardization=True, | |
| enable_scaling=True, | |
| ) | |
| self.post_filter_generator_post = WSConv1d(channels, 512, 1, bias=False) | |
| self.register_buffer("post_filter_scale", torch.tensor(0.01)) | |
| def forward( | |
| self, x: torch.Tensor, pitch: torch.Tensor, speaker_embedding: torch.Tensor | |
| ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: | |
| # x: [batch_size, channels, length] | |
| # pitch: [batch_size, length] | |
| # speaker_embedding: [batch_size, speaker_embedding_length, speaker_embedding_channels] | |
| batch_size, _, length = x.size() | |
| x = self.prenet(x, speaker_embedding) | |
| ir = self.ir_generator(x) | |
| ir = F.silu(ir, inplace=True) | |
| # [batch_size, 512, length] | |
| ir = self.ir_generator_post(ir) | |
| ir *= self.ir_scale | |
| ir_amp = ir[:, : ir.size(1) // 2 + 1, :].exp() | |
| ir_phase = F.pad(ir[:, ir.size(1) // 2 + 1 :, :], (0, 0, 1, 1)) | |
| ir_phase[:, 1::2, :] += math.pi | |
| # TODO: 直流成分が正の値しか取れないのを修正する | |
| # 最近傍補間 | |
| # [batch_size, length * hop_length] | |
| pitch = torch.repeat_interleave(pitch, self.hop_length, dim=1) | |
| # [batch_size, length * hop_length] | |
| periodic_signal = overlap_add( | |
| ir_amp, | |
| ir_phase, | |
| self.ir_window, | |
| pitch, | |
| self.hop_length, | |
| delay=0, | |
| sr=self.out_sample_rate, | |
| ) | |
| aperiodicity = self.aperiodicity_generator(x) | |
| aperiodicity = F.silu(aperiodicity, inplace=True) | |
| # [batch_size, hop_length, length] | |
| aperiodicity = self.aperiodicity_generator_post(aperiodicity) | |
| aperiodicity *= self.aperiodicity_scale | |
| # [batch_size, length * hop_length], [batch_size, length * hop_length] | |
| aperiodic_signal, noise_excitation = generate_noise(aperiodicity, delay=0) | |
| post_filter = self.post_filter_generator(x) | |
| post_filter = F.silu(post_filter, inplace=True) | |
| # [batch_size, 512, length] | |
| post_filter = self.post_filter_generator_post(post_filter) | |
| post_filter *= self.post_filter_scale | |
| post_filter[:, 0, :] += 1.0 | |
| # [batch_size, length, 512] | |
| post_filter = post_filter.transpose(1, 2) | |
| with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): | |
| periodic_signal = periodic_signal.float() | |
| aperiodic_signal = aperiodic_signal.float() | |
| post_filter = post_filter.float() | |
| post_filter = torch.fft.rfft(post_filter, n=768) | |
| # [batch_size, length, 768] | |
| periodic_signal = torch.fft.irfft( | |
| torch.fft.rfft( | |
| periodic_signal.view(batch_size, length, self.hop_length), n=768 | |
| ) | |
| * post_filter, | |
| n=768, | |
| ) | |
| aperiodic_signal = torch.fft.irfft( | |
| torch.fft.rfft( | |
| aperiodic_signal.view(batch_size, length, self.hop_length), n=768 | |
| ) | |
| * post_filter, | |
| n=768, | |
| ) | |
| periodic_signal = F.fold( | |
| periodic_signal.transpose(1, 2), | |
| (1, (length - 1) * self.hop_length + 768), | |
| (1, 768), | |
| stride=(1, self.hop_length), | |
| ).squeeze_((1, 2)) | |
| aperiodic_signal = F.fold( | |
| aperiodic_signal.transpose(1, 2), | |
| (1, (length - 1) * self.hop_length + 768), | |
| (1, 768), | |
| stride=(1, self.hop_length), | |
| ).squeeze_((1, 2)) | |
| periodic_signal = periodic_signal[:, 120 : 120 + length * self.hop_length] | |
| aperiodic_signal = aperiodic_signal[:, 120 : 120 + length * self.hop_length] | |
| noise_excitation = noise_excitation[:, 120:] | |
| # TODO: compensation の正確さが怪しくなってくる。今も本当に必要なのか? | |
| # [batch_size, 1, length * hop_length] | |
| y_g_hat = (periodic_signal + aperiodic_signal)[:, None, :] | |
| return y_g_hat, { | |
| "periodic_signal": periodic_signal.detach(), | |
| "aperiodic_signal": aperiodic_signal.detach(), | |
| "noise_excitation": noise_excitation.detach(), | |
| } | |
| def merge_weights(self): | |
| self.prenet.merge_weights() | |
| self.ir_generator.merge_weights() | |
| self.ir_generator_post.merge_weights() | |
| self.aperiodicity_generator.merge_weights() | |
| self.aperiodicity_generator_post.merge_weights() | |
| self.ir_generator_post.weight.data *= self.ir_scale | |
| self.ir_generator_post.bias.data *= self.ir_scale | |
| self.ir_scale.fill_(1.0) | |
| self.aperiodicity_generator_post.weight.data *= self.aperiodicity_scale | |
| self.aperiodicity_scale.fill_(1.0) | |
| self.post_filter_generator.merge_weights() | |
| self.post_filter_generator_post.merge_weights() | |
| self.post_filter_generator_post.weight.data *= self.post_filter_scale | |
| self.post_filter_scale.fill_(1.0) | |
| def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): | |
| if isinstance(f, (str, bytes, os.PathLike)): | |
| with open(f, "wb") as f: | |
| self.dump(f) | |
| return | |
| if not hasattr(f, "write"): | |
| raise TypeError | |
| dump_layer(self.prenet, f) | |
| dump_layer(self.ir_generator, f) | |
| dump_layer(self.ir_generator_post, f) | |
| dump_layer(self.ir_window, f) | |
| dump_layer(self.aperiodicity_generator, f) | |
| dump_layer(self.aperiodicity_generator_post, f) | |
| dump_layer(self.post_filter_generator, f) | |
| dump_layer(self.post_filter_generator_post, f) | |
| def compute_loudness( | |
| x: torch.Tensor, sr: int, win_lengths: list[int] | |
| ) -> list[torch.Tensor]: | |
| # x: [batch_size, wav_length] | |
| assert x.ndim == 2 | |
| n_fft = 2048 | |
| chunk_length = n_fft // 2 | |
| n_taps = chunk_length + 1 | |
| results = [] | |
| with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): | |
| if not hasattr(compute_loudness, "filter"): | |
| compute_loudness.filter = {} | |
| if sr not in compute_loudness.filter: | |
| ir = torch.zeros(n_taps, device=x.device, dtype=torch.double) | |
| ir[0] = 0.5 | |
| ir = torchaudio.functional.treble_biquad( | |
| ir, sr, 4.0, 1500.0, 1.0 / math.sqrt(2) | |
| ) | |
| ir = torchaudio.functional.highpass_biquad(ir, sr, 38.0, 0.5) | |
| ir *= 2.0 | |
| compute_loudness.filter[sr] = torch.fft.rfft(ir, n=n_fft).to( | |
| torch.complex64 | |
| ) | |
| x = x.float() | |
| wav_length = x.size(-1) | |
| if wav_length % chunk_length != 0: | |
| x = F.pad(x, (0, chunk_length - wav_length % chunk_length)) | |
| padded_wav_length = x.size(-1) | |
| x = x.view(x.size()[:-1] + (padded_wav_length // chunk_length, chunk_length)) | |
| x = torch.fft.irfft( | |
| torch.fft.rfft(x, n=n_fft) * compute_loudness.filter[sr], | |
| n=n_fft, | |
| ) | |
| x = F.fold( | |
| x.transpose(-2, -1), | |
| (1, padded_wav_length + chunk_length), | |
| (1, n_fft), | |
| stride=(1, chunk_length), | |
| ).squeeze_((-3, -2))[..., :wav_length] | |
| x.square_() | |
| for win_length in win_lengths: | |
| hop_length = win_length // 4 | |
| # [..., n_frames] | |
| energy = ( | |
| x.unfold(-1, win_length, hop_length) | |
| .matmul(torch.hann_window(win_length, device=x.device)) | |
| .add_(win_length / 4.0 * 1e-5) | |
| .log10_() | |
| ) | |
| # フィルタリング後の波形が振幅 1 の正弦波なら大体 log10(win_length/4), 1 の差は 10dB の差 | |
| results.append(energy) | |
| return results | |
| def beatrice_slice_segments( | |
| x: torch.Tensor, start_indices: torch.Tensor, segment_length: int | |
| ) -> torch.Tensor: | |
| batch_size, channels, _ = x.size() | |
| # [batch_size, 1, segment_size] | |
| indices = start_indices[:, None, None] + torch.arange( | |
| segment_length, device=start_indices.device | |
| ) | |
| # [batch_size, channels, segment_size] | |
| indices = indices.expand(batch_size, channels, segment_length) | |
| return x.gather(2, indices) | |
| class ConverterNetwork(nn.Module): | |
| def __init__( | |
| self, | |
| phone_extractor: PhoneExtractor, | |
| pitch_estimator: PitchEstimator, | |
| n_speakers: int, | |
| pitch_bins: int, | |
| hidden_channels: int, | |
| vq_topk: int = 4, | |
| training_time_vq: Literal["none", "self", "random"] = "none", | |
| phone_noise_ratio: int = 0.5, | |
| floor_noise_level: float = 1e-3, | |
| ): | |
| super().__init__() | |
| self.frozen_modules = { | |
| "phone_extractor": phone_extractor.eval().requires_grad_(False), | |
| "pitch_estimator": pitch_estimator.eval().requires_grad_(False), | |
| } | |
| self.pitch_bins = pitch_bins | |
| self.phone_noise_ratio = phone_noise_ratio | |
| self.floor_noise_level = floor_noise_level | |
| self.out_sample_rate = out_sample_rate = 24000 | |
| phone_channels = 128 | |
| self.vq = VectorQuantizer( | |
| n_speakers=n_speakers, | |
| codebook_size=512, | |
| channels=phone_channels, | |
| topk=vq_topk, | |
| training_time_vq=training_time_vq, | |
| ) | |
| self.embed_phone = nn.Conv1d(phone_channels, hidden_channels, 1) | |
| self.embed_phone.weight.data.normal_(0.0, math.sqrt(2.0 / (256 * 5))) | |
| self.embed_phone.bias.data.zero_() | |
| self.embed_quantized_pitch = nn.Embedding(pitch_bins, hidden_channels) | |
| phase = ( | |
| torch.arange(pitch_bins, dtype=torch.float)[:, None] | |
| * ( | |
| torch.arange(0, hidden_channels, 2, dtype=torch.float) | |
| * (-math.log(10000.0) / hidden_channels) | |
| ).exp_() | |
| ) | |
| self.embed_quantized_pitch.weight.data[:, 0::2] = phase.sin() | |
| self.embed_quantized_pitch.weight.data[:, 1::2] = phase.cos_() | |
| self.embed_quantized_pitch.weight.data *= math.sqrt(4.0 / 5.0) | |
| self.embed_quantized_pitch.weight.requires_grad_(False) | |
| self.embed_pitch_features = nn.Conv1d(4, hidden_channels, 1) | |
| self.embed_pitch_features.weight.data.normal_(0.0, math.sqrt(2.0 / (4 * 5))) | |
| self.embed_pitch_features.bias.data.zero_() | |
| self.embed_speaker = nn.Embedding(n_speakers, hidden_channels) | |
| self.embed_speaker.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0)) | |
| self.embed_formant_shift = nn.Embedding(9, hidden_channels) | |
| self.embed_formant_shift.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0)) | |
| self.key_value_speaker_embedding_length = 384 | |
| self.key_value_speaker_embedding_channels = 128 | |
| self.key_value_speaker_embedding = nn.Embedding( | |
| n_speakers, | |
| self.key_value_speaker_embedding_length | |
| * self.key_value_speaker_embedding_channels, | |
| ) | |
| self.key_value_speaker_embedding.weight.data[0].normal_() | |
| self.key_value_speaker_embedding.weight.data[1:] = ( | |
| self.key_value_speaker_embedding.weight.data[0] | |
| ) | |
| self.vocoder = Vocoder( | |
| channels=hidden_channels, | |
| speaker_embedding_channels=self.key_value_speaker_embedding_channels, | |
| hop_length=out_sample_rate // 100, | |
| n_pre_blocks=4, | |
| out_sample_rate=out_sample_rate, | |
| ) | |
| self.melspectrograms = nn.ModuleList() | |
| for win_length, n_mels in [ | |
| (32, 5), | |
| (64, 10), | |
| (128, 20), | |
| (256, 40), | |
| (512, 80), | |
| (1024, 160), | |
| (2048, 320), | |
| ]: | |
| self.melspectrograms.append( | |
| torchaudio.transforms.MelSpectrogram( | |
| sample_rate=out_sample_rate, | |
| n_fft=win_length, | |
| win_length=win_length, | |
| hop_length=win_length // 4, | |
| n_mels=n_mels, | |
| power=2, | |
| norm="slaney", | |
| mel_scale="slaney", | |
| ) | |
| ) | |
| def initialize_vq(self, inputs: Sequence[Iterable[torch.Tensor]]): | |
| collector_func = self.frozen_modules["phone_extractor"].units | |
| target_layer = self.frozen_modules["phone_extractor"].head | |
| self.vq.build_codebooks( | |
| collector_func, | |
| target_layer, | |
| inputs, | |
| ) | |
| self.vq.enable_hook(target_layer) | |
| def enable_hook(self): | |
| target_layer = self.frozen_modules["phone_extractor"].head | |
| self.vq.enable_hook(target_layer) | |
| def _get_resampler( | |
| self, orig_freq, new_freq, device, cache={} | |
| ) -> torchaudio.transforms.Resample: | |
| key = orig_freq, new_freq | |
| if key in cache: | |
| return cache[key] | |
| resampler = torchaudio.transforms.Resample(orig_freq, new_freq).to( | |
| device, non_blocking=True | |
| ) | |
| cache[key] = resampler | |
| return resampler | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| target_speaker_id: torch.Tensor, | |
| formant_shift_semitone: torch.Tensor, | |
| pitch_shift_semitone: Optional[torch.Tensor] = None, | |
| slice_start_indices: Optional[torch.Tensor] = None, | |
| slice_segment_length: Optional[int] = None, | |
| return_stats: bool = False, | |
| ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]: | |
| # x: [batch_size, 1, wav_length] | |
| # target_speaker_id: Long[batch_size] | |
| # formant_shift_semitone: [batch_size] | |
| # pitch_shift_semitone: [batch_size] | |
| # slice_start_indices: [batch_size] | |
| batch_size, _, _ = x.size() | |
| self.vq.set_target_speaker_ids(target_speaker_id) | |
| with torch.inference_mode(): | |
| phone_extractor: PhoneExtractor = self.frozen_modules["phone_extractor"] | |
| pitch_estimator: PitchEstimator = self.frozen_modules["pitch_estimator"] | |
| # [batch_size, 1, wav_length] -> [batch_size, phone_channels, length] | |
| phone = phone_extractor.units(x).transpose(1, 2) | |
| if self.training and self.phone_noise_ratio != 0.0: | |
| phone *= (1.0 - self.phone_noise_ratio) / phone.square().mean( | |
| 1, keepdim=True | |
| ).sqrt_() | |
| noise = torch.randn_like(phone) | |
| noise *= ( | |
| self.phone_noise_ratio | |
| / noise.square().mean(1, keepdim=True).sqrt_() | |
| ) | |
| phone += noise | |
| # F.rms_norm は PyTorch >= 2.4 が必要 | |
| phone *= ( | |
| 1.0 | |
| / phone.square() | |
| .mean(1, keepdim=True) | |
| .add_(torch.finfo(torch.float).eps) | |
| .sqrt_() | |
| ) | |
| # [batch_size, 1, wav_length] -> [batch_size, pitch_bins, length], [batch_size, 1, length] | |
| pitch, energy = pitch_estimator(x) | |
| # augmentation | |
| if self.training: | |
| # [batch_size, pitch_bins - 1] | |
| weights = pitch.softmax(1)[:, 1:, :].mean(2) | |
| # [batch_size] | |
| mean_pitch = ( | |
| weights | |
| * torch.arange( | |
| 1, | |
| self.embed_quantized_pitch.num_embeddings, | |
| device=weights.device, | |
| ) | |
| ).sum(1) / weights.sum(1) | |
| mean_pitch = mean_pitch.round_().long() | |
| target_pitch = torch.randint_like(mean_pitch, 64, 257) | |
| shift = target_pitch - mean_pitch | |
| shift_ratio = ( | |
| 2.0 ** (shift.float() / pitch_estimator.pitch_bins_per_octave) | |
| ).tolist() | |
| shift = [] | |
| interval_length = 100 # 1s | |
| interval_zeros = torch.zeros( | |
| (1, 1, interval_length * 160), device=x.device | |
| ) | |
| concatenated_shifted_x = [] | |
| offsets = [0] | |
| torch.backends.cudnn.benchmark = False | |
| for i in range(batch_size): | |
| shift_ratio_i = shift_ratio[i] | |
| shift_ratio_fraction_i = Fraction.from_float( | |
| shift_ratio_i | |
| ).limit_denominator(30) | |
| shift_numer_i = shift_ratio_fraction_i.numerator | |
| shift_denom_i = shift_ratio_fraction_i.denominator | |
| shift_ratio_i = shift_numer_i / shift_denom_i | |
| shift_i = int( | |
| round( | |
| math.log2(shift_ratio_i) | |
| * pitch_estimator.pitch_bins_per_octave | |
| ) | |
| ) | |
| shift.append(shift_i) | |
| shift_ratio[i] = shift_ratio_i | |
| # [1, 1, wav_length / shift_ratio] | |
| with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): | |
| shifted_x_i = self._get_resampler( | |
| shift_numer_i, shift_denom_i, x.device | |
| )(x[i])[None] | |
| if shifted_x_i.size(2) % 160 != 0: | |
| shifted_x_i = F.pad( | |
| shifted_x_i, | |
| (0, 160 - shifted_x_i.size(2) % 160), | |
| mode="reflect", | |
| ) | |
| assert shifted_x_i.size(2) % 160 == 0 | |
| offsets.append( | |
| offsets[-1] + interval_length + shifted_x_i.size(2) // 160 | |
| ) | |
| concatenated_shifted_x.extend([interval_zeros, shifted_x_i]) | |
| if offsets[-1] % 256 != 0: | |
| # 長さが同じ方が何かのキャッシュが効いて早くなるようなので | |
| # 適当に 256 の倍数になるようにパディングして長さのパターン数を減らす | |
| concatenated_shifted_x.append( | |
| torch.zeros( | |
| (1, 1, (256 - offsets[-1] % 256) * 160), device=x.device | |
| ) | |
| ) | |
| # [batch_size, 1, sum(wav_length) + batch_size * 16000] | |
| concatenated_shifted_x = torch.cat(concatenated_shifted_x, dim=2) | |
| assert concatenated_shifted_x.size(2) % (256 * 160) == 0 | |
| # [1, pitch_bins, length / shift_ratio], [1, 1, length / shift_ratio] | |
| concatenated_pitch, concatenated_energy = pitch_estimator( | |
| concatenated_shifted_x | |
| ) | |
| for i in range(batch_size): | |
| shift_i = shift[i] | |
| shift_ratio_i = shift_ratio[i] | |
| left = offsets[i] + interval_length | |
| right = offsets[i + 1] | |
| pitch_i = concatenated_pitch[:, :, left:right] | |
| energy_i = concatenated_energy[:, :, left:right] | |
| pitch_i = F.interpolate( | |
| pitch_i, | |
| scale_factor=shift_ratio_i, | |
| mode="linear", | |
| align_corners=False, | |
| ) | |
| energy_i = F.interpolate( | |
| energy_i, | |
| scale_factor=shift_ratio_i, | |
| mode="linear", | |
| align_corners=False, | |
| ) | |
| assert pitch_i.size(2) == energy_i.size(2) | |
| assert abs(pitch_i.size(2) - pitch.size(2)) <= 10 | |
| length = min(pitch_i.size(2), pitch.size(2)) | |
| if shift_i > 0: | |
| pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length] | |
| pitch[i : i + 1, 1:-shift_i, :length] = pitch_i[ | |
| :, 1 + shift_i :, :length | |
| ] | |
| pitch[i : i + 1, -shift_i:, :length] = -10.0 | |
| elif shift_i < 0: | |
| pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length] | |
| pitch[i : i + 1, 1 : 1 - shift_i, :length] = -10.0 | |
| pitch[i : i + 1, 1 - shift_i :, :length] = pitch_i[ | |
| :, 1:shift_i, :length | |
| ] | |
| energy[i : i + 1, :, :length] = energy_i[:, :, :length] | |
| torch.backends.cudnn.benchmark = True | |
| # [batch_size, pitch_bins, length] -> Long[batch_size, length], [batch_size, 3, length] | |
| quantized_pitch, pitch_features = pitch_estimator.sample_pitch( | |
| pitch, return_features=True | |
| ) | |
| if pitch_shift_semitone is not None: | |
| quantized_pitch = torch.where( | |
| quantized_pitch == 0, | |
| quantized_pitch, | |
| ( | |
| quantized_pitch | |
| + ( | |
| pitch_shift_semitone[:, None] | |
| * (pitch_estimator.pitch_bins_per_octave / 12.0) | |
| ) | |
| .round_() | |
| .long() | |
| ).clamp_(1, self.pitch_bins - 1), | |
| ) | |
| pitch = 55.0 * 2.0 ** ( | |
| quantized_pitch.float() / pitch_estimator.pitch_bins_per_octave | |
| ) | |
| # phone が 2.5ms 先読みしているのに対して、 | |
| # energy は 12.5ms, pitch_features は 22.5ms 先読みしているので、 | |
| # ずらして phone に合わせる | |
| energy = F.pad(energy[:, :, :-1], (1, 0), mode="reflect") | |
| quantized_pitch = F.pad(quantized_pitch[:, :-2], (2, 0), mode="reflect") | |
| pitch_features = F.pad(pitch_features[:, :, :-2], (2, 0), mode="reflect") | |
| # [batch_size, 1, length], [batch_size, 3, length] -> [batch_size, 4, length] | |
| pitch_features = torch.cat([energy, pitch_features], dim=1) | |
| formant_shift_indices = ( | |
| ((formant_shift_semitone + 2.0) * 2.0).round_().long() | |
| ) | |
| phone = phone.clone() | |
| quantized_pitch = quantized_pitch.clone() | |
| pitch_features = pitch_features.clone() | |
| formant_shift_indices = formant_shift_indices.clone() | |
| pitch = pitch.clone() | |
| # [batch_sise, hidden_channels, length] | |
| x = ( | |
| self.embed_phone(phone) | |
| + self.embed_quantized_pitch(quantized_pitch).transpose(1, 2) | |
| + self.embed_pitch_features(pitch_features) | |
| + ( | |
| self.embed_speaker(target_speaker_id)[:, :, None] | |
| + self.embed_formant_shift(formant_shift_indices)[:, :, None] | |
| ) | |
| ) | |
| if slice_start_indices is not None: | |
| assert slice_segment_length is not None | |
| # [batch_size, hidden_channels, length] -> [batch_size, hidden_channels, segment_length] | |
| x = beatrice_slice_segments(x, slice_start_indices, slice_segment_length) | |
| x = F.silu(x, inplace=True) | |
| speaker_embedding = self.key_value_speaker_embedding(target_speaker_id).view( | |
| batch_size, | |
| self.key_value_speaker_embedding_length, | |
| self.key_value_speaker_embedding_channels, | |
| ) | |
| # [batch_size, hidden_channels, segment_length] -> [batch_size, 1, segment_length * 240] | |
| y_g_hat, stats = self.vocoder(x, pitch, speaker_embedding) | |
| stats["pitch"] = pitch | |
| if return_stats: | |
| return y_g_hat, stats | |
| else: | |
| return y_g_hat | |
| def _normalize_melsp(self, x): | |
| return x.clamp(min=1e-10).log_() | |
| def forward_and_compute_loss( | |
| self, | |
| noisy_wavs_16k: torch.Tensor, | |
| target_speaker_id: torch.Tensor, | |
| formant_shift_semitone: torch.Tensor, | |
| slice_start_indices: torch.Tensor, | |
| slice_segment_length: int, | |
| y_all: torch.Tensor, | |
| enable_loss_ap: bool = False, | |
| ) -> tuple[ | |
| torch.Tensor, | |
| torch.Tensor, | |
| torch.Tensor, | |
| torch.Tensor, | |
| torch.Tensor, | |
| torch.Tensor, | |
| dict[str, float], | |
| ]: | |
| # noisy_wavs_16k: [batch_size, 1, wav_length] | |
| # target_speaker_id: Long[batch_size] | |
| # formant_shift_semitone: [batch_size] | |
| # slice_start_indices: [batch_size] | |
| # slice_segment_length: int | |
| # y_all: [batch_size, 1, wav_length] | |
| stats = {} | |
| loss_mel = 0.0 | |
| loss_loudness = 0.0 | |
| loudness_win_lengths = [512, 1024, 2048, 4096] | |
| # [batch_size, 1, wav_length] -> [batch_size, 1, wav_length * 240] | |
| y_hat_all, intermediates = self( | |
| noisy_wavs_16k, | |
| target_speaker_id, | |
| formant_shift_semitone, | |
| return_stats=True, | |
| ) | |
| y_hat_all = y_hat_all.detach().where(y_all == 0.0, y_hat_all) | |
| with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): | |
| periodic_signal = intermediates["periodic_signal"].float() | |
| aperiodic_signal = intermediates["aperiodic_signal"].float() | |
| noise_excitation = intermediates["noise_excitation"].float() | |
| periodic_signal = periodic_signal[:, : noise_excitation.size(1)] | |
| aperiodic_signal = aperiodic_signal[:, : noise_excitation.size(1)] | |
| y_hat_all = y_hat_all.float() | |
| floor_noise = torch.randn_like(y_all) * self.floor_noise_level | |
| y_all = y_all + floor_noise | |
| y_hat_all += floor_noise | |
| y_hat_all_truncated = y_hat_all.squeeze(1)[:, : periodic_signal.size(1)] | |
| y_all_truncated = y_all.squeeze(1)[:, : periodic_signal.size(1)] | |
| y_loudness = compute_loudness( | |
| y_all_truncated, self.out_sample_rate, loudness_win_lengths | |
| ) | |
| y_hat_loudness = compute_loudness( | |
| y_hat_all_truncated, self.out_sample_rate, loudness_win_lengths | |
| ) | |
| for win_length, y_loudness_i, y_hat_loudness_i in zip( | |
| loudness_win_lengths, y_loudness, y_hat_loudness | |
| ): | |
| loss_loudness_i = F.mse_loss(y_hat_loudness_i, y_loudness_i) | |
| loss_loudness += loss_loudness_i * math.sqrt(win_length) | |
| stats[f"loss_loudness_{win_length}"] = loss_loudness_i.item() | |
| for melspectrogram in self.melspectrograms: | |
| melsp_periodic_signal = melspectrogram(periodic_signal) | |
| melsp_aperiodic_signal = melspectrogram(aperiodic_signal) | |
| melsp_noise_excitation = melspectrogram(noise_excitation) | |
| # [1, n_mels, 1] | |
| # 1/6 ... [-0.5, 0.5] の一様乱数の平均パワー | |
| # 3/8 ... ハン窓をかけた時のパワー減衰 | |
| # 0.5 ... 謎 | |
| reference_melsp = melspectrogram.mel_scale( | |
| torch.full( | |
| (1, melspectrogram.n_fft // 2 + 1, 1), | |
| (1 / 6) * (3 / 8) * 0.5 * melspectrogram.win_length, | |
| device=noisy_wavs_16k.device, | |
| ) | |
| ) | |
| aperiodic_ratio = melsp_aperiodic_signal / ( | |
| melsp_periodic_signal + melsp_aperiodic_signal + 1e-5 | |
| ) | |
| compensation_ratio = reference_melsp / (melsp_noise_excitation + 1e-5) | |
| melsp_y_hat = melspectrogram(y_hat_all_truncated) | |
| melsp_y_hat = melsp_y_hat * ( | |
| (1.0 - aperiodic_ratio) + aperiodic_ratio * compensation_ratio | |
| ) | |
| y_hat_mel = self._normalize_melsp(melsp_y_hat) | |
| y_mel = self._normalize_melsp(melspectrogram(y_all_truncated)) | |
| loss_mel_i = F.l1_loss(y_hat_mel, y_mel) | |
| loss_mel += loss_mel_i | |
| stats[ | |
| f"loss_mel_{melspectrogram.win_length}_{melspectrogram.n_mels}" | |
| ] = loss_mel_i.item() | |
| loss_mel /= len(self.melspectrograms) | |
| if enable_loss_ap: | |
| t = ( | |
| torch.arange(intermediates["pitch"].size(1), device=y_all.device) | |
| * 0.01 | |
| + 0.005 | |
| ) | |
| y_coarse_aperiodicity, y_rms = d4c( | |
| y_all.squeeze(1), | |
| intermediates["pitch"], | |
| t, | |
| self.vocoder.out_sample_rate, | |
| coarse_only=True, | |
| ) | |
| y_coarse_aperiodicity = 10.0 ** (y_coarse_aperiodicity / 10.0) | |
| y_hat_coarse_aperiodicity, y_hat_rms = d4c( | |
| y_hat_all.squeeze(1), | |
| intermediates["pitch"], | |
| t, | |
| self.vocoder.out_sample_rate, | |
| coarse_only=True, | |
| ) | |
| y_hat_coarse_aperiodicity = 10.0 ** (y_hat_coarse_aperiodicity / 10.0) | |
| rms = torch.maximum(y_rms, y_hat_rms) | |
| loss_ap = F.mse_loss( | |
| y_hat_coarse_aperiodicity, y_coarse_aperiodicity, reduction="none" | |
| ) | |
| loss_ap *= (rms / (rms + 1e-3) * (rms > 1e-5))[:, :, None] | |
| loss_ap = loss_ap.mean() | |
| else: | |
| loss_ap = torch.tensor(0.0) | |
| # [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240] | |
| y_hat = beatrice_slice_segments( | |
| y_hat_all, slice_start_indices * 240, slice_segment_length * 240 | |
| ) | |
| # [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240] | |
| y = beatrice_slice_segments(y_all, slice_start_indices * 240, slice_segment_length * 240) | |
| return y, y_hat, y_hat_all, loss_loudness, loss_mel, loss_ap, stats | |
| def merge_weights(self): | |
| self.vocoder.merge_weights() | |
| def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): | |
| if isinstance(f, (str, bytes, os.PathLike)): | |
| with open(f, "wb") as f: | |
| self.dump(f) | |
| return | |
| if not hasattr(f, "write"): | |
| raise TypeError | |
| dump_layer(self.embed_phone, f) | |
| dump_layer(self.embed_quantized_pitch, f) | |
| dump_layer(self.embed_pitch_features, f) | |
| dump_layer(self.vocoder, f) | |
| def dump_speaker_embeddings(self, f: Union[BinaryIO, str, bytes, os.PathLike]): | |
| if isinstance(f, (str, bytes, os.PathLike)): | |
| with open(f, "wb") as f: | |
| self.dump_speaker_embeddings(f) | |
| return | |
| if not hasattr(f, "write"): | |
| raise TypeError | |
| dump_params(self.vq.codebooks, f) | |
| dump_layer(self.embed_speaker, f) | |
| dump_layer(self.embed_formant_shift, f) | |
| dump_layer(self.key_value_speaker_embedding, f) | |
| def dump_embedding_setter(self, f: Union[BinaryIO, str, bytes, os.PathLike]): | |
| if isinstance(f, (str, bytes, os.PathLike)): | |
| with open(f, "wb") as f: | |
| self.dump_embedding_setter(f) | |
| return | |
| if not hasattr(f, "write"): | |
| raise TypeError | |
| self.vocoder.prenet.dump_kv(f) | |
| # Discriminator | |
| def _normalize(tensor: torch.Tensor, dim: int) -> torch.Tensor: | |
| denom = tensor.norm(p=2.0, dim=dim, keepdim=True).clamp_min(1e-6) | |
| return tensor / denom | |
| class SANConv2d(nn.Conv2d): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| padding: int = 0, | |
| dilation: int = 1, | |
| bias: bool = True, | |
| padding_mode="zeros", | |
| device=None, | |
| dtype=None, | |
| ): | |
| super().__init__( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=1, | |
| bias=bias, | |
| padding_mode=padding_mode, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| scale = self.weight.norm(p=2.0, dim=[1, 2, 3], keepdim=True).clamp_min(1e-6) | |
| self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight)) | |
| self.scale = nn.parameter.Parameter(scale.view(out_channels)) | |
| if bias: | |
| self.bias = nn.parameter.Parameter( | |
| torch.zeros(in_channels, device=device, dtype=dtype) | |
| ) | |
| else: | |
| self.register_parameter("bias", None) | |
| def forward( | |
| self, input: torch.Tensor, flg_san_train: bool = False | |
| ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: | |
| if self.bias is not None: | |
| input = input + self.bias.view(self.in_channels, 1, 1) | |
| normalized_weight = self._get_normalized_weight() | |
| scale = self.scale.view(self.out_channels, 1, 1) | |
| if flg_san_train: | |
| out_fun = F.conv2d( | |
| input, | |
| normalized_weight.detach(), | |
| None, | |
| self.stride, | |
| self.padding, | |
| self.dilation, | |
| self.groups, | |
| ) | |
| out_dir = F.conv2d( | |
| input.detach(), | |
| normalized_weight, | |
| None, | |
| self.stride, | |
| self.padding, | |
| self.dilation, | |
| self.groups, | |
| ) | |
| out = out_fun * scale, out_dir * scale.detach() | |
| else: | |
| out = F.conv2d( | |
| input, | |
| normalized_weight, | |
| None, | |
| self.stride, | |
| self.padding, | |
| self.dilation, | |
| self.groups, | |
| ) | |
| out = out * scale | |
| return out | |
| def normalize_weight(self): | |
| self.weight.data = self._get_normalized_weight() | |
| def _get_normalized_weight(self) -> torch.Tensor: | |
| return _normalize(self.weight, dim=[1, 2, 3]) | |
| def get_padding(kernel_size: int, dilation: int = 1) -> int: | |
| return (kernel_size * dilation - dilation) // 2 | |
| class BeatriceDiscriminatorP(nn.Module): | |
| def __init__( | |
| self, period: int, kernel_size: int = 5, stride: int = 3, san: bool = False | |
| ): | |
| super().__init__() | |
| self.period = period | |
| self.san = san | |
| # fmt: off | |
| self.convs = nn.ModuleList([ | |
| weight_norm(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), | |
| weight_norm(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), | |
| weight_norm(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), | |
| weight_norm(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), | |
| weight_norm(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, (get_padding(kernel_size, 1), 0))), | |
| ]) | |
| # fmt: on | |
| if san: | |
| self.conv_post = SANConv2d(1024, 1, (3, 1), 1, (1, 0)) | |
| else: | |
| self.conv_post = weight_norm(nn.Conv2d(1024, 1, (3, 1), 1, (1, 0))) | |
| def forward( | |
| self, x: torch.Tensor, flg_san_train: bool = False | |
| ) -> tuple[ | |
| Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor] | |
| ]: | |
| fmap = [] | |
| b, c, t = x.shape | |
| if t % self.period != 0: | |
| n_pad = self.period - (t % self.period) | |
| x = F.pad(x, (0, n_pad), "reflect") | |
| t = t + n_pad | |
| x = x.view(b, c, t // self.period, self.period) | |
| for conv in self.convs: | |
| x = conv(x) | |
| x = F.silu(x, inplace=True) | |
| fmap.append(x) | |
| if self.san: | |
| x = self.conv_post(x, flg_san_train=flg_san_train) | |
| else: | |
| x = self.conv_post(x) | |
| if flg_san_train: | |
| x_fun, x_dir = x | |
| fmap.append(x_fun) | |
| x_fun = torch.flatten(x_fun, 1, -1) | |
| x_dir = torch.flatten(x_dir, 1, -1) | |
| x = x_fun, x_dir | |
| else: | |
| fmap.append(x) | |
| x = torch.flatten(x, 1, -1) | |
| return x, fmap | |
| class BeatriceDiscriminatorR(nn.Module): | |
| def __init__(self, resolution: int, san: bool = False): | |
| super().__init__() | |
| self.resolution = resolution | |
| self.san = san | |
| assert len(self.resolution) == 3 | |
| self.convs = nn.ModuleList( | |
| [ | |
| weight_norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), | |
| weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), | |
| weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), | |
| weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), | |
| weight_norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), | |
| ] | |
| ) | |
| if san: | |
| self.conv_post = SANConv2d(32, 1, (3, 3), padding=(1, 1)) | |
| else: | |
| self.conv_post = weight_norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) | |
| def forward( | |
| self, x: torch.Tensor, flg_san_train: bool = False | |
| ) -> tuple[ | |
| Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor] | |
| ]: | |
| fmap = [] | |
| x = self._spectrogram(x).unsqueeze(1) | |
| for conv in self.convs: | |
| x = conv(x) | |
| x = F.silu(x, inplace=True) | |
| fmap.append(x) | |
| if self.san: | |
| x = self.conv_post(x, flg_san_train=flg_san_train) | |
| else: | |
| x = self.conv_post(x) | |
| if flg_san_train: | |
| x_fun, x_dir = x | |
| fmap.append(x_fun) | |
| x_fun = torch.flatten(x_fun, 1, -1) | |
| x_dir = torch.flatten(x_dir, 1, -1) | |
| x = x_fun, x_dir | |
| else: | |
| fmap.append(x) | |
| x = torch.flatten(x, 1, -1) | |
| return x, fmap | |
| def _spectrogram(self, x: torch.Tensor) -> torch.Tensor: | |
| n_fft, hop_length, win_length = self.resolution | |
| x = F.pad( | |
| x, ((n_fft - hop_length) // 2, (n_fft - hop_length) // 2), mode="reflect" | |
| ).squeeze(1) | |
| with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): | |
| mag = torch.stft( | |
| x.float(), | |
| n_fft=n_fft, | |
| hop_length=hop_length, | |
| win_length=win_length, | |
| window=torch.ones(win_length, device=x.device), | |
| center=False, | |
| return_complex=True, | |
| ).abs() | |
| return mag | |
| class BeatriceMultiPeriodDiscriminator(nn.Module): | |
| def __init__(self, san: bool = False): | |
| super().__init__() | |
| resolutions = [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]] | |
| periods = [2, 3, 5, 7, 11] | |
| self.discriminators = nn.ModuleList( | |
| [BeatriceDiscriminatorR(r, san=san) for r in resolutions] | |
| + [BeatriceDiscriminatorP(p, san=san) for p in periods] | |
| ) | |
| self.discriminator_names = [f"R_{n}_{h}_{w}" for n, h, w in resolutions] + [ | |
| f"P_{p}" for p in periods | |
| ] | |
| self.san = san | |
| def forward( | |
| self, y: torch.Tensor, y_hat: torch.Tensor, flg_san_train: bool = False | |
| ) -> tuple[ | |
| list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]], | |
| list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]], | |
| list[list[torch.Tensor]], | |
| list[list[torch.Tensor]], | |
| ]: | |
| batch_size = y.size(0) | |
| concatenated_y_y_hat = torch.cat([y, y_hat]) | |
| y_d_rs = [] | |
| y_d_gs = [] | |
| fmap_rs = [] | |
| fmap_gs = [] | |
| for d in self.discriminators: | |
| if flg_san_train: | |
| (y_d_fun, y_d_dir), fmap = d( | |
| concatenated_y_y_hat, flg_san_train=flg_san_train | |
| ) | |
| y_d_r_fun, y_d_g_fun = torch.split(y_d_fun, batch_size) | |
| y_d_r_dir, y_d_g_dir = torch.split(y_d_dir, batch_size) | |
| y_d_r = y_d_r_fun, y_d_r_dir | |
| y_d_g = y_d_g_fun, y_d_g_dir | |
| else: | |
| y_d, fmap = d(concatenated_y_y_hat, flg_san_train=flg_san_train) | |
| y_d_r, y_d_g = torch.split(y_d, batch_size) | |
| fmap_r = [] | |
| fmap_g = [] | |
| for fm in fmap: | |
| fm_r, fm_g = torch.split(fm, batch_size) | |
| fmap_r.append(fm_r) | |
| fmap_g.append(fm_g) | |
| y_d_rs.append(y_d_r) | |
| y_d_gs.append(y_d_g) | |
| fmap_rs.append(fmap_r) | |
| fmap_gs.append(fmap_g) | |
| return y_d_rs, y_d_gs, fmap_rs, fmap_gs | |
| def forward_and_compute_loss( | |
| self, y: torch.Tensor, y_hat: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float]]: | |
| y_d_rs, y_d_gs, fmap_rs, fmap_gs = self(y, y_hat, flg_san_train=self.san) | |
| stats = {} | |
| assert len(y_d_gs) == len(y_d_rs) == len(self.discriminators) | |
| with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): | |
| # discriminator loss | |
| d_loss = 0.0 | |
| for dr, dg, name in zip(y_d_rs, y_d_gs, self.discriminator_names): | |
| if self.san: | |
| dr_fun, dr_dir = map(lambda x: x.float(), dr) | |
| dg_fun, dg_dir = map(lambda x: x.float(), dg) | |
| r_loss_fun = F.softplus(1.0 - dr_fun).square().mean() | |
| g_loss_fun = F.softplus(dg_fun).square().mean() | |
| r_loss_dir = F.softplus(1.0 - dr_dir).square().mean() | |
| g_loss_dir = -F.softplus(1.0 - dg_dir).square().mean() | |
| r_loss = r_loss_fun + r_loss_dir | |
| g_loss = g_loss_fun + g_loss_dir | |
| else: | |
| dr = dr.float() | |
| dg = dg.float() | |
| r_loss = (1.0 - dr).square().mean() | |
| g_loss = dg.square().mean() | |
| stats[f"{name}_dr_loss"] = r_loss.item() | |
| stats[f"{name}_dg_loss"] = g_loss.item() | |
| d_loss += r_loss + g_loss | |
| # adversarial loss | |
| adv_loss = 0.0 | |
| for dg, name in zip(y_d_gs, self.discriminator_names): | |
| if self.san: | |
| dg_fun = dg[0].float() | |
| g_loss = F.softplus(1.0 - dg_fun).square().mean() | |
| else: | |
| dg = dg.float() | |
| g_loss = (1.0 - dg).square().mean() | |
| stats[f"{name}_gg_loss"] = g_loss.item() | |
| adv_loss += g_loss | |
| # feature mathcing loss | |
| fm_loss = 0.0 | |
| for fr, fg, name in zip(fmap_rs, fmap_gs, self.discriminator_names): | |
| fm_loss_i = 0.0 | |
| for j, (r, g) in enumerate(zip(fr, fg)): | |
| fm_loss_ij = (r.detach().float() - g.float()).abs().mean() | |
| stats[f"~{name}_fm_loss_{j}"] = fm_loss_ij.item() | |
| fm_loss_i += fm_loss_ij | |
| stats[f"{name}_fm_loss"] = fm_loss_i.item() | |
| fm_loss += fm_loss_i | |
| return d_loss, adv_loss, fm_loss, stats | |
| class GradBalancer: | |
| """Adapted from https://github.com/facebookresearch/encodec/blob/main/encodec/balancer.py""" | |
| def __init__( | |
| self, | |
| weights: dict[str, float], | |
| rescale_grads: bool = True, | |
| total_norm: float = 1.0, | |
| ema_decay: float = 0.999, | |
| per_batch_item: bool = True, | |
| ): | |
| self.weights = weights | |
| self.per_batch_item = per_batch_item | |
| self.total_norm = total_norm | |
| self.ema_decay = ema_decay | |
| self.rescale_grads = rescale_grads | |
| self.ema_total: dict[str, float] = defaultdict(float) | |
| self.ema_fix: dict[str, float] = defaultdict(float) | |
| def backward( | |
| self, | |
| losses: dict[str, torch.Tensor], | |
| input: torch.Tensor, | |
| scaler: Optional[torch.amp.GradScaler] = None, | |
| skip_update_ema: bool = False, | |
| ) -> dict[str, float]: | |
| stats = {} | |
| if skip_update_ema: | |
| assert len(losses) == len(self.ema_total) | |
| ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()} | |
| else: | |
| # 各 loss に対して d loss / d input とそのノルムを計算する | |
| norms = {} | |
| grads = {} | |
| for name, loss in losses.items(): | |
| if scaler is not None: | |
| loss = scaler.scale(loss) | |
| (grad,) = torch.autograd.grad(loss, [input], retain_graph=True) | |
| if not grad.isfinite().all(): | |
| input.backward(grad) | |
| return {} | |
| grad = grad.detach() / (1.0 if scaler is None else scaler.get_scale()) | |
| if self.per_batch_item: | |
| dims = tuple(range(1, grad.dim())) | |
| ema_norm = grad.norm(dim=dims).mean() | |
| else: | |
| ema_norm = grad.norm() | |
| norms[name] = float(ema_norm) | |
| grads[name] = grad | |
| # ノルムの移動平均を計算する | |
| for key, value in norms.items(): | |
| self.ema_total[key] = self.ema_total[key] * self.ema_decay + value | |
| self.ema_fix[key] = self.ema_fix[key] * self.ema_decay + 1.0 | |
| ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()} | |
| # ログを取る | |
| total_ema_norm = sum(ema_norms.values()) | |
| for k, ema_norm in ema_norms.items(): | |
| stats[f"grad_norm_value_{k}"] = ema_norm | |
| stats[f"grad_norm_ratio_{k}"] = ema_norm / (total_ema_norm + 1e-12) | |
| # loss の係数の比率を計算する | |
| if self.rescale_grads: | |
| total_weights = sum([self.weights[k] for k in ema_norms]) | |
| ratios = {k: w / total_weights for k, w in self.weights.items()} | |
| # 勾配を修正する | |
| loss = 0.0 | |
| for name, ema_norm in ema_norms.items(): | |
| if self.rescale_grads: | |
| scale = ratios[name] * self.total_norm / (ema_norm + 1e-12) | |
| else: | |
| scale = self.weights[name] | |
| loss += (losses if skip_update_ema else grads)[name] * scale | |
| if scaler is not None: | |
| loss = scaler.scale(loss) | |
| if skip_update_ema: | |
| (loss,) = torch.autograd.grad(loss, [input]) | |
| input.backward(loss) | |
| return stats | |
| def state_dict(self) -> dict[str, dict[str, float]]: | |
| return { | |
| "ema_total": dict(self.ema_total), | |
| "ema_fix": dict(self.ema_fix), | |
| } | |
| def load_state_dict(self, state_dict): | |
| self.ema_total = defaultdict(float, state_dict["ema_total"]) | |
| self.ema_fix = defaultdict(float, state_dict["ema_fix"]) | |
| class QualityTester(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.utmos = torch.hub.load( | |
| "tarepan/SpeechMOS:v1.0.0", "utmos22_strong", trust_repo=True | |
| ).eval() | |
| def compute_mos(self, wav: torch.Tensor) -> dict[str, list[float]]: | |
| res = {"utmos": self.utmos(wav, sr=16000).tolist()} | |
| return res | |
| def test( | |
| self, converted_wav: torch.Tensor, source_wav: torch.Tensor | |
| ) -> dict[str, list[float]]: | |
| # [batch_size, wav_length] | |
| res = {} | |
| res.update(self.compute_mos(converted_wav)) | |
| return res | |
| def test_many( | |
| self, converted_wavs: list[torch.Tensor], source_wavs: list[torch.Tensor] | |
| ) -> tuple[dict[str, float], dict[str, list[float]]]: | |
| # list[batch_size, wav_length] | |
| results = defaultdict(list) | |
| assert len(converted_wavs) == len(source_wavs) | |
| for converted_wav, source_wav in zip(converted_wavs, source_wavs): | |
| res = self.test(converted_wav, source_wav) | |
| for metric_name, value in res.items(): | |
| results[metric_name].extend(value) | |
| return { | |
| metric_name: sum(values) / len(values) | |
| for metric_name, values in results.items() | |
| }, results | |
| def compute_grad_norm( | |
| model: nn.Module, return_stats: bool = False | |
| ) -> Union[float, dict[str, float]]: | |
| total_norm = 0.0 | |
| stats = {} | |
| for name, p in model.named_parameters(): | |
| if p.grad is None: | |
| continue | |
| param_norm = p.grad.data.norm().item() | |
| if not math.isfinite(param_norm): | |
| param_norm = p.grad.data.float().norm().item() | |
| total_norm += param_norm * param_norm | |
| if return_stats: | |
| stats[f"grad_norm_{name}"] = param_norm | |
| total_norm = math.sqrt(total_norm) | |
| if return_stats: | |
| return total_norm, stats | |
| else: | |
| return total_norm | |
| def compute_mean_f0( | |
| files: list[Path], method: Literal["dio", "harvest"] = "dio" | |
| ) -> float: | |
| sum_log_f0 = 0.0 | |
| n_frames = 0 | |
| for file in files: | |
| wav, sr = beatrice_load_audio(file) | |
| if method == "dio": | |
| f0, _ = pyworld.dio(wav.ravel().numpy().astype(np.float64), sr) | |
| elif method == "harvest": | |
| f0, _ = pyworld.harvest(wav.ravel().numpy().astype(np.float64), sr) | |
| else: | |
| raise ValueError(f"Invalid method: {method}") | |
| f0 = f0[f0 > 0] | |
| sum_log_f0 += float(np.log(f0).sum()) | |
| n_frames += len(f0) | |
| if n_frames == 0: | |
| return math.nan | |
| mean_log_f0 = sum_log_f0 / n_frames | |
| return math.exp(mean_log_f0) | |
| def get_resampler( | |
| sr_before: int, sr_after: int, device="cpu", cache={} | |
| ) -> torchaudio.transforms.Resample: | |
| if not isinstance(device, str): | |
| device = str(device) | |
| if (sr_before, sr_after, device) not in cache: | |
| cache[(sr_before, sr_after, device)] = torchaudio.transforms.Resample( | |
| sr_before, sr_after | |
| ).to(device) | |
| return cache[(sr_before, sr_after, device)] | |
| def convolve(signal: torch.Tensor, ir: torch.Tensor) -> torch.Tensor: | |
| n = 1 << (signal.size(-1) + ir.size(-1) - 2).bit_length() | |
| res = torch.fft.irfft(torch.fft.rfft(signal, n=n) * torch.fft.rfft(ir, n=n), n=n) | |
| return res[..., : signal.size(-1)] | |
| def random_formant_shift( | |
| wav: torch.Tensor, | |
| sample_rate: int, | |
| formant_shift_semitone_min: float = -3.0, | |
| formant_shift_semitone_max: float = 3.0, | |
| ) -> torch.Tensor: | |
| assert wav.ndim == 2 | |
| assert wav.size(0) == 1 | |
| device = wav.device | |
| hop_length = 256 | |
| # [wav_length] | |
| wav_np = wav.ravel().double().cpu().numpy() | |
| f0, t = pyworld.dio( | |
| wav_np, | |
| sample_rate, | |
| f0_floor=55, | |
| f0_ceil=1400, | |
| frame_period=hop_length * 1000 / sample_rate, | |
| ) | |
| f0 = pyworld.stonemask(wav_np, f0, t, sample_rate) | |
| world_sp = pyworld.cheaptrick(wav_np, f0, t, sample_rate) | |
| world_sp = ( | |
| torch.from_numpy(world_sp).float().to(device).sqrt_()[None] | |
| ) # [1, length, n_fft // 2 + 1] | |
| n_fft = win_length = (world_sp.size(2) - 1) * 2 | |
| window = torch.hann_window(win_length, device=device) | |
| # [1, n_fft // 2 + 1, length] | |
| stft_sp = torch.stft( | |
| wav, | |
| n_fft=n_fft, | |
| hop_length=hop_length, | |
| win_length=win_length, | |
| window=window, | |
| return_complex=True, | |
| ) | |
| assert world_sp.size(1) == stft_sp.size(2), (world_sp.size(), stft_sp.size()) | |
| assert world_sp.size(2) == stft_sp.size(1), (world_sp.size(), stft_sp.size()) | |
| shift_semitones = ( | |
| torch.rand(()).item() | |
| * (formant_shift_semitone_max - formant_shift_semitone_min) | |
| + formant_shift_semitone_min | |
| ) | |
| shift_ratio = 2.0 ** (shift_semitones / 12.0) | |
| shifted_world_sp = F.interpolate( | |
| world_sp, scale_factor=shift_ratio, mode="linear", align_corners=True | |
| ) | |
| if shifted_world_sp.size(2) > n_fft // 2 + 1: | |
| shifted_world_sp = shifted_world_sp[:, :, : n_fft // 2 + 1] | |
| elif shifted_world_sp.size(2) < n_fft // 2 + 1: | |
| shifted_world_sp = F.pad( | |
| shifted_world_sp, (0, n_fft // 2 + 1 - shifted_world_sp.size(2)) | |
| ) | |
| ratio = ((shifted_world_sp + 1e-5) / (world_sp + 1e-5)).clamp(0.1, 10.0) | |
| stft_sp *= ratio.transpose(-2, -1) # [1, n_fft // 2 + 1, length] | |
| out = torch.istft( | |
| stft_sp, | |
| n_fft=n_fft, | |
| hop_length=hop_length, | |
| win_length=win_length, | |
| window=window, | |
| length=wav.size(-1), | |
| ) | |
| return out | |
| def random_filter(audio: torch.Tensor) -> torch.Tensor: | |
| assert audio.ndim == 2 | |
| ab = torch.rand(audio.size(0), 6) * 0.75 - 0.375 | |
| a, b = ab[:, :3], ab[:, 3:] | |
| a[:, 0] = 1.0 | |
| b[:, 0] = 1.0 | |
| audio = torchaudio.functional.lfilter(audio, a, b, clamp=False) | |
| return audio | |
| def get_noise( | |
| n_samples: int, sample_rate: float, files: list[Union[str, bytes, os.PathLike]] | |
| ) -> torch.Tensor: | |
| resample_augmentation_candidates = [0.9, 0.95, 1.0, 1.05, 1.1] | |
| wavs = [] | |
| current_length = 0 | |
| while current_length < n_samples: | |
| idx_files = torch.randint(0, len(files), ()) | |
| file = files[idx_files] | |
| wav, sr = beatrice_load_audio(file) | |
| assert wav.size(0) == 1 | |
| augmented_sample_rate = int( | |
| round( | |
| sample_rate | |
| * resample_augmentation_candidates[ | |
| torch.randint(0, len(resample_augmentation_candidates), ()) | |
| ] | |
| ) | |
| ) | |
| resampler = get_resampler(sr, augmented_sample_rate) | |
| wav = resampler(wav) | |
| wav = random_filter(wav) | |
| wav *= 0.99 / (wav.abs().max() + 1e-5) | |
| wavs.append(wav) | |
| current_length += wav.size(1) | |
| start = torch.randint(0, current_length - n_samples + 1, ()) | |
| wav = torch.cat(wavs, dim=1)[:, start : start + n_samples] | |
| assert wav.size() == (1, n_samples), wav.size() | |
| return wav | |
| def get_butterworth_lpf( | |
| cutoff_freq: float, sample_rate: int, cache={} | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| if (cutoff_freq, sample_rate) not in cache: | |
| q = math.sqrt(0.5) | |
| omega = math.tau * cutoff_freq / sample_rate | |
| cos_omega = math.cos(omega) | |
| alpha = math.sin(omega) / (2.0 * q) | |
| b1 = (1.0 - cos_omega) / (1.0 + alpha) | |
| b0 = b1 * 0.5 | |
| a1 = -2.0 * cos_omega / (1.0 + alpha) | |
| a2 = (1.0 - alpha) / (1.0 + alpha) | |
| cache[(cutoff_freq, sample_rate)] = ( | |
| torch.tensor([b0, b1, b0]), | |
| torch.tensor([1.0, a1, a2]), | |
| ) | |
| return cache[(cutoff_freq, sample_rate)] | |
| def augment_audio( | |
| clean: torch.Tensor, | |
| sample_rate: int, | |
| noise_files: list[Union[str, bytes, os.PathLike]], | |
| ir_files: list[Union[str, bytes, os.PathLike]], | |
| snr_candidates: list[float] = [20.0, 25.0, 30.0, 35.0, 40.0, 45.0], | |
| formant_shift_probability: float = 0.5, | |
| formant_shift_semitone_min: float = -3.0, | |
| formant_shift_semitone_max: float = 3.0, | |
| reverb_probability: float = 0.5, | |
| lpf_probability: float = 0.2, | |
| lpf_cutoff_freq_candidates: list[float] = [2000.0, 3000.0, 4000.0, 6000.0], | |
| ) -> torch.Tensor: | |
| # [1, wav_length] | |
| assert clean.size(0) == 1 | |
| n_samples = clean.size(1) | |
| original_clean_rms = clean.square().mean().sqrt_() | |
| # clean をフォルマントシフトする | |
| if torch.rand(()) < formant_shift_probability: | |
| clean = random_formant_shift( | |
| clean, sample_rate, formant_shift_semitone_min, formant_shift_semitone_max | |
| ) | |
| # noise を取得して clean と concat する | |
| noise = get_noise(n_samples, sample_rate, noise_files) | |
| signals = torch.cat([clean, noise]) | |
| # clean, noise に異なるランダムフィルタをかける | |
| signals = random_filter(signals) | |
| # clean, noise にリバーブをかける | |
| if torch.rand(()) < reverb_probability: | |
| ir_file = ir_files[torch.randint(0, len(ir_files), ())] | |
| ir, sr = beatrice_load_audio(ir_file) | |
| assert ir.size() == (2, sr), ir.size() | |
| assert sr == sample_rate, (sr, sample_rate) | |
| signals = convolve(signals, ir) | |
| # clean, noise に同じ LPF をかける | |
| if torch.rand(()) < lpf_probability: | |
| if signals.abs().max() > 0.8: | |
| signals /= signals.abs().max() * 1.25 | |
| cutoff_freq = lpf_cutoff_freq_candidates[ | |
| torch.randint(0, len(lpf_cutoff_freq_candidates), ()) | |
| ] | |
| b, a = get_butterworth_lpf(cutoff_freq, sample_rate) | |
| signals = torchaudio.functional.lfilter(signals, a, b, clamp=False) | |
| # clean の音量を合わせる | |
| clean, noise = signals | |
| clean_rms = clean.square().mean().sqrt_() | |
| clean *= original_clean_rms / clean_rms | |
| if len(snr_candidates) >= 1: | |
| # clean, noise の音量をピークを重視して取る | |
| clean_level = clean.square().square_().mean().sqrt_().sqrt_() | |
| noise_level = noise.square().square_().mean().sqrt_().sqrt_() | |
| # SNR | |
| snr = snr_candidates[torch.randint(0, len(snr_candidates), ())] | |
| # noisy を生成 | |
| noisy = clean + noise * ( | |
| 0.1 ** (snr / 20.0) * clean_level / (noise_level + 1e-5) | |
| ) | |
| return noisy | |
| class WavDataset(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| audio_files: list[tuple[Path, int]], | |
| in_sample_rate: int = 16000, | |
| out_sample_rate: int = 24000, | |
| wav_length: int = 4 * 24000, # 4s | |
| segment_length: int = 100, # 1s | |
| noise_files: Optional[list[Union[str, bytes, os.PathLike]]] = None, | |
| ir_files: Optional[list[Union[str, bytes, os.PathLike]]] = None, | |
| augmentation_snr_candidates: list[float] = [20.0, 25.0, 30.0, 35.0, 40.0, 45.0], | |
| augmentation_formant_shift_probability: float = 0.5, | |
| augmentation_formant_shift_semitone_min: float = -3.0, | |
| augmentation_formant_shift_semitone_max: float = 3.0, | |
| augmentation_reverb_probability: float = 0.5, | |
| augmentation_lpf_probability: float = 0.2, | |
| augmentation_lpf_cutoff_freq_candidates: list[float] = [ | |
| 2000.0, | |
| 3000.0, | |
| 4000.0, | |
| 6000.0, | |
| ], | |
| ): | |
| self.audio_files = audio_files | |
| self.in_sample_rate = in_sample_rate | |
| self.out_sample_rate = out_sample_rate | |
| self.wav_length = wav_length | |
| self.segment_length = segment_length | |
| self.noise_files = noise_files | |
| self.ir_files = ir_files | |
| self.augmentation_snr_candidates = augmentation_snr_candidates | |
| self.augmentation_formant_shift_probability = ( | |
| augmentation_formant_shift_probability | |
| ) | |
| self.augmentation_formant_shift_semitone_min = ( | |
| augmentation_formant_shift_semitone_min | |
| ) | |
| self.augmentation_formant_shift_semitone_max = ( | |
| augmentation_formant_shift_semitone_max | |
| ) | |
| self.augmentation_reverb_probability = augmentation_reverb_probability | |
| self.augmentation_lpf_probability = augmentation_lpf_probability | |
| self.augmentation_lpf_cutoff_freq_candidates = ( | |
| augmentation_lpf_cutoff_freq_candidates | |
| ) | |
| if (noise_files is None) is not (ir_files is None): | |
| raise ValueError("noise_files and ir_files must be both None or not None") | |
| self.in_hop_length = in_sample_rate // 100 | |
| self.out_hop_length = out_sample_rate // 100 # 10ms 刻み | |
| def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, int, int]: | |
| file, speaker_id = self.audio_files[index] | |
| clean_wav, sample_rate = beatrice_load_audio(file) | |
| if clean_wav.size(0) != 1: | |
| ch = torch.randint(0, clean_wav.size(0), ()) | |
| clean_wav = clean_wav[ch : ch + 1] | |
| formant_shift_candidates = [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0] | |
| formant_shift = formant_shift_candidates[ | |
| torch.randint(0, len(formant_shift_candidates), ()).item() | |
| ] | |
| resampler_fraction = Fraction( | |
| sample_rate / self.out_sample_rate * 2.0 ** (formant_shift / 12.0) | |
| ).limit_denominator(300) | |
| clean_wav = get_resampler( | |
| resampler_fraction.numerator, resampler_fraction.denominator | |
| )(clean_wav) | |
| assert clean_wav.size(0) == 1 | |
| assert clean_wav.size(1) != 0 | |
| clean_wav = F.pad(clean_wav, (self.wav_length, self.wav_length)) | |
| if self.noise_files is None: | |
| noisy_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)( | |
| clean_wav | |
| ) | |
| else: | |
| clean_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)( | |
| clean_wav | |
| ) | |
| noisy_wav_16k = augment_audio( | |
| clean_wav_16k, | |
| self.in_sample_rate, | |
| self.noise_files, | |
| self.ir_files, | |
| self.augmentation_snr_candidates, | |
| self.augmentation_formant_shift_probability, | |
| self.augmentation_formant_shift_semitone_min, | |
| self.augmentation_formant_shift_semitone_max, | |
| self.augmentation_reverb_probability, | |
| self.augmentation_lpf_probability, | |
| self.augmentation_lpf_cutoff_freq_candidates, | |
| ) | |
| clean_wav = clean_wav.squeeze_(0) | |
| noisy_wav_16k = noisy_wav_16k.squeeze_(0) | |
| # 音量をランダマイズする | |
| amplitude = torch.rand(()).item() * 0.899 + 0.1 | |
| factor = amplitude / clean_wav.abs().max() | |
| clean_wav *= factor | |
| noisy_wav_16k *= factor | |
| while noisy_wav_16k.abs().max() >= 1.0: | |
| clean_wav *= 0.5 | |
| noisy_wav_16k *= 0.5 | |
| return clean_wav, noisy_wav_16k, speaker_id, formant_shift | |
| def __len__(self) -> int: | |
| return len(self.audio_files) | |
| def collate( | |
| self, batch: list[tuple[torch.Tensor, torch.Tensor, int, int]] | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| assert self.wav_length % self.out_hop_length == 0 | |
| length = self.wav_length // self.out_hop_length | |
| clean_wavs = [] | |
| noisy_wavs = [] | |
| slice_starts = [] | |
| speaker_ids = [] | |
| formant_shifts = [] | |
| for clean_wav, noisy_wav, speaker_id, formant_shift in batch: | |
| # 発声部分をランダムに 1 箇所選ぶ | |
| (voiced,) = clean_wav.nonzero(as_tuple=True) | |
| assert voiced.numel() != 0 | |
| center = voiced[torch.randint(0, voiced.numel(), ()).item()].item() | |
| # 発声部分が中央にくるように、スライス区間を選ぶ | |
| slice_start = center - self.segment_length * self.out_hop_length // 2 | |
| assert slice_start >= 0 | |
| # スライス区間が含まれるように、ランダムに wav_length の長さを切り出す | |
| r = torch.randint(0, length - self.segment_length + 1, ()).item() | |
| offset = slice_start - r * self.out_hop_length | |
| clean_wavs.append(clean_wav[offset : offset + self.wav_length]) | |
| offset_in_sample_rate = int( | |
| round(offset * self.in_sample_rate / self.out_sample_rate) | |
| ) | |
| noisy_wavs.append( | |
| noisy_wav[ | |
| offset_in_sample_rate : offset_in_sample_rate | |
| + length * self.in_hop_length | |
| ] | |
| ) | |
| slice_start = r | |
| slice_starts.append(slice_start) | |
| speaker_ids.append(speaker_id) | |
| formant_shifts.append(formant_shift) | |
| clean_wavs = torch.stack(clean_wavs) | |
| noisy_wavs = torch.stack(noisy_wavs) | |
| slice_starts = torch.tensor(slice_starts) | |
| speaker_ids = torch.tensor(speaker_ids) | |
| formant_shifts = torch.tensor(formant_shifts) | |
| return ( | |
| clean_wavs, # [batch_size, wav_length] | |
| noisy_wavs, # [batch_size, wav_length] | |
| slice_starts, # Long[batch_size] | |
| speaker_ids, # Long[batch_size] | |
| formant_shifts, # Long[batch_size] | |
| ) | |
| AUDIO_FILE_SUFFIXES = { | |
| ".wav", | |
| ".aif", | |
| ".aiff", | |
| ".fla", | |
| ".flac", | |
| ".oga", | |
| ".ogg", | |
| ".opus", | |
| ".mp3", | |
| } | |
| def get_compressed_optimizer_state_dict( | |
| optimizer: torch.optim.Optimizer, | |
| ) -> dict: | |
| state_dict = {} | |
| for k0, v0 in optimizer.state_dict().items(): | |
| if k0 != "state": | |
| state_dict[k0] = v0 | |
| continue | |
| state_dict[k0] = {} | |
| for k1, v1 in v0.items(): | |
| state_dict[k0][k1] = {} | |
| for k2, v2 in v1.items(): | |
| if isinstance(v2, torch.Tensor): | |
| state_dict[k0][k1][k2] = v2.bfloat16() | |
| assert state_dict[k0][k1][k2].isfinite().all() | |
| else: | |
| state_dict[k0][k1][k2] = v2 | |
| return state_dict | |
| def get_decompressed_optimizer_state_dict(compressed_state_dict: dict) -> dict: | |
| state_dict = {} | |
| for k0, v0 in compressed_state_dict.items(): | |
| if k0 != "state": | |
| state_dict[k0] = v0 | |
| continue | |
| state_dict[k0] = {} | |
| for k1, v1 in v0.items(): | |
| state_dict[k0][k1] = {} | |
| for k2, v2 in v1.items(): | |
| if isinstance(v2, torch.Tensor): | |
| state_dict[k0][k1][k2] = v2.float() | |
| assert state_dict[k0][k1][k2].isfinite().all() | |
| else: | |
| state_dict[k0][k1][k2] = v2 | |
| return state_dict | |
| # ============================================================ | |
| # BEATRICE V2 TRAINING - Embedded (downloads assets from HuggingFace) | |
| # ============================================================ | |
| BEATRICE_AUDIO_FILE_SUFFIXES = {".wav", ".aif", ".aiff", ".fla", ".flac", ".oga", ".ogg", ".opus", ".mp3"} | |
| def preprocess_audio_for_beatrice(audio_path: str, output_dir: str, speaker_name: str = "speaker"): | |
| """Preprocess audio for Beatrice training using silence-based splitting""" | |
| # Create speaker directory structure required by Beatrice | |
| speaker_dir = os.path.join(output_dir, speaker_name) | |
| os.makedirs(speaker_dir, exist_ok=True) | |
| # Load audio at 16kHz (Beatrice input sample rate) | |
| audio, sr = librosa.load(audio_path, sr=16000, mono=True) | |
| # Simple silence-based splitting (RMS threshold) | |
| chunk_size = int(4.0 * sr) # 4 second chunks | |
| hop = int(3.5 * sr) # 0.5s overlap | |
| threshold = 0.01 # RMS threshold | |
| chunks_saved = 0 | |
| for i, start in enumerate(range(0, len(audio) - chunk_size, hop)): | |
| chunk = audio[start:start + chunk_size] | |
| rms = np.sqrt(np.mean(chunk ** 2)) | |
| if rms > threshold: # Skip silence | |
| # Normalize | |
| max_val = np.abs(chunk).max() | |
| if max_val > 0: | |
| chunk = chunk / max_val * 0.9 | |
| chunk_path = os.path.join(speaker_dir, f"{speaker_name}_{chunks_saved:04d}.wav") | |
| sf.write(chunk_path, chunk, sr) | |
| chunks_saved += 1 | |
| logger.info(f"Beatrice preprocessing: {chunks_saved} chunks saved to {speaker_dir}") | |
| return chunks_saved, output_dir | |
| def train_beatrice_generator( | |
| data_dir: str, | |
| output_dir: str, | |
| epochs: int = 30, | |
| batch_size: int = 8, | |
| lr_g: float = 5e-5, | |
| lr_d: float = 5e-5, | |
| use_augmentation: bool = False, | |
| resume: bool = False, | |
| progress_callback=None, | |
| ): | |
| """Train Beatrice v2 model - generator yielding (message, model_path) tuples""" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Download pretrained | |
| yield "Downloading pretrained models...", None | |
| phone_extractor_path = download_beatrice_asset("phone_extractor") | |
| pitch_estimator_path = download_beatrice_asset("pitch_estimator") | |
| pretrained_model_path = download_beatrice_asset("pretrained_model") | |
| # Discover speakers from directory structure | |
| # Expected: data_dir/speaker_name/*.wav | |
| speakers = [] | |
| training_filelist = [] | |
| speaker_audio_files = [] | |
| for speaker_dir in sorted(Path(data_dir).iterdir()): | |
| if not speaker_dir.is_dir(): | |
| continue | |
| candidates = [f for f in sorted(speaker_dir.rglob("*")) | |
| if f.is_file() and f.suffix.lower() in BEATRICE_AUDIO_FILE_SUFFIXES] | |
| if candidates: | |
| speaker_id = len(speakers) | |
| speakers.append(speaker_dir.name) | |
| training_filelist.extend([(f, speaker_id) for f in candidates]) | |
| speaker_audio_files.append(candidates) | |
| n_speakers = len(speakers) | |
| if n_speakers == 0: | |
| yield "Error: No speakers found in data directory", None | |
| return | |
| yield f"Found {n_speakers} speaker(s), {len(training_filelist)} files", None | |
| # Augmentation assets (optional) | |
| noise_files = None | |
| ir_files = None | |
| if use_augmentation: | |
| try: | |
| noise_dir, ir_dir = download_beatrice_augmentation() | |
| if noise_dir and ir_dir: | |
| noise_files = sorted(list(Path(noise_dir).rglob("*.wav")) + list(Path(noise_dir).rglob("*.flac"))) | |
| ir_files = sorted(list(Path(ir_dir).rglob("*.wav")) + list(Path(ir_dir).rglob("*.flac"))) | |
| if noise_files and ir_files: | |
| yield f"Loaded augmentation: {len(noise_files)} noise, {len(ir_files)} IR files", None | |
| else: | |
| noise_files = None | |
| ir_files = None | |
| except Exception as e: | |
| yield f"Warning: Could not load augmentation assets: {e}", None | |
| # Build models | |
| yield "Building models...", None | |
| phone_extractor = PhoneExtractor().to(device).eval().requires_grad_(False) | |
| pe_ckpt = torch.load(phone_extractor_path, map_location="cpu", weights_only=True) | |
| phone_extractor.load_state_dict(pe_ckpt["phone_extractor"], strict=False) | |
| del pe_ckpt | |
| pitch_estimator = PitchEstimator().to(device).eval().requires_grad_(False) | |
| pi_ckpt = torch.load(pitch_estimator_path, map_location="cpu", weights_only=True) | |
| pitch_estimator.load_state_dict(pi_ckpt["pitch_estimator"]) | |
| del pi_ckpt | |
| hidden_channels = 256 | |
| pitch_bins = 448 | |
| net_g = ConverterNetwork( | |
| phone_extractor, pitch_estimator, | |
| n_speakers=n_speakers, | |
| pitch_bins=pitch_bins, | |
| hidden_channels=hidden_channels, | |
| vq_topk=4, | |
| training_time_vq="none", | |
| phone_noise_ratio=0.5, | |
| floor_noise_level=1e-3, | |
| ).to(device) | |
| net_d = BeatriceMultiPeriodDiscriminator(san=True).to(device) | |
| # Optimizers | |
| optim_g = torch.optim.AdamW(net_g.parameters(), lr_g, betas=(0.8, 0.99), eps=1e-6) | |
| optim_d = torch.optim.AdamW(net_d.parameters(), lr_d, betas=(0.8, 0.99), eps=1e-6) | |
| grad_scaler = torch.amp.GradScaler(device.type, enabled=device.type == "cuda") | |
| grad_balancer = GradBalancer( | |
| weights={ | |
| "loss_loudness": 1.0, | |
| "loss_mel": 45.0, | |
| "loss_adv": 1.0, | |
| "loss_fm": 2.0, | |
| }, | |
| ema_decay=0.999, | |
| ) | |
| initial_iteration = 0 | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Load pretrained or resume | |
| if resume: | |
| latest_ckpt = os.path.join(output_dir, "checkpoint_latest.pt.gz") | |
| if os.path.isfile(latest_ckpt): | |
| yield "Resuming from checkpoint...", None | |
| with gzip.open(latest_ckpt, "rb") as f: | |
| ckpt = torch.load(f, map_location="cpu", weights_only=True) | |
| net_g.load_state_dict(ckpt["net_g"], strict=False) | |
| # Filter discriminator for shape mismatches | |
| net_d_state = net_d.state_dict() | |
| filtered_d = {k: v for k, v in ckpt["net_d"].items() | |
| if k in net_d_state and v.shape == net_d_state[k].shape} | |
| net_d.load_state_dict(filtered_d, strict=False) | |
| optim_g.load_state_dict(get_decompressed_optimizer_state_dict(ckpt["optim_g"])) | |
| optim_d.load_state_dict(get_decompressed_optimizer_state_dict(ckpt["optim_d"])) | |
| if "grad_balancer" in ckpt: | |
| grad_balancer.load_state_dict(ckpt["grad_balancer"]) | |
| if "grad_scaler" in ckpt: | |
| grad_scaler.load_state_dict(ckpt["grad_scaler"]) | |
| initial_iteration = ckpt.get("iteration", 0) | |
| del ckpt | |
| else: | |
| yield "No checkpoint found, starting fresh with pretrained", None | |
| resume = False | |
| if not resume: | |
| yield "Loading pretrained weights...", None | |
| with gzip.open(pretrained_model_path, "rb") as f: | |
| pretrained_ckpt = torch.load(f, map_location="cpu", weights_only=True) | |
| # Adapt pretrained for our n_speakers | |
| initial_speaker_emb = pretrained_ckpt["net_g"]["embed_speaker.weight"][:1] | |
| pretrained_ckpt["net_g"]["embed_speaker.weight"] = initial_speaker_emb[[0] * n_speakers] | |
| initial_kv_emb = pretrained_ckpt["net_g"]["key_value_speaker_embedding.weight"][:1] | |
| pretrained_ckpt["net_g"]["key_value_speaker_embedding.weight"] = initial_kv_emb[[0] * n_speakers] | |
| pretrained_ckpt["net_g"]["vq.codebooks"] = pretrained_ckpt["net_g"]["vq.codebooks"][[0] * n_speakers] | |
| net_g.load_state_dict(pretrained_ckpt["net_g"], strict=False) | |
| # Filter discriminator state dict for shape mismatches (pretrained may use san=False) | |
| net_d_state = net_d.state_dict() | |
| filtered_d = {k: v for k, v in pretrained_ckpt["net_d"].items() | |
| if k in net_d_state and v.shape == net_d_state[k].shape} | |
| net_d.load_state_dict(filtered_d, strict=False) | |
| logger.info(f"Loaded {len(filtered_d)}/{len(pretrained_ckpt['net_d'])} discriminator weights") | |
| # Don't load grad_balancer/grad_scaler from pretrained - our loss weights may differ | |
| # These will be re-initialized fresh for fine-tuning | |
| del pretrained_ckpt | |
| # Build VQ codebooks | |
| yield "Building VQ codebooks...", None | |
| def wav_iterator(files): | |
| for file in files: | |
| wav, sr = beatrice_load_audio(file) | |
| wav = wav.to(device) | |
| if sr != 16000: | |
| wav = get_resampler(sr, 16000, str(device))(wav) | |
| yield wav[:, None, :] | |
| if resume: | |
| net_g.enable_hook() | |
| else: | |
| net_g.initialize_vq([wav_iterator(files) for files in speaker_audio_files]) | |
| # Dataset | |
| dataset = WavDataset( | |
| training_filelist, | |
| in_sample_rate=16000, | |
| out_sample_rate=24000, | |
| wav_length=96000, | |
| segment_length=100, | |
| noise_files=noise_files, | |
| ir_files=ir_files, | |
| ) | |
| _num_workers = min(4, os.cpu_count() or 1) | |
| effective_batch = min(batch_size, len(training_filelist)) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| num_workers=_num_workers, | |
| collate_fn=dataset.collate, | |
| shuffle=True, | |
| batch_size=effective_batch, | |
| pin_memory=True, | |
| drop_last=len(training_filelist) > effective_batch, | |
| persistent_workers=_num_workers > 0, | |
| ) | |
| # Calculate steps | |
| steps_per_epoch = max(1, len(training_filelist) // batch_size) | |
| total_steps = epochs * steps_per_epoch | |
| warmup_steps = min(total_steps // 4, 5000) | |
| # LR scheduler with warmup | |
| def lr_lambda(step): | |
| if step < warmup_steps: | |
| return step / max(1, warmup_steps) | |
| return 0.999 ** (step - warmup_steps) | |
| scheduler_g = torch.optim.lr_scheduler.LambdaLR(optim_g, lr_lambda) | |
| scheduler_d = torch.optim.lr_scheduler.LambdaLR(optim_d, lr_lambda) | |
| # Advance schedulers if resuming | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore", message=r"Detected call of `lr_scheduler\.step\(\)") | |
| for _ in range(initial_iteration + 1): | |
| scheduler_g.step() | |
| scheduler_d.step() | |
| net_g.train() | |
| net_d.train() | |
| yield f"Training {total_steps} steps ({epochs} epochs x {steps_per_epoch} steps/epoch)", None | |
| # Training loop | |
| step = initial_iteration | |
| data_iter = None | |
| ckpt_path = None | |
| for epoch in range(epochs): | |
| epoch_loss_g = 0.0 | |
| epoch_loss_d = 0.0 | |
| epoch_steps = 0 | |
| for batch_idx in range(steps_per_epoch): | |
| if data_iter is None: | |
| data_iter = iter(dataloader) | |
| batch = next(data_iter, None) | |
| if batch is None: | |
| data_iter = iter(dataloader) | |
| batch = next(data_iter, None) | |
| if batch is None: | |
| break | |
| clean_wavs, noisy_wavs_16k, slice_starts, speaker_ids, formant_shifts = \ | |
| [x.to(device, non_blocking=True) for x in batch] | |
| with torch.amp.autocast(device.type, enabled=device.type == "cuda"): | |
| # Generator forward | |
| y, y_hat, y_hat_for_backward, loss_loudness, loss_mel, loss_ap, gen_stats = \ | |
| net_g.forward_and_compute_loss( | |
| noisy_wavs_16k[:, None, :], | |
| speaker_ids, | |
| formant_shifts, | |
| slice_start_indices=slice_starts, | |
| slice_segment_length=100, | |
| y_all=clean_wavs[:, None, :], | |
| ) | |
| # Discriminator forward | |
| loss_disc, loss_adv, loss_fm, disc_stats = \ | |
| net_d.forward_and_compute_loss(y, y_hat) | |
| # Discriminator backward | |
| optim_d.zero_grad(set_to_none=True) | |
| grad_scaler.scale(loss_disc).backward(retain_graph=True, inputs=list(net_d.parameters())) | |
| grad_scaler.unscale_(optim_d) | |
| # Generator backward | |
| optim_g.zero_grad(set_to_none=True) | |
| grad_balancer.backward( | |
| {"loss_loudness": loss_loudness, "loss_mel": loss_mel, | |
| "loss_adv": loss_adv, "loss_fm": loss_fm}, | |
| y_hat_for_backward, grad_scaler, | |
| skip_update_ema=step > 10 and step % 5 != 0, | |
| ) | |
| grad_scaler.unscale_(optim_g) | |
| # Update | |
| grad_scaler.step(optim_g) | |
| grad_scaler.step(optim_d) | |
| grad_scaler.update() | |
| optim_g.zero_grad(set_to_none=True) | |
| optim_d.zero_grad(set_to_none=True) | |
| scheduler_g.step() | |
| scheduler_d.step() | |
| epoch_loss_g += loss_mel.item() | |
| epoch_loss_d += loss_disc.item() | |
| epoch_steps += 1 | |
| step += 1 | |
| if progress_callback: | |
| progress_callback(step / total_steps) | |
| avg_loss_g = epoch_loss_g / max(1, epoch_steps) | |
| avg_loss_d = epoch_loss_d / max(1, epoch_steps) | |
| yield f"Epoch {epoch+1}/{epochs} | G loss: {avg_loss_g:.4f} | D loss: {avg_loss_d:.4f} | LR: {scheduler_g.get_last_lr()[0]:.2e}", None | |
| # Save checkpoint periodically | |
| if (epoch + 1) % max(1, epochs // 5) == 0 or epoch == epochs - 1: | |
| ckpt_path = os.path.join(output_dir, f"checkpoint_{step:08d}.pt.gz") | |
| with gzip.open(ckpt_path, "wb") as f: | |
| torch.save({ | |
| "iteration": step, | |
| "net_g": net_g.state_dict(), | |
| "phone_extractor": phone_extractor.state_dict(), | |
| "pitch_estimator": pitch_estimator.state_dict(), | |
| "net_d": {k: v.half() for k, v in net_d.state_dict().items()}, | |
| "optim_g": get_compressed_optimizer_state_dict(optim_g), | |
| "optim_d": get_compressed_optimizer_state_dict(optim_d), | |
| "grad_balancer": grad_balancer.state_dict(), | |
| "grad_scaler": grad_scaler.state_dict(), | |
| "h": { | |
| "hidden_channels": hidden_channels, | |
| "pitch_bins": pitch_bins, | |
| "vq_topk": 4, | |
| "training_time_vq": "none", | |
| "phone_noise_ratio": 0.5, | |
| "floor_noise_level": 1e-3, | |
| "san": True, | |
| }, | |
| "speakers": speakers, | |
| }, f) | |
| shutil.copy(ckpt_path, os.path.join(output_dir, "checkpoint_latest.pt.gz")) | |
| yield f"Saved checkpoint: {ckpt_path}", ckpt_path | |
| # Cleanup | |
| purge_memory(net_g, net_d, optim_g, optim_d, phone_extractor, pitch_estimator) | |
| yield "Training complete!", ckpt_path | |
| def convert_voice_beatrice( | |
| source_audio, | |
| model_file, | |
| target_speaker: int = 0, | |
| pitch_shift: int = 0, | |
| formant_shift: float = 0.0, | |
| progress=None, | |
| ): | |
| """Convert voice using Beatrice v2 model | |
| Args: | |
| source_audio: Path to source audio file | |
| model_file: Path to Beatrice checkpoint (.pt.gz) or file object with .name | |
| target_speaker: Target speaker index | |
| pitch_shift: Pitch shift in semitones | |
| formant_shift: Formant shift in semitones (-2 to 2) | |
| progress: Gradio progress callback | |
| Returns: | |
| (output_path, status_message) tuple | |
| """ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Get model path | |
| if hasattr(model_file, 'name'): | |
| model_path = model_file.name | |
| elif isinstance(model_file, str): | |
| model_path = model_file | |
| else: | |
| return None, "Invalid model file" | |
| if not model_path or not os.path.exists(model_path): | |
| return None, f"Model file not found: {model_path}" | |
| try: | |
| if progress: | |
| progress(0.1, "Loading models...") | |
| # Download pretrained assets (phone extractor + pitch estimator) | |
| phone_extractor_path = download_beatrice_asset("phone_extractor") | |
| pitch_estimator_path = download_beatrice_asset("pitch_estimator") | |
| # Build phone extractor | |
| phone_extractor = PhoneExtractor().to(device).eval().requires_grad_(False) | |
| pe_ckpt = torch.load(phone_extractor_path, map_location="cpu", weights_only=True) | |
| phone_extractor.load_state_dict(pe_ckpt["phone_extractor"], strict=False) | |
| del pe_ckpt | |
| # Build pitch estimator | |
| pitch_estimator = PitchEstimator().to(device).eval().requires_grad_(False) | |
| pi_ckpt = torch.load(pitch_estimator_path, map_location="cpu", weights_only=True) | |
| pitch_estimator.load_state_dict(pi_ckpt["pitch_estimator"]) | |
| del pi_ckpt | |
| if progress: | |
| progress(0.3, "Loading trained model...") | |
| # Load trained checkpoint | |
| with gzip.open(model_path, "rb") as f: | |
| checkpoint = torch.load(f, map_location="cpu", weights_only=True) | |
| # Determine model params from checkpoint | |
| n_speakers = checkpoint["net_g"]["embed_speaker.weight"].shape[0] | |
| h = checkpoint.get("h", {}) | |
| hidden_channels = h.get("hidden_channels", 256) | |
| pitch_bins = h.get("pitch_bins", 448) | |
| speakers = checkpoint.get("speakers", [f"Speaker {i}" for i in range(n_speakers)]) | |
| if target_speaker >= n_speakers: | |
| target_speaker = 0 | |
| net_g = ConverterNetwork( | |
| phone_extractor, pitch_estimator, | |
| n_speakers=n_speakers, | |
| pitch_bins=pitch_bins, | |
| hidden_channels=hidden_channels, | |
| vq_topk=h.get("vq_topk", 4), | |
| training_time_vq=h.get("training_time_vq", "none"), | |
| phone_noise_ratio=h.get("phone_noise_ratio", 0.5), | |
| floor_noise_level=h.get("floor_noise_level", 1e-3), | |
| ).to(device).eval() | |
| net_g.load_state_dict(checkpoint["net_g"], strict=False) | |
| net_g.enable_hook() | |
| del checkpoint | |
| if progress: | |
| progress(0.5, "Converting voice...") | |
| # Load audio at 16kHz | |
| audio_path = source_audio if isinstance(source_audio, str) else source_audio | |
| audio, sr = librosa.load(audio_path, sr=16000, mono=True) | |
| audio_tensor = torch.from_numpy(audio).float().unsqueeze(0).unsqueeze(0).to(device) | |
| # Pad to multiple of 160 (phone extractor stride) | |
| original_length = audio_tensor.shape[-1] | |
| if original_length % 160 != 0: | |
| pad_len = 160 - original_length % 160 | |
| audio_tensor = F.pad(audio_tensor, (0, pad_len)) | |
| # Convert | |
| with torch.inference_mode(): | |
| y_hat = net_g( | |
| audio_tensor, | |
| torch.tensor([target_speaker], device=device), | |
| torch.tensor([formant_shift], device=device), | |
| torch.tensor([float(pitch_shift)], device=device), | |
| ) | |
| # Output is 24kHz, trim to match input duration | |
| output_length = original_length // 160 * 240 # 16kHz→24kHz frame ratio | |
| output = y_hat.squeeze().cpu().numpy()[:output_length] | |
| # Save | |
| fd, output_path = tempfile.mkstemp(suffix=".wav") | |
| os.close(fd) | |
| sf.write(output_path, output, 24000) | |
| # Cleanup | |
| purge_memory(net_g, phone_extractor, pitch_estimator) | |
| speaker_name = speakers[target_speaker] if target_speaker < len(speakers) else f"Speaker {target_speaker}" | |
| return output_path, f"Converted using Beatrice v2 | 24kHz | Speaker: {speaker_name} | Pitch: {pitch_shift:+d} | Formant: {formant_shift:+.1f}" | |
| except Exception as e: | |
| logger.exception("Beatrice inference error") | |
| return None, f"Error: {str(e)}" | |
| # ============================================================ | |
| # GRADIO UI - Gradio 6 Compatible | |
| # ============================================================ | |
| def train_ui( | |
| audio_file, | |
| model_name: str, | |
| epochs: int, | |
| batch_size: int, | |
| sample_rate: int, | |
| f0_method: str = "rmvpe", | |
| progress=gr.Progress() | |
| ): | |
| """Training function for Gradio UI - Generator for live log updates""" | |
| if audio_file is None: | |
| yield None, None, "❌ Please upload training audio" | |
| return | |
| if not model_name or model_name.strip() == "": | |
| yield None, None, "❌ Please enter a model name" | |
| return | |
| # Check if CUDA available | |
| has_cuda = torch.cuda.is_available() | |
| device_info = "GPU (CUDA)" if has_cuda else "CPU" | |
| # Log accumulator for live updates | |
| logs = [] | |
| try: | |
| model_name = sanitize_model_name(model_name) | |
| output_dir = f"trained_models/{model_name}" | |
| data_dir = f"{output_dir}/data" | |
| # Preprocessing phase | |
| logs.append(f"🚀 Starting on {device_info}") | |
| logs.append(f"📂 Output: {output_dir}") | |
| yield None, None, "\n".join(logs) | |
| progress(0.1, "Preprocessing...") | |
| logs.append("🔄 Preprocessing audio...") | |
| yield None, None, "\n".join(logs) | |
| audio_path = audio_file if isinstance(audio_file, str) else audio_file.name | |
| result = preprocess_audio_for_training(audio_path, data_dir, target_sr=sample_rate, f0_method=f0_method) | |
| if result is None: | |
| logs.append("❌ Preprocessing failed - no valid audio chunks") | |
| yield None, None, "\n".join(logs) | |
| return | |
| logs.append("✅ Preprocessing complete") | |
| logs.append(f"🏋️ Training {epochs} epochs...") | |
| logs.append("─" * 40) | |
| yield None, None, "\n".join(logs) | |
| # Training phase - iterate over generator for live updates | |
| ckpt = None | |
| idx = None | |
| for msg, path, index in train_rvc_generator( | |
| data_dir=data_dir, | |
| output_dir=output_dir, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| lr=1e-5, | |
| target_sr=sample_rate, | |
| progress_callback=progress | |
| ): | |
| logs.append(msg) | |
| yield None, None, "\n".join(logs) | |
| if path: | |
| ckpt = path | |
| if index: | |
| idx = index | |
| if ckpt: | |
| logs.append("─" * 40) | |
| logs.append(f"✅ Training complete!") | |
| logs.append(f"📦 Model: {ckpt}") | |
| if idx: | |
| logs.append(f"📦 Index: {idx}") | |
| progress(1.0, "Done!") | |
| yield ckpt, idx, "\n".join(logs) | |
| else: | |
| logs.append("❌ Training failed") | |
| yield None, None, "\n".join(logs) | |
| except Exception as e: | |
| logger.exception("Training error") | |
| logs.append(f"❌ Error: {str(e)}") | |
| yield None, None, "\n".join(logs) | |
| # ============================================================ | |
| # SPLIT-AND-STITCH BACKGROUND PROCESSOR | |
| # | |
| # Architecture rationale: | |
| # Heavy RVC inference on a 2-core CPU can push 10–14 GB RAM for | |
| # long audio. The solution is to: | |
| # 1. Chop the input into CHUNK_SEC-second slices with a | |
| # OVERLAP_SEC-second overlap on each side. The overlap gives | |
| # the vocoder context at boundaries so there are no clicks. | |
| # 2. Process each chunk independently through the full RVC | |
| # pipeline (RMVPE + FAISS + vocoder). After each chunk, | |
| # call purge_memory() so RAM never accumulates. | |
| # 3. Cross-fade adjacent chunks over the overlap region using a | |
| # linear fade-out/fade-in (equal-power is not needed here | |
| # because RVC output is already bandlimited). The crossfade | |
| # is the only "small render" step — it's just numpy slicing | |
| # and a linear ramp, not another model pass. | |
| # 4. Concatenate the stitched segments and write the final WAV. | |
| # | |
| # Non-negotiable settings preserved: | |
| # - f0_method is always forwarded as-is (RMVPE by default). | |
| # - index_file is always forwarded (FAISS stays active). | |
| # - Model sample rate (40k/48k) is respected — chunks are saved | |
| # at 16k for processing and the stitched output is at tgt_sr. | |
| # ============================================================ | |
| CHUNK_SEC = 30 # seconds per chunk fed to RVC (keeps RAM under ~4 GB/chunk) | |
| OVERLAP_SEC = 0.4 # seconds of overlap for crossfade seam (at tgt_sr) | |
| MIN_CHUNK_SEC = 5 # don't split files shorter than this | |
| def _crossfade_join(seg_a: np.ndarray, seg_b: np.ndarray, | |
| overlap_samples: int) -> np.ndarray: | |
| """ | |
| Overlap-add two mono float32 segments. | |
| seg_a: …audio… [overlap_samples tail] | |
| seg_b: [overlap_samples head] …audio… | |
| Returns the seamlessly stitched result. | |
| """ | |
| if overlap_samples <= 0 or len(seg_a) < overlap_samples or len(seg_b) < overlap_samples: | |
| return np.concatenate([seg_a, seg_b]) | |
| fade_out = np.linspace(1.0, 0.0, overlap_samples, dtype=np.float32) | |
| fade_in = np.linspace(0.0, 1.0, overlap_samples, dtype=np.float32) | |
| blended = seg_a[-overlap_samples:] * fade_out + seg_b[:overlap_samples] * fade_in | |
| return np.concatenate([seg_a[:-overlap_samples], blended, seg_b[overlap_samples:]]) | |
| def convert_voice_chunked( | |
| source_audio, | |
| model_file, | |
| index_file=None, | |
| pitch_shift: int = 0, | |
| f0_method: str = "rmvpe", # RMVPE is non-negotiable — forwarded as-is | |
| index_rate: float = 0.75, # FAISS active — forwarded as-is | |
| protect: float = 0.33, | |
| volume_envelope: float = 1.0, | |
| progress=None, | |
| chunk_sec: int = CHUNK_SEC, | |
| overlap_sec: float = OVERLAP_SEC, | |
| ) -> Tuple[str, str]: | |
| """ | |
| Split-and-stitch wrapper around convert_voice(). | |
| For audio shorter than MIN_CHUNK_SEC seconds, falls straight | |
| through to convert_voice() with no splitting overhead. | |
| For longer audio: | |
| • Loads the full waveform once at 16 kHz (source SR for RVC). | |
| • Slices into overlapping chunks at the 16k level. | |
| • Each chunk is written to a temp WAV, run through convert_voice(), | |
| and the result is loaded back as a numpy array. | |
| • purge_memory() is called after every chunk so RAM is returned | |
| to the OS via malloc_trim() before the next chunk starts. | |
| • Chunks are crossfaded and concatenated. | |
| • The final stitched audio is written to a single output WAV. | |
| """ | |
| if source_audio is None: | |
| return None, "Please upload source audio" | |
| if model_file is None: | |
| return None, "Please upload RVC model (.pth)" | |
| # ------------------------------------------------------------------ | |
| # 1. Load source audio once to check duration | |
| # ------------------------------------------------------------------ | |
| try: | |
| audio_full, _ = librosa.load(source_audio, sr=16000, mono=True) | |
| except Exception as e: | |
| return None, f"Failed to load audio: {e}" | |
| duration_sec = len(audio_full) / 16000.0 | |
| # Short file — no splitting needed, avoids overhead | |
| if duration_sec <= MIN_CHUNK_SEC: | |
| purge_memory(audio_full) | |
| return convert_voice( | |
| source_audio, model_file, index_file, | |
| pitch_shift, f0_method, index_rate, protect, volume_envelope, | |
| progress if progress is not None else gr.Progress() | |
| ) | |
| # ------------------------------------------------------------------ | |
| # 2. Determine chunk boundaries (in 16k samples) | |
| # ------------------------------------------------------------------ | |
| sr_in = 16000 | |
| chunk_samp = int(chunk_sec * sr_in) | |
| overlap_samp = int(overlap_sec * sr_in) | |
| hop_samp = chunk_samp - overlap_samp # non-overlapping step | |
| starts = list(range(0, len(audio_full), hop_samp)) | |
| n_chunks = len(starts) | |
| logger.info(f"[Chunked] {duration_sec:.1f}s → {n_chunks} chunks " | |
| f"({chunk_sec}s each, {overlap_sec}s overlap)") | |
| # ------------------------------------------------------------------ | |
| # 3. Process each chunk through convert_voice() | |
| # ------------------------------------------------------------------ | |
| stitched_segments: list[np.ndarray] = [] | |
| tgt_sr = None # learned from first chunk output | |
| overlap_out_samp = 0 # overlap at target SR (learned after first chunk) | |
| for i, start in enumerate(starts): | |
| end = min(start + chunk_samp, len(audio_full)) | |
| chunk = audio_full[start:end] | |
| if progress is not None: | |
| try: | |
| progress((i + 0.5) / n_chunks, | |
| f"Processing chunk {i+1}/{n_chunks}…") | |
| except Exception: | |
| pass | |
| # Write chunk to temp WAV | |
| fd, chunk_path = tempfile.mkstemp(suffix=".wav") | |
| os.close(fd) | |
| sf.write(chunk_path, chunk, sr_in) | |
| try: | |
| # Run full RVC pipeline on this chunk (RMVPE + FAISS inside) | |
| out_path, status = convert_voice( | |
| chunk_path, model_file, index_file, | |
| pitch_shift, f0_method, index_rate, protect, volume_envelope, | |
| # Suppress inner progress — we own the bar | |
| gr.Progress() if progress is None else progress | |
| ) | |
| except Exception as e: | |
| logger.warning(f"[Chunked] Chunk {i+1} failed: {e}") | |
| # Remove temp and skip — silence gap is better than crash | |
| try: os.unlink(chunk_path) | |
| except OSError: pass | |
| purge_memory(chunk) | |
| continue | |
| finally: | |
| try: os.unlink(chunk_path) | |
| except OSError: pass | |
| if out_path is None: | |
| logger.warning(f"[Chunked] Chunk {i+1} returned no output: {status}") | |
| purge_memory(chunk) | |
| continue | |
| # Load the converted chunk | |
| chunk_out, chunk_sr = sf.read(out_path, dtype="float32") | |
| try: os.unlink(out_path) | |
| except OSError: pass | |
| # Learn target SR from first successful chunk | |
| if tgt_sr is None: | |
| tgt_sr = chunk_sr | |
| # Scale overlap to target SR | |
| overlap_out_samp = int(overlap_sec * tgt_sr) | |
| stitched_segments.append(chunk_out) | |
| # Critical: free everything before next chunk loads the model | |
| purge_memory(chunk, chunk_out) | |
| # ------------------------------------------------------------------ | |
| # 4. Stitch segments with crossfade | |
| # ------------------------------------------------------------------ | |
| if not stitched_segments: | |
| return None, "All chunks failed — no output produced" | |
| result = stitched_segments[0] | |
| for seg in stitched_segments[1:]: | |
| result = _crossfade_join(result, seg, overlap_out_samp) | |
| # Final normalization (matches convert_voice behaviour) | |
| audio_max = np.abs(result).max() / 0.99 | |
| if audio_max > 1.0: | |
| result /= audio_max | |
| # ------------------------------------------------------------------ | |
| # 5. Write stitched output | |
| # ------------------------------------------------------------------ | |
| final_sr = tgt_sr if tgt_sr is not None else 40000 | |
| fd, output_path = tempfile.mkstemp(suffix=".wav") | |
| os.close(fd) | |
| sf.write(output_path, result, final_sr) | |
| purge_memory(result, audio_full) | |
| logger.info(f"[Chunked] Stitched output: {len(stitched_segments)} chunks → {output_path}") | |
| return output_path, ( | |
| f"Chunked conversion: {n_chunks} chunks stitched | " | |
| f"sr={final_sr} | pitch={pitch_shift:+d} | f0={f0_method}" | |
| ) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"# 🎤 Voice Conversion (RVC + Beatrice)\nInference: CPU • Training: {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}") | |
| with gr.Tabs(): | |
| # ==================== TAB 1: VOICE CONVERSION ==================== | |
| with gr.Tab("🎵 Voice Conversion"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| source_audio = gr.Audio(label="Source Audio", type="filepath") | |
| gr.Markdown("### Model") | |
| model_type = gr.Radio( | |
| ["RVC v2", "Beatrice v2"], value="RVC v2", | |
| label="Model Type", | |
| info="RVC: .pth files | Beatrice: .pt.gz files" | |
| ) | |
| # RVC model inputs | |
| with gr.Group(visible=True) as rvc_model_group: | |
| _available_models = list_rvc_models() | |
| rvc_model_dropdown = gr.Dropdown( | |
| choices=_available_models, | |
| label="Select Voice Model", | |
| info="Models auto-loaded from weights/ folder", | |
| value="model_kunni.pth" if "model_kunni.pth" in _available_models else (_available_models[0] if _available_models else None), | |
| ) | |
| with gr.Row(): | |
| model_file = gr.File(label="OR Upload New (.pth)", file_types=[".pth"]) | |
| load_example_btn = gr.Button("Load Example (Benee)", size="sm") | |
| index_file = gr.File(label="Index File (.index) - Optional", file_types=[".index"]) | |
| # Beatrice model inputs | |
| with gr.Group(visible=False) as beatrice_model_group: | |
| beatrice_model_file = gr.File(label="Beatrice Model (.pt.gz)", file_types=[".gz"]) | |
| with gr.Row(): | |
| beatrice_target_speaker = gr.Number(value=0, label="Target Speaker", precision=0) | |
| beatrice_formant_shift = gr.Slider(-2, 2, value=0.0, step=0.5, label="Formant Shift") | |
| with gr.Row(): | |
| pitch_shift = gr.Slider(-12, 12, value=0, step=1, label="Pitch (semitones)") | |
| f0_method = gr.Radio(["rmvpe", "pm", "harvest"], value="rmvpe", label="F0 Method", visible=True) | |
| with gr.Row(visible=True) as rvc_extra_options: | |
| index_rate = gr.Slider(0, 1, value=0.75, step=0.05, label="Index Rate") | |
| protect = gr.Slider(0, 0.5, value=0.33, step=0.01, label="Protect (voiceless consonants)") | |
| convert_btn = gr.Button("Convert", variant="primary") | |
| with gr.Column(): | |
| output_audio = gr.Audio(label="Converted Audio", type="filepath") | |
| output_info = gr.Textbox(label="Status", lines=2) | |
| def update_model_type(model_type_val): | |
| is_rvc = model_type_val == "RVC v2" | |
| return ( | |
| gr.update(visible=is_rvc), # rvc_model_group | |
| gr.update(visible=not is_rvc), # beatrice_model_group | |
| gr.update(visible=is_rvc), # f0_method | |
| gr.update(visible=is_rvc), # rvc_extra_options | |
| ) | |
| model_type.change( | |
| update_model_type, | |
| [model_type], | |
| [rvc_model_group, beatrice_model_group, f0_method, rvc_extra_options] | |
| ) | |
| load_example_btn.click( | |
| load_example_model, | |
| [], | |
| [model_file, index_file, output_info] | |
| ) | |
| def convert_unified(source, m_type, rvc_dropdown_model, rvc_model, rvc_index, beat_model, | |
| beat_speaker, beat_formant, pitch, f0, idx_rate, prot, | |
| progress=gr.Progress()): | |
| if m_type == "RVC v2": | |
| # Resolve model: prefer uploaded file, fall back to dropdown selection | |
| resolved_model = rvc_model | |
| resolved_index = rvc_index | |
| if resolved_model is None and rvc_dropdown_model: | |
| pth_path = Path("weights") / rvc_dropdown_model | |
| # Wrap in a simple object so convert_voice can call .name on it | |
| class _FileObj: | |
| def __init__(self, p): self.name = str(p) | |
| resolved_model = _FileObj(pth_path) | |
| # Auto-hunt for matching .index in weights/ | |
| if resolved_index is None: | |
| stem = pth_path.stem | |
| index_path = Path("weights") / f"{stem}.index" | |
| if index_path.exists(): | |
| resolved_index = _FileObj(index_path) | |
| # Use the split-and-stitch processor for RVC. | |
| # It falls through to plain convert_voice() for short clips | |
| # and auto-chunks long audio to prevent OOM on HF Spaces. | |
| return convert_voice_chunked(source, resolved_model, resolved_index, pitch, f0, idx_rate, prot, progress=progress) | |
| else: | |
| return convert_voice_beatrice( | |
| source, beat_model, | |
| target_speaker=int(beat_speaker), | |
| pitch_shift=int(pitch), | |
| formant_shift=float(beat_formant), | |
| progress=progress | |
| ) | |
| convert_btn.click( | |
| convert_unified, | |
| [source_audio, model_type, rvc_model_dropdown, model_file, index_file, beatrice_model_file, | |
| beatrice_target_speaker, beatrice_formant_shift, | |
| pitch_shift, f0_method, index_rate, protect], | |
| [output_audio, output_info], | |
| api_name="convert", | |
| concurrency_limit=1, | |
| ) | |
| gr.Markdown("**Models:** [HuggingFace](https://huggingface.co/models?search=rvc) | [Weights.gg](https://weights.gg)") | |
| # ==================== TAB 2: TRAINING ==================== | |
| with gr.Tab("🏋️ Training"): | |
| # GPU Warning | |
| gpu_status = "🟢 GPU Available (CUDA)" if torch.cuda.is_available() else "🟡 CPU Only (Training will be slow)" | |
| gr.Markdown(f""" | |
| ### Training Status: {gpu_status} | |
| {'**GPU detected!** Training will use CUDA acceleration.' if torch.cuda.is_available() else '**No GPU detected.** Training will run on CPU (~30 sec/epoch). For faster training, run locally with CUDA GPU.'} | |
| """) | |
| # Trainer selector - Beatrice always available (downloads assets from HF) | |
| trainer_selector = gr.Dropdown( | |
| choices=["RVC v2", "Beatrice v2"], | |
| value="RVC v2", | |
| label="Trainer", | |
| info="RVC v2: general purpose | Beatrice v2: low latency (~50ms)" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| train_audio = gr.Audio(label="Training Audio", type="filepath") | |
| train_model_name = gr.Textbox(label="Model Name", placeholder="my_voice", value="my_voice") | |
| # RVC-specific options | |
| with gr.Group(visible=True) as rvc_options: | |
| with gr.Row(): | |
| train_epochs = gr.Slider(1, 500, value=50, step=1, label="Epochs") | |
| train_batch = gr.Slider(1, 8, value=2, step=1, label="Batch Size") | |
| with gr.Row(): | |
| train_sr = gr.Radio([32000, 40000, 48000], value=40000, label="Sample Rate") | |
| train_f0 = gr.Radio(["rmvpe", "pm", "harvest"], value="rmvpe", label="F0 Method") | |
| # Beatrice-specific options | |
| with gr.Group(visible=False) as beatrice_options: | |
| with gr.Row(): | |
| beatrice_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs (30 recommended)") | |
| beatrice_batch = gr.Slider(1, 64, value=8, step=1, label="Batch Size") | |
| beatrice_resume = gr.Checkbox(label="Resume from checkpoint", value=False) | |
| train_btn = gr.Button("Start Training", variant="primary") | |
| with gr.Column(): | |
| train_output_model = gr.File(label="Trained Model (.pth)") | |
| train_output_index = gr.File(label="Index File (.index)") | |
| train_status = gr.Textbox(label="Training Status", lines=6) | |
| # Toggle visibility based on trainer selection | |
| def update_trainer_options(trainer): | |
| is_rvc = trainer == "RVC v2" | |
| return gr.update(visible=is_rvc), gr.update(visible=not is_rvc) | |
| trainer_selector.change( | |
| update_trainer_options, | |
| [trainer_selector], | |
| [rvc_options, beatrice_options] | |
| ) | |
| # Unified training function | |
| def train_unified(trainer, audio, name, rvc_epochs, rvc_batch, rvc_sr, rvc_f0, | |
| beat_epochs, beat_batch, beat_resume, progress=gr.Progress()): | |
| if trainer == "RVC v2": | |
| # Use generator for live updates (yields 3-tuples: model, index, status) | |
| for result in train_ui(audio, name, rvc_epochs, rvc_batch, rvc_sr, rvc_f0, progress): | |
| yield result | |
| else: | |
| # Beatrice training (embedded - downloads assets from HF) | |
| if audio is None: | |
| yield None, None, "❌ Please upload training audio" | |
| return | |
| if not name or name.strip() == "": | |
| yield None, None, "❌ Please enter a model name" | |
| return | |
| try: | |
| name = sanitize_model_name(name) | |
| output_dir = f"trained_models/beatrice_{name}" | |
| data_dir = f"{output_dir}/training_data" | |
| logs = [] | |
| # Preprocess audio into speaker chunks | |
| logs.append(f"🚀 Starting Beatrice training on {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}") | |
| yield None, None, "\n".join(logs) | |
| progress(0.05, "Preprocessing audio...") | |
| audio_path = audio if isinstance(audio, str) else audio.name | |
| chunks, _ = preprocess_audio_for_beatrice(audio_path, data_dir, name) | |
| if chunks == 0: | |
| yield None, None, "❌ Preprocessing failed - no valid audio chunks" | |
| return | |
| logs.append(f"✅ Preprocessed {chunks} audio chunks") | |
| logs.append("─" * 40) | |
| yield None, None, "\n".join(logs) | |
| # Train using embedded generator | |
| ckpt = None | |
| for msg, path in train_beatrice_generator( | |
| data_dir=data_dir, | |
| output_dir=output_dir, | |
| epochs=beat_epochs, | |
| batch_size=beat_batch, | |
| resume=beat_resume, | |
| progress_callback=progress | |
| ): | |
| logs.append(msg) | |
| yield None, None, "\n".join(logs) | |
| if path: | |
| ckpt = path | |
| if ckpt: | |
| logs.append("─" * 40) | |
| logs.append(f"✅ Training complete!") | |
| logs.append(f"📦 Model: {ckpt}") | |
| progress(1.0, "Done!") | |
| yield ckpt, None, "\n".join(logs) | |
| else: | |
| logs.append("❌ Training failed") | |
| yield None, None, "\n".join(logs) | |
| except Exception as e: | |
| logger.exception("Beatrice training error") | |
| yield None, None, f"❌ Error: {str(e)}" | |
| train_btn.click( | |
| train_unified, | |
| [trainer_selector, train_audio, train_model_name, | |
| train_epochs, train_batch, train_sr, train_f0, | |
| beatrice_epochs, beatrice_batch, beatrice_resume], | |
| [train_output_model, train_output_index, train_status], | |
| api_name="train", | |
| concurrency_limit=1, | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### Training Tips | |
| **RVC v2:** | |
| - 50-100 epochs for quick test, 200-500 for quality | |
| - CPU training: ~30 sec/epoch (100 epochs ≈ 50 min) | |
| **Beatrice v2:** | |
| - 20-50 epochs recommended, GPU recommended (CPU works but slow) | |
| - Pretrained assets downloaded automatically from HuggingFace | |
| - Output: .pt.gz checkpoint (use in Voice Conversion tab) | |
| - Lower latency (~50ms vs ~100ms) | |
| ### CLI | |
| ```bash | |
| python app.py train -a voice.mp3 -o ./model --epochs 100 | |
| python app.py train-beatrice -a voice.mp3 -o ./beatrice_model --epochs 30 | |
| python app.py infer -i input.wav -m beatrice_model.pt.gz -o output.wav | |
| ``` | |
| """) | |
| def cli_convert(args): | |
| """CLI mode conversion - supports both RVC and Beatrice models""" | |
| print(f"Converting: {args.input}") | |
| print(f"Model: {args.model}") | |
| # Auto-detect model type from extension or --type flag | |
| model_type = getattr(args, 'type', None) | |
| model_path = args.model | |
| index_path = getattr(args, 'index', None) | |
| if args.example: | |
| print("Downloading example model...") | |
| model_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_MODEL_FILE) | |
| index_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_INDEX_FILE) | |
| model_type = "rvc" | |
| print(f"Model: {model_path}") | |
| if not model_path: | |
| print("Error: No model specified. Use -m MODEL.pth or --example") | |
| sys.exit(1) | |
| # Auto-detect type from extension if not specified | |
| if model_type is None: | |
| if model_path.endswith('.pt.gz') or model_path.endswith('.gz'): | |
| model_type = "beatrice" | |
| else: | |
| model_type = "rvc" | |
| if model_type == "beatrice": | |
| # Beatrice inference | |
| print(f"Using Beatrice v2 inference") | |
| output_path, status = convert_voice_beatrice( | |
| source_audio=args.input, | |
| model_file=model_path, | |
| target_speaker=getattr(args, 'speaker', 0), | |
| pitch_shift=args.pitch, | |
| formant_shift=getattr(args, 'formant_shift', 0.0), | |
| ) | |
| else: | |
| # RVC inference | |
| class FileObj: | |
| def __init__(self, path): | |
| self.name = path | |
| model_file = FileObj(model_path) | |
| index_file = FileObj(index_path) if index_path else None | |
| output_path, status = convert_voice( | |
| source_audio=args.input, | |
| model_file=model_file, | |
| index_file=index_file, | |
| pitch_shift=args.pitch, | |
| f0_method=args.f0, | |
| index_rate=args.index_rate, | |
| progress=lambda *a, **k: None | |
| ) | |
| if output_path: | |
| shutil.copy(output_path, args.output) | |
| print(f"Output: {args.output}") | |
| print(status) | |
| else: | |
| print(f"Failed: {status}") | |
| sys.exit(1) | |
| def cli_train_beatrice(args): | |
| """CLI Beatrice training mode (embedded - downloads assets from HF)""" | |
| print(f"=== Beatrice v2 Training (Embedded) ===") | |
| print(f"Input audio: {args.audio}") | |
| print(f"Output dir: {args.output}") | |
| # Preprocess audio - Beatrice expects: data_dir/speaker_name/*.wav | |
| data_dir = os.path.join(args.output, "training_data") | |
| speaker_name = os.path.splitext(os.path.basename(args.audio))[0] | |
| print(f"\n[1/2] Preprocessing audio for Beatrice...") | |
| chunks, _ = preprocess_audio_for_beatrice(args.audio, data_dir, speaker_name) | |
| if chunks == 0: | |
| print("Preprocessing failed - no valid audio chunks") | |
| sys.exit(1) | |
| print(f"Created {chunks} audio chunks") | |
| # Train using embedded code | |
| print(f"\n[2/2] Training Beatrice model ({args.epochs} epochs)...") | |
| ckpt = None | |
| for msg, path in train_beatrice_generator( | |
| data_dir=data_dir, | |
| output_dir=args.output, | |
| epochs=args.epochs, | |
| batch_size=args.batch, | |
| resume=args.resume, | |
| ): | |
| print(msg) | |
| if path: | |
| ckpt = path | |
| if ckpt: | |
| print(f"\nTraining complete!") | |
| print(f"Model: {ckpt}") | |
| else: | |
| print("Training failed!") | |
| sys.exit(1) | |
| def cli_train(args): | |
| """CLI training mode""" | |
| print(f"=== RVC Training ===") | |
| print(f"Input audio: {args.audio}") | |
| print(f"Output dir: {args.output}") | |
| # Create temp dir for preprocessing | |
| data_dir = f"{args.output}/data" | |
| # Preprocess | |
| print("\n[1/2] Preprocessing audio...") | |
| result = preprocess_audio_for_training(args.audio, data_dir, target_sr=args.sr, f0_method=args.f0) | |
| if result is None: | |
| print("Preprocessing failed!") | |
| sys.exit(1) | |
| # Train | |
| print(f"\n[2/2] Training for {args.epochs} epochs...") | |
| ckpt, idx = train_rvc( | |
| data_dir=data_dir, | |
| output_dir=args.output, | |
| epochs=args.epochs, | |
| batch_size=args.batch, | |
| lr=args.lr, | |
| target_sr=args.sr | |
| ) | |
| if ckpt: | |
| print(f"\nTraining complete!") | |
| print(f"Model saved: {ckpt}") | |
| if idx: | |
| print(f"Index saved: {idx}") | |
| else: | |
| print("Training failed!") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| # Check if any CLI args (besides script name) | |
| if len(sys.argv) > 1: | |
| parser = argparse.ArgumentParser( | |
| description="RVC Voice Conversion - Inference and Training", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| ) | |
| subparsers = parser.add_subparsers(dest="command", help="Commands") | |
| # Inference subcommand | |
| infer_parser = subparsers.add_parser("infer", help="Voice conversion inference (RVC + Beatrice)", | |
| epilog=""" | |
| Examples: | |
| python app.py infer -i voice.wav -m model.pth -o output.wav | |
| python app.py infer -i voice.wav -m beatrice_model.pt.gz -o output.wav --type beatrice | |
| python app.py infer -i voice.wav --example -o output.wav | |
| """) | |
| infer_parser.add_argument("-i", "--input", required=True, help="Input audio file") | |
| infer_parser.add_argument("-o", "--output", required=True, help="Output audio file") | |
| infer_parser.add_argument("-m", "--model", help="Model file (.pth for RVC, .pt.gz for Beatrice)") | |
| infer_parser.add_argument("--type", choices=["rvc", "beatrice"], default=None, help="Model type (auto-detected from extension)") | |
| infer_parser.add_argument("--index", help="Index file (.index) - RVC only") | |
| infer_parser.add_argument("--example", action="store_true", help="Use example model (Benee-RVC)") | |
| infer_parser.add_argument("-p", "--pitch", type=int, default=0, help="Pitch shift (-12 to 12)") | |
| infer_parser.add_argument("--f0", choices=["rmvpe", "pm", "harvest"], default="rmvpe", help="F0 method (RVC only)") | |
| infer_parser.add_argument("--index-rate", type=float, default=0.75, help="Index rate 0-1 (RVC only)") | |
| infer_parser.add_argument("--speaker", type=int, default=0, help="Target speaker index (Beatrice only)") | |
| infer_parser.add_argument("--formant-shift", type=float, default=0.0, help="Formant shift -2 to 2 (Beatrice only)") | |
| # Training subcommand (RVC) | |
| train_parser = subparsers.add_parser("train", help="Train RVC model", | |
| epilog=""" | |
| Examples: | |
| python app.py train -a voice.mp3 -o ./my_model --epochs 5 | |
| python app.py train -a dataset.wav -o ./trained --epochs 10 --batch 4 | |
| """) | |
| train_parser.add_argument("-a", "--audio", required=True, help="Training audio file") | |
| train_parser.add_argument("-o", "--output", required=True, help="Output directory for model") | |
| train_parser.add_argument("--epochs", type=int, default=5, help="Number of epochs (default: 5)") | |
| train_parser.add_argument("--batch", type=int, default=2, help="Batch size (default: 2)") | |
| train_parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate (default: 1e-5)") | |
| train_parser.add_argument("--sr", type=int, default=40000, help="Sample rate (default: 40000)") | |
| train_parser.add_argument("--f0", choices=["rmvpe", "pm", "harvest"], default="rmvpe", help="F0 method (default: rmvpe)") | |
| # Beatrice v2 Training subcommand | |
| beatrice_parser = subparsers.add_parser("train-beatrice", help="Train Beatrice v2 model (downloads assets from HF)", | |
| epilog=""" | |
| Examples: | |
| python app.py train-beatrice -a voice.mp3 -o ./beatrice_model --epochs 20 | |
| python app.py train-beatrice -a dataset.wav -o ./trained --epochs 50 --batch 8 --resume | |
| Pretrained assets are downloaded automatically from HuggingFace. | |
| """) | |
| beatrice_parser.add_argument("-a", "--audio", required=True, help="Training audio file") | |
| beatrice_parser.add_argument("-o", "--output", required=True, help="Output directory for model") | |
| beatrice_parser.add_argument("--epochs", type=int, default=20, help="Number of epochs (default: 20)") | |
| beatrice_parser.add_argument("--batch", type=int, default=24, help="Batch size (default: 24, reduce for less VRAM)") | |
| beatrice_parser.add_argument("--resume", action="store_true", help="Resume from checkpoint") | |
| args = parser.parse_args() | |
| if args.command == "infer": | |
| cli_convert(args) | |
| elif args.command == "train": | |
| cli_train(args) | |
| elif args.command == "train-beatrice": | |
| cli_train_beatrice(args) | |
| else: | |
| parser.print_help() | |
| else: | |
| # ============================================================ | |
| # GRADIO QUEUE — Stability config for 2-core CPU HF Spaces | |
| # | |
| # Why these settings matter: | |
| # | |
| # max_size=3 | |
| # Hard cap on pending jobs. On a 2-core CPU, RVC can take | |
| # 3–10 minutes per job. Without a cap, users pile up and | |
| # the Space exhausts RAM while holding dozens of audio | |
| # blobs in the queue. 3 is a safe upper bound: one running | |
| # + two waiting. Any new request beyond that gets a clear | |
| # "queue full" HTTP 503 instead of a silent timeout. | |
| # | |
| # default_concurrency_limit=1 | |
| # Only one inference job runs at a time. Two simultaneous | |
| # RVC jobs on a 2-core CPU would each get 1 thread, which | |
| # is slower than one job using both cores sequentially. | |
| # Serial execution is strictly faster here. | |
| # | |
| # The convert_btn.click already has concurrency_limit=1 set | |
| # at the event level; this queue-level setting enforces it | |
| # globally (including any API calls) so nothing bypasses it. | |
| # | |
| # status_update_rate="auto" | |
| # Gradio sends SSE heartbeats to the browser on this cadence. | |
| # "auto" (≈ 1 Hz) is enough to keep the WebSocket / SSE | |
| # connection alive across long jobs without flooding the | |
| # 2-core CPU with keep-alive overhead. | |
| # ============================================================ | |
| demo.queue( | |
| max_size=3, | |
| default_concurrency_limit=1, | |
| status_update_rate="auto", | |
| ).launch( | |
| mcp_server=True, | |
| show_error=True, | |
| ssr_mode=False, | |
| ) |