""" RVC + Beatrice v2 Voice Conversion - Single-file app for HuggingFace Spaces RVC-Project + Beatrice v2 (fierce-cats/beatrice-trainer), consolidated into single file - Inference: RVC v2 (.pth) + Beatrice v2 (.pt.gz), CPU or GPU - Training: RVC v2 + Beatrice v2, GPU recommended Usage: CLI: python app.py infer -i input.wav -m model.pth -o output.wav python app.py infer -i input.wav -m beatrice.pt.gz -o output.wav Gradio: python app.py """ import os import sys # MPS fallback for macOS os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" import argparse import gc import gzip import json as json_module import logging import math import re import shutil import tempfile import warnings # Suppress known harmless warnings from HF Spaces / torch internals warnings.filterwarnings("ignore", message=".*torch.distributed.reduce_op.*", category=FutureWarning) warnings.filterwarnings("ignore", message=".*torch.nn.utils.weight_norm.*", category=FutureWarning) from collections import defaultdict from fractions import Fraction from functools import partial from pathlib import Path from random import Random from typing import Optional, List, Tuple, Union, BinaryIO, Literal, Sequence, Iterable, Callable import gradio as gr import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import Conv1d, ConvTranspose1d from torch.nn.utils import weight_norm, remove_weight_norm import librosa import pyworld import soundfile as sf import torchaudio from scipy import signal from huggingface_hub import hf_hub_download from tqdm.auto import tqdm # 48 Hz high-pass filter to remove low-frequency artifacts (same as Applio) FILTER_ORDER = 5 CUTOFF_FREQUENCY = 48 # Hz SAMPLE_RATE = 16000 # Hz bh, ah = signal.butter(N=FILTER_ORDER, Wn=CUTOFF_FREQUENCY, btype="high", fs=SAMPLE_RATE) def sanitize_model_name(name: str) -> str: """Sanitize model name for safe use in file paths""" name = os.path.basename(name.strip()) name = re.sub(r'[^\w\-.]', '_', name) return name or "unnamed_model" def list_rvc_models() -> list: """Scan the weights/ directory and return a sorted list of .pth model filenames.""" weights_dir = Path("weights") if not weights_dir.exists(): return [] return sorted([p.name for p in weights_dir.glob("*.pth")]) # Default example model DEFAULT_MODEL_REPO = "audo/Benee-RVC" DEFAULT_MODEL_FILE = "BENEE8000.pth" DEFAULT_INDEX_FILE = "added_IVF1054_Flat_nprobe_8.index" # RVC v2 pretrained weights from official repo RVC_PRETRAINED_REPO = "lj1995/VoiceConversionWebUI" RVC_PRETRAINED_V2 = { # Generator with f0 (pitch) "f0G48k": "pretrained_v2/f0G48k.pth", "f0G40k": "pretrained_v2/f0G40k.pth", "f0G32k": "pretrained_v2/f0G32k.pth", # Discriminator with f0 "f0D48k": "pretrained_v2/f0D48k.pth", "f0D40k": "pretrained_v2/f0D40k.pth", "f0D32k": "pretrained_v2/f0D32k.pth", # Generator without f0 "G48k": "pretrained_v2/G48k.pth", "G40k": "pretrained_v2/G40k.pth", "G32k": "pretrained_v2/G32k.pth", # Discriminator without f0 "D48k": "pretrained_v2/D48k.pth", "D40k": "pretrained_v2/D40k.pth", "D32k": "pretrained_v2/D32k.pth", } def download_pretrained_rvc(name: str) -> str: """Download RVC v2 pretrained weights from HuggingFace""" if name not in RVC_PRETRAINED_V2: raise ValueError(f"Unknown pretrained: {name}. Available: {list(RVC_PRETRAINED_V2.keys())}") filepath = RVC_PRETRAINED_V2[name] logger.info(f"Downloading pretrained {name} from {RVC_PRETRAINED_REPO}...") return hf_hub_download(repo_id=RVC_PRETRAINED_REPO, filename=filepath) # Beatrice v2 pretrained assets BEATRICE_REPO = "fierce-cats/beatrice-trainer" BEATRICE_PRETRAINED = { "phone_extractor": "assets/pretrained/122_checkpoint_03000000.pt", "pitch_estimator": "assets/pretrained/104_3_checkpoint_00300000.pt", "pretrained_model": "assets/pretrained/151_checkpoint_libritts_r_200_02750000.pt.gz", } def download_beatrice_asset(name: str) -> str: """Download Beatrice v2 pretrained asset from HuggingFace""" if name not in BEATRICE_PRETRAINED: raise ValueError(f"Unknown asset: {name}. Available: {list(BEATRICE_PRETRAINED.keys())}") filepath = BEATRICE_PRETRAINED[name] logger.info(f"Downloading Beatrice asset {name} from {BEATRICE_REPO}...") return hf_hub_download(repo_id=BEATRICE_REPO, filename=filepath) def download_beatrice_augmentation(): """Download Beatrice augmentation assets (noise + IR) - optional for training""" try: from huggingface_hub import snapshot_download cache_dir = snapshot_download(repo_id=BEATRICE_REPO, allow_patterns=["assets/noise/*", "assets/ir/*"]) noise_dir = os.path.join(cache_dir, "assets", "noise") ir_dir = os.path.join(cache_dir, "assets", "ir") if os.path.isdir(noise_dir) and os.path.isdir(ir_dir): return noise_dir, ir_dir return None, None except Exception as e: logger.warning(f"Could not download augmentation assets: {e}") return None, None def load_pretrained_weights(model: nn.Module, pretrained_path: str) -> None: """Load pretrained weights into model, handling speaker embedding mismatch""" logger.info(f"Loading pretrained weights: {pretrained_path}") state_dict = torch.load(pretrained_path, map_location="cpu", weights_only=True) # Handle different checkpoint formats if "model" in state_dict: state_dict = state_dict["model"] # Filter out mismatched keys, but handle emb_g specially model_state = model.state_dict() filtered_state = {} skipped = [] for k, v in state_dict.items(): if k in model_state: if v.shape == model_state[k].shape: filtered_state[k] = v elif k == "emb_g.weight": # Initialize our speaker embedding with mean of pretrained embeddings # This gives a much better starting point than random initialization mean_emb = v.mean(dim=0, keepdim=True) # [1, 256] num_speakers = model_state[k].shape[0] filtered_state[k] = mean_emb.expand(num_speakers, -1).clone() logger.info(f"Initialized emb_g from pretrained mean ({v.shape[0]} -> {num_speakers} speakers)") else: skipped.append(f"{k}: {v.shape} vs {model_state[k].shape}") else: skipped.append(f"{k}: not in model") if skipped: logger.info(f"Skipped {len(skipped)} mismatched keys") model.load_state_dict(filtered_state, strict=False) logger.info(f"Loaded {len(filtered_state)}/{len(state_dict)} pretrained weights") logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Device selection: # - Inference: Always CPU (HF Spaces free tier, also works everywhere) # - Training: GPU if available for speed, CPU fallback device = torch.device("cpu") # For inference train_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # For training logger.info(f"Inference device: {device}") logger.info(f"Training device: {train_device}") # ============================================================ # CPU OPTIMIZATION — Locked 2-core HuggingFace Spaces config # ============================================================ # Restrict PyTorch to exactly 2 physical cores. # OpenMP and MKL must both be capped before any tensor ops fire. _CPU_CORES = 2 torch.set_num_threads(_CPU_CORES) torch.set_num_interop_threads(_CPU_CORES) os.environ["OMP_NUM_THREADS"] = str(_CPU_CORES) os.environ["MKL_NUM_THREADS"] = str(_CPU_CORES) os.environ["OPENBLAS_NUM_THREADS"] = str(_CPU_CORES) os.environ["VECLIB_MAXIMUM_THREADS"] = str(_CPU_CORES) os.environ["NUMEXPR_NUM_THREADS"] = str(_CPU_CORES) # torch.inference_mode is heavier than no_grad but also frees the # autograd graph eagerly, which helps on a memory-constrained CPU. # Enable oneDNN graph fusion (fuses conv+bn, linear+relu etc. into # single kernels — measurable speedup on Intel Xeon VMs). torch.backends.mkldnn.enabled = True try: torch.jit.enable_onednn_fusion(True) except Exception: pass logger.info(f"PyTorch CPU threads: {torch.get_num_threads()} (interop={torch.get_num_interop_threads()})") # ============================================================ # MEMORY MANAGEMENT — purge_memory() # Call this between every heavy operation to prevent OOM on # the 16 GB HuggingFace Spaces free-tier CPU instance. # ============================================================ import ctypes, platform def purge_memory(*tensors_or_arrays): """ Aggressively free memory after a heavy generation step. Pass any tensors / numpy arrays that should be deleted. The function: 1. Deletes every passed object from caller scope. 2. Runs Python gc (two passes: first collects cycles, second collects anything the first pass freed). 3. On Linux (HuggingFace Spaces), calls malloc_trim(0) via ctypes so glibc returns freed pages to the OS immediately. Without this, RSS can stay high even after gc.collect(). 4. Clears CUDA cache if a GPU is somehow available. """ for obj in tensors_or_arrays: try: del obj except Exception: pass # Two-pass gc: cycles first, then their referents gc.collect() gc.collect() # Return glibc memory to the OS (Linux only — HF Spaces is Linux) if platform.system() == "Linux": try: ctypes.CDLL("libc.so.6").malloc_trim(0) except Exception: pass if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() # ============================================================ # COMMONS - Helper functions from infer/lib/infer_pack/commons.py # ============================================================ def init_weights(m, mean=0.0, std=0.01): classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(mean, std) def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None): if max_length is None: max_length = length.max() x = torch.arange(max_length, dtype=length.dtype, device=length.device) return x.unsqueeze(0) < length.unsqueeze(1) @torch.jit.script def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): n_channels_int = n_channels[0] in_act = input_a + input_b t_act = torch.tanh(in_act[:, :n_channels_int, :]) s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) return t_act * s_act def slice_segments(x, ids_str, segment_size=4): """Slice segments from tensor""" ret = torch.zeros_like(x[:, :, :segment_size]) for i in range(x.size(0)): idx_str = ids_str[i] idx_end = idx_str + segment_size ret[i] = x[i, :, idx_str:idx_end] return ret def slice_segments2(x, ids_str, segment_size=4): """Slice segments from 2D tensor""" ret = torch.zeros_like(x[:, :segment_size]) for i in range(x.size(0)): idx_str = ids_str[i] idx_end = idx_str + segment_size ret[i] = x[i, idx_str:idx_end] return ret def rand_slice_segments(x, x_lengths=None, segment_size=4): """Random slice segments""" b, d, t = x.size() if x_lengths is None: x_lengths = t ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=1) ids_str = (torch.rand([b], device=x.device) * ids_str_max.float()).long() ret = slice_segments(x, ids_str, segment_size) return ret, ids_str # ============================================================ # MODULES - From infer/lib/infer_pack/modules.py # ============================================================ LRELU_SLOPE = 0.1 class LayerNorm(nn.Module): def __init__(self, channels, eps=1e-5): super().__init__() self.channels = channels self.eps = eps self.gamma = nn.Parameter(torch.ones(channels)) self.beta = nn.Parameter(torch.zeros(channels)) def forward(self, x): x = x.transpose(1, -1) x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) return x.transpose(1, -1) class WN(nn.Module): def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): super().__init__() assert kernel_size % 2 == 1 self.hidden_channels = hidden_channels self.kernel_size = (kernel_size,) self.dilation_rate = dilation_rate self.n_layers = n_layers self.gin_channels = gin_channels self.p_dropout = float(p_dropout) self.in_layers = nn.ModuleList() self.res_skip_layers = nn.ModuleList() self.drop = nn.Dropout(float(p_dropout)) if gin_channels != 0: cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) self.cond_layer = weight_norm(cond_layer, name="weight") for i in range(n_layers): dilation = dilation_rate ** i padding = int((kernel_size * dilation - dilation) / 2) in_layer = nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding) in_layer = weight_norm(in_layer, name="weight") self.in_layers.append(in_layer) if i < n_layers - 1: res_skip_channels = 2 * hidden_channels else: res_skip_channels = hidden_channels res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1) res_skip_layer = weight_norm(res_skip_layer, name="weight") self.res_skip_layers.append(res_skip_layer) def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None): output = torch.zeros_like(x) n_channels_tensor = torch.IntTensor([self.hidden_channels]) if g is not None: g = self.cond_layer(g) for i, (in_layer, res_skip_layer) in enumerate(zip(self.in_layers, self.res_skip_layers)): x_in = in_layer(x) if g is not None: cond_offset = i * 2 * self.hidden_channels g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] else: g_l = torch.zeros_like(x_in) acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) acts = self.drop(acts) res_skip_acts = res_skip_layer(acts) if i < self.n_layers - 1: res_acts = res_skip_acts[:, :self.hidden_channels, :] x = (x + res_acts) * x_mask output = output + res_skip_acts[:, self.hidden_channels:, :] else: output = output + res_skip_acts return output * x_mask def remove_weight_norm(self): if self.gin_channels != 0: remove_weight_norm(self.cond_layer) for l in self.in_layers: remove_weight_norm(l) for l in self.res_skip_layers: remove_weight_norm(l) class ResBlock1(nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super().__init__() self.convs1 = nn.ModuleList([ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]))), ]) self.convs1.apply(init_weights) self.convs2 = nn.ModuleList([ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), ]) self.convs2.apply(init_weights) self.lrelu_slope = LRELU_SLOPE def forward(self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None): for c1, c2 in zip(self.convs1, self.convs2): xt = F.leaky_relu(x, self.lrelu_slope) if x_mask is not None: xt = xt * x_mask xt = c1(xt) xt = F.leaky_relu(xt, self.lrelu_slope) if x_mask is not None: xt = xt * x_mask xt = c2(xt) x = xt + x if x_mask is not None: x = x * x_mask return x def remove_weight_norm(self): for l in self.convs1: remove_weight_norm(l) for l in self.convs2: remove_weight_norm(l) class ResBlock2(nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3)): super().__init__() self.convs = nn.ModuleList([ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), ]) self.convs.apply(init_weights) self.lrelu_slope = LRELU_SLOPE def forward(self, x, x_mask: Optional[torch.Tensor] = None): for c in self.convs: xt = F.leaky_relu(x, self.lrelu_slope) if x_mask is not None: xt = xt * x_mask xt = c(xt) x = xt + x if x_mask is not None: x = x * x_mask return x def remove_weight_norm(self): for l in self.convs: remove_weight_norm(l) class Flip(nn.Module): def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): x = torch.flip(x, [1]) if not reverse: logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) return x, logdet else: return x, torch.zeros([1], device=x.device) class ResidualCouplingLayer(nn.Module): def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False): assert channels % 2 == 0 super().__init__() self.channels = channels self.hidden_channels = hidden_channels self.half_channels = channels // 2 self.mean_only = mean_only self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=float(p_dropout), gin_channels=gin_channels) self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) self.post.weight.data.zero_() self.post.bias.data.zero_() def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): x0, x1 = torch.split(x, [self.half_channels] * 2, 1) h = self.pre(x0) * x_mask h = self.enc(h, x_mask, g=g) stats = self.post(h) * x_mask if not self.mean_only: m, logs = torch.split(stats, [self.half_channels] * 2, 1) else: m = stats logs = torch.zeros_like(m) if not reverse: x1 = m + x1 * torch.exp(logs) * x_mask x = torch.cat([x0, x1], 1) logdet = torch.sum(logs, [1, 2]) return x, logdet else: x1 = (x1 - m) * torch.exp(-logs) * x_mask x = torch.cat([x0, x1], 1) return x, torch.zeros([1]) def remove_weight_norm(self): self.enc.remove_weight_norm() # ============================================================ # ATTENTIONS - From infer/lib/infer_pack/attentions.py # ============================================================ class MultiHeadAttention(nn.Module): def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, proximal_bias=False, proximal_init=False): super().__init__() assert channels % n_heads == 0 self.channels = channels self.out_channels = out_channels self.n_heads = n_heads self.p_dropout = p_dropout self.window_size = window_size self.heads_share = heads_share self.proximal_bias = proximal_bias self.proximal_init = proximal_init self.k_channels = channels // n_heads self.conv_q = nn.Conv1d(channels, channels, 1) self.conv_k = nn.Conv1d(channels, channels, 1) self.conv_v = nn.Conv1d(channels, channels, 1) self.conv_o = nn.Conv1d(channels, out_channels, 1) self.drop = nn.Dropout(p_dropout) if window_size is not None: n_heads_rel = 1 if heads_share else n_heads rel_stddev = self.k_channels ** -0.5 self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) nn.init.xavier_uniform_(self.conv_q.weight) nn.init.xavier_uniform_(self.conv_k.weight) nn.init.xavier_uniform_(self.conv_v.weight) if proximal_init: with torch.no_grad(): self.conv_k.weight.copy_(self.conv_q.weight) self.conv_k.bias.copy_(self.conv_q.bias) def forward(self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): q = self.conv_q(x) k = self.conv_k(c) v = self.conv_v(c) x, _ = self.attention(q, k, v, mask=attn_mask) x = self.conv_o(x) return x def attention(self, query, key, value, mask=None): b, d, t_s = key.size() t_t = query.size(2) query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) if self.window_size is not None: key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) scores_local = self._relative_position_to_absolute_position(rel_logits) scores = scores + scores_local if self.proximal_bias: scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) if mask is not None: scores = scores.masked_fill(mask == 0, -1e4) p_attn = F.softmax(scores, dim=-1) p_attn = self.drop(p_attn) output = torch.matmul(p_attn, value) if self.window_size is not None: relative_weights = self._absolute_position_to_relative_position(p_attn) value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) output = output.transpose(2, 3).contiguous().view(b, d, t_t) return output, p_attn def _matmul_with_relative_values(self, x, y): return torch.matmul(x, y.unsqueeze(0)) def _matmul_with_relative_keys(self, x, y): return torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) def _get_relative_embeddings(self, relative_embeddings, length): pad_length = max(length - (self.window_size + 1), 0) slice_start_position = max((self.window_size + 1) - length, 0) slice_end_position = slice_start_position + 2 * length - 1 if pad_length > 0: padded_relative_embeddings = F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) else: padded_relative_embeddings = relative_embeddings return padded_relative_embeddings[:, slice_start_position:slice_end_position] def _relative_position_to_absolute_position(self, x): batch, heads, length, _ = x.size() x = F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0]) x_flat = x.view([batch, heads, length * 2 * length]) x_flat = F.pad(x_flat, [0, int(length) - 1, 0, 0, 0, 0]) x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:] return x_final def _absolute_position_to_relative_position(self, x): batch, heads, length, _ = x.size() x = F.pad(x, [0, int(length) - 1, 0, 0, 0, 0, 0, 0]) x_flat = x.view([batch, heads, int(length ** 2) + int(length * (length - 1))]) x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] return x_final def _attention_bias_proximal(self, length): r = torch.arange(length, dtype=torch.float32) diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) class FFN(nn.Module): def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.filter_channels = filter_channels self.kernel_size = kernel_size self.p_dropout = p_dropout self.causal = causal self.is_activation = activation == "gelu" self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) self.drop = nn.Dropout(p_dropout) def forward(self, x: torch.Tensor, x_mask: torch.Tensor): x = self.conv_1(self._padding(x, x_mask)) if self.is_activation: x = x * torch.sigmoid(1.702 * x) else: x = torch.relu(x) x = self.drop(x) x = self.conv_2(self._padding(x, x_mask)) return x * x_mask def _padding(self, x, x_mask): if self.causal: if self.kernel_size == 1: return x * x_mask pad_l = self.kernel_size - 1 return F.pad(x * x_mask, [pad_l, 0, 0, 0, 0, 0]) else: if self.kernel_size == 1: return x * x_mask pad_l = (self.kernel_size - 1) // 2 pad_r = self.kernel_size // 2 return F.pad(x * x_mask, [pad_l, pad_r, 0, 0, 0, 0]) class Encoder(nn.Module): def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, window_size=10, **kwargs): super().__init__() self.hidden_channels = hidden_channels self.n_layers = int(n_layers) self.drop = nn.Dropout(p_dropout) self.attn_layers = nn.ModuleList() self.norm_layers_1 = nn.ModuleList() self.ffn_layers = nn.ModuleList() self.norm_layers_2 = nn.ModuleList() for i in range(self.n_layers): self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) self.norm_layers_1.append(LayerNorm(hidden_channels)) self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) self.norm_layers_2.append(LayerNorm(hidden_channels)) def forward(self, x, x_mask): attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) x = x * x_mask for attn, norm1, ffn, norm2 in zip(self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2): y = attn(x, x, attn_mask) y = self.drop(y) x = norm1(x + y) y = ffn(x, x_mask) y = self.drop(y) x = norm2(x + y) return x * x_mask # ============================================================ # MODELS - From infer/lib/infer_pack/models.py # ============================================================ sr2sr = {"32k": 32000, "40k": 40000, "48k": 48000} class TextEncoder256(nn.Module): def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True): super().__init__() self.out_channels = out_channels self.hidden_channels = hidden_channels self.emb_phone = nn.Linear(256, hidden_channels) self.lrelu = nn.LeakyReLU(0.1, inplace=True) if f0: self.emb_pitch = nn.Embedding(256, hidden_channels) self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor): if pitch is None: x = self.emb_phone(phone) else: x = self.emb_phone(phone) + self.emb_pitch(pitch) x = x * math.sqrt(self.hidden_channels) x = self.lrelu(x) x = torch.transpose(x, 1, -1) x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype) x = self.encoder(x * x_mask, x_mask) stats = self.proj(x) * x_mask m, logs = torch.split(stats, self.out_channels, dim=1) return m, logs, x_mask class TextEncoder768(nn.Module): def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True): super().__init__() self.out_channels = out_channels self.hidden_channels = hidden_channels self.emb_phone = nn.Linear(768, hidden_channels) self.lrelu = nn.LeakyReLU(0.1, inplace=True) if f0: self.emb_pitch = nn.Embedding(256, hidden_channels) self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor): if pitch is None: x = self.emb_phone(phone) else: x = self.emb_phone(phone) + self.emb_pitch(pitch) x = x * math.sqrt(self.hidden_channels) x = self.lrelu(x) x = torch.transpose(x, 1, -1) x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype) x = self.encoder(x * x_mask, x_mask) stats = self.proj(x) * x_mask m, logs = torch.split(stats, self.out_channels, dim=1) return m, logs, x_mask class ResidualCouplingBlock(nn.Module): def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): super().__init__() self.n_flows = n_flows self.flows = nn.ModuleList() for i in range(n_flows): self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) self.flows.append(Flip()) def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): if not reverse: for flow in self.flows: x, _ = flow(x, x_mask, g=g, reverse=reverse) else: for flow in self.flows[::-1]: x, _ = flow.forward(x, x_mask, g=g, reverse=reverse) return x def remove_weight_norm(self): for i in range(self.n_flows): self.flows[i * 2].remove_weight_norm() class PosteriorEncoder(nn.Module): def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0): super().__init__() self.out_channels = out_channels self.pre = nn.Conv1d(in_channels, hidden_channels, 1) self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None): x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x = self.pre(x) * x_mask x = self.enc(x, x_mask, g=g) stats = self.proj(x) * x_mask m, logs = torch.split(stats, self.out_channels, dim=1) z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask return z, m, logs, x_mask def remove_weight_norm(self): self.enc.remove_weight_norm() class Generator(nn.Module): def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): super().__init__() self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) resblock_class = ResBlock1 if resblock == "1" else ResBlock2 self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2))) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): self.resblocks.append(resblock_class(ch, k, d)) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) self.ups.apply(init_weights) if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None): x = self.conv_pre(x) if g is not None: x = x + self.cond(g) for i in range(self.num_upsamples): x = F.leaky_relu(x, LRELU_SLOPE) x = self.ups[i](x) xs = None for j in range(self.num_kernels): if xs is None: xs = self.resblocks[i * self.num_kernels + j](x) else: xs += self.resblocks[i * self.num_kernels + j](x) x = xs / self.num_kernels x = F.leaky_relu(x) x = self.conv_post(x) x = torch.tanh(x) return x def remove_weight_norm(self): for l in self.ups: remove_weight_norm(l) for l in self.resblocks: l.remove_weight_norm() class SineGen(nn.Module): def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0): super().__init__() self.sine_amp = sine_amp self.noise_std = noise_std self.harmonic_num = harmonic_num self.dim = harmonic_num + 1 self.sampling_rate = samp_rate self.voiced_threshold = voiced_threshold def _f02uv(self, f0): uv = torch.ones_like(f0) uv = uv * (f0 > self.voiced_threshold) return uv.float() def forward(self, f0: torch.Tensor, upp: int): with torch.no_grad(): f0 = f0[:, None].transpose(1, 2) f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) f0_buf[:, :, 0] = f0[:, :, 0] for idx in range(self.harmonic_num): f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2) rad_values = (f0_buf / self.sampling_rate) % 1 rand_ini = torch.rand(f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device) rand_ini[:, 0] = 0 rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini tmp_over_one = torch.cumsum(rad_values, 1) tmp_over_one *= upp tmp_over_one = F.interpolate(tmp_over_one.transpose(2, 1), scale_factor=float(upp), mode="linear", align_corners=True).transpose(2, 1) rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1) tmp_over_one %= 1 tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 cumsum_shift = torch.zeros_like(rad_values) cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi) sine_waves = sine_waves * self.sine_amp uv = self._f02uv(f0) uv = F.interpolate(uv.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1) noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 noise = noise_amp * torch.randn_like(sine_waves) sine_waves = sine_waves * uv + noise return sine_waves, uv, noise class SourceModuleHnNSF(nn.Module): def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0, is_half=False): super().__init__() self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod) self.l_linear = nn.Linear(harmonic_num + 1, 1) self.l_tanh = nn.Tanh() def forward(self, x: torch.Tensor, upp: int = 1): sine_wavs, uv, _ = self.l_sin_gen(x, upp) sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype) sine_merge = self.l_tanh(self.l_linear(sine_wavs)) return sine_merge, None, None class GeneratorNSF(nn.Module): def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, sr, is_half=False): super().__init__() self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) self.f0_upsamp = nn.Upsample(scale_factor=math.prod(upsample_rates)) self.m_source = SourceModuleHnNSF(sampling_rate=sr, harmonic_num=0, is_half=is_half) self.noise_convs = nn.ModuleList() self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) resblock_class = ResBlock1 if resblock == "1" else ResBlock2 self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): c_cur = upsample_initial_channel // (2 ** (i + 1)) self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2))) if i + 1 < len(upsample_rates): stride_f0 = math.prod(upsample_rates[i + 1:]) self.noise_convs.append(Conv1d(1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)) else: self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): self.resblocks.append(resblock_class(ch, k, d)) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) self.ups.apply(init_weights) if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) self.upp = math.prod(upsample_rates) self.lrelu_slope = LRELU_SLOPE def forward(self, x, f0, g: Optional[torch.Tensor] = None): har_source, _, _ = self.m_source(f0, self.upp) har_source = har_source.transpose(1, 2) x = self.conv_pre(x) if g is not None: x = x + self.cond(g) for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)): if i < self.num_upsamples: x = F.leaky_relu(x, self.lrelu_slope) x = ups(x) x_source = noise_convs(har_source) x = x + x_source xs = None l = [i * self.num_kernels + j for j in range(self.num_kernels)] for j, resblock in enumerate(self.resblocks): if j in l: if xs is None: xs = resblock(x) else: xs += resblock(x) x = xs / self.num_kernels x = F.leaky_relu(x) x = self.conv_post(x) x = torch.tanh(x) return x def remove_weight_norm(self): for l in self.ups: remove_weight_norm(l) for l in self.resblocks: l.remove_weight_norm() # Synthesizer classes for different model versions class SynthesizerTrnMs256NSFsid(nn.Module): """RVC v1 model with f0""" def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, **kwargs): super().__init__() if isinstance(sr, str): sr = sr2sr[sr] self.segment_size = segment_size self.gin_channels = gin_channels self.spk_embed_dim = spk_embed_dim self.enc_p = TextEncoder256(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) self.dec = GeneratorNSF(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, is_half=kwargs.get("is_half", False)) self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) @torch.jit.export def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: torch.Tensor, nsff0: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): g = self.emb_g(sid).unsqueeze(-1) m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask if rate is not None: head = int(z_p.shape[2] * (1 - rate.item())) z_p = z_p[:, :, head:] x_mask = x_mask[:, :, head:] nsff0 = nsff0[:, head:] z = self.flow(z_p, x_mask, g=g, reverse=True) o = self.dec(z * x_mask, nsff0, g=g) return o, x_mask, (z, z_p, m_p, logs_p) class SynthesizerTrnMs768NSFsid(nn.Module): """RVC v2 model with f0""" def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, **kwargs): super().__init__() if isinstance(sr, str): sr = sr2sr[sr] self.segment_size = segment_size self.gin_channels = gin_channels self.spk_embed_dim = spk_embed_dim self.enc_p = TextEncoder768(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) self.dec = GeneratorNSF(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, is_half=kwargs.get("is_half", False)) self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) def forward(self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds): """Training forward pass""" g = self.emb_g(ds).unsqueeze(-1) m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) z_p = self.flow(z, y_mask, g=g) z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size) pitchf = slice_segments2(pitchf, ids_slice, self.segment_size) o = self.dec(z_slice, pitchf, g=g) return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) @torch.jit.export def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: torch.Tensor, nsff0: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): g = self.emb_g(sid).unsqueeze(-1) m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask if rate is not None: head = int(z_p.shape[2] * (1.0 - rate.item())) z_p = z_p[:, :, head:] x_mask = x_mask[:, :, head:] nsff0 = nsff0[:, head:] z = self.flow(z_p, x_mask, g=g, reverse=True) o = self.dec(z * x_mask, nsff0, g=g) return o, x_mask, (z, z_p, m_p, logs_p) class SynthesizerTrnMs256NSFsid_nono(nn.Module): """RVC v1 model without f0""" def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr=None, **kwargs): super().__init__() self.segment_size = segment_size self.gin_channels = gin_channels self.enc_p = TextEncoder256(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), f0=False) self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) @torch.jit.export def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): g = self.emb_g(sid).unsqueeze(-1) m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask if rate is not None: head = int(z_p.shape[2] * (1.0 - rate.item())) z_p = z_p[:, :, head:] x_mask = x_mask[:, :, head:] z = self.flow(z_p, x_mask, g=g, reverse=True) o = self.dec(z * x_mask, g=g) return o, x_mask, (z, z_p, m_p, logs_p) class SynthesizerTrnMs768NSFsid_nono(nn.Module): """RVC v2 model without f0""" def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr=None, **kwargs): super().__init__() self.segment_size = segment_size self.gin_channels = gin_channels self.enc_p = TextEncoder768(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), f0=False) self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) @torch.jit.export def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): g = self.emb_g(sid).unsqueeze(-1) m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask if rate is not None: head = int(z_p.shape[2] * (1.0 - rate.item())) z_p = z_p[:, :, head:] x_mask = x_mask[:, :, head:] z = self.flow(z_p, x_mask, g=g, reverse=True) o = self.dec(z * x_mask, g=g) return o, x_mask, (z, z_p, m_p, logs_p) # ============================================================ # DISCRIMINATOR - For training # ============================================================ class DiscriminatorS(nn.Module): def __init__(self, use_spectral_norm=False): super().__init__() norm_f = nn.utils.spectral_norm if use_spectral_norm else weight_norm self.convs = nn.ModuleList([ norm_f(Conv1d(1, 16, 15, 1, padding=7)), norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), ]) self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) def forward(self, x): fmap = [] for l in self.convs: x = l(x) x = F.leaky_relu(x, 0.1) fmap.append(x) x = self.conv_post(x) fmap.append(x) x = torch.flatten(x, 1, -1) return x, fmap class DiscriminatorP(nn.Module): def __init__(self, period, use_spectral_norm=False): super().__init__() self.period = period norm_f = nn.utils.spectral_norm if use_spectral_norm else weight_norm self.convs = nn.ModuleList([ norm_f(nn.Conv2d(1, 32, (5, 1), (3, 1), padding=(2, 0))), norm_f(nn.Conv2d(32, 128, (5, 1), (3, 1), padding=(2, 0))), norm_f(nn.Conv2d(128, 512, (5, 1), (3, 1), padding=(2, 0))), norm_f(nn.Conv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0))), norm_f(nn.Conv2d(1024, 1024, (5, 1), 1, padding=(2, 0))), ]) self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) def forward(self, x): fmap = [] b, c, t = x.shape if t % self.period != 0: n_pad = self.period - (t % self.period) x = F.pad(x, (0, n_pad), "reflect") t = t + n_pad x = x.view(b, c, t // self.period, self.period) for l in self.convs: x = l(x) x = F.leaky_relu(x, 0.1) fmap.append(x) x = self.conv_post(x) fmap.append(x) x = torch.flatten(x, 1, -1) return x, fmap class MultiPeriodDiscriminator(nn.Module): def __init__(self, use_spectral_norm=False): super().__init__() periods = [2, 3, 5, 7, 11, 17, 23, 37] # 8 periods for v2 pretrained (9 total discriminators) self.discriminators = nn.ModuleList( [DiscriminatorS(use_spectral_norm)] + [DiscriminatorP(p, use_spectral_norm) for p in periods] ) def forward(self, y, y_hat): y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], [] for d in self.discriminators: y_d_r, fmap_r = d(y) y_d_g, fmap_g = d(y_hat) y_d_rs.append(y_d_r) y_d_gs.append(y_d_g) fmap_rs.append(fmap_r) fmap_gs.append(fmap_g) return y_d_rs, y_d_gs, fmap_rs, fmap_gs # ============================================================ # TRAINING LOSSES # ============================================================ def feature_loss(fmap_r, fmap_g): loss = 0 for dr, dg in zip(fmap_r, fmap_g): for rl, gl in zip(dr, dg): loss += torch.mean(torch.abs(rl.float().detach() - gl.float())) return loss * 2 def discriminator_loss(disc_real_outputs, disc_generated_outputs): loss = 0 for dr, dg in zip(disc_real_outputs, disc_generated_outputs): loss += torch.mean((1 - dr.float()) ** 2) + torch.mean(dg.float() ** 2) return loss def generator_loss(disc_outputs): loss = 0 for dg in disc_outputs: loss += torch.mean((1 - dg.float()) ** 2) return loss def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): z_p, logs_q, m_p, logs_p, z_mask = [x.float() for x in [z_p, logs_q, m_p, logs_p, z_mask]] kl = logs_p - logs_q - 0.5 + 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) return torch.sum(kl * z_mask) / torch.sum(z_mask) # ============================================================ # HUBERT EXTRACTION - Using torchaudio bundle # ============================================================ # ContentVec model for v1 (256-dim) and HuBERT for v2 (768-dim) _contentvec_model = None # For v1 models (256-dim output) _hubert_model = None # For v2 models (768-dim output) _hubert_bundle = None CONTENTVEC_REPO = "IAHispano/Applio" CONTENTVEC_MODEL = "Resources/embedders/contentvec/pytorch_model.bin" CONTENTVEC_CONFIG = "Resources/embedders/contentvec/config.json" def load_contentvec(): """Load ContentVec model from HuggingFace for v1 models (256-dim output)""" global _contentvec_model if _contentvec_model is None: try: from transformers import HubertModel, HubertConfig logger.info("Loading ContentVec model from HuggingFace...") # Download model files model_path = hf_hub_download(repo_id=CONTENTVEC_REPO, filename=CONTENTVEC_MODEL) config_path = hf_hub_download(repo_id=CONTENTVEC_REPO, filename=CONTENTVEC_CONFIG) # Create model with final_proj layer class HubertModelWithFinalProj(HubertModel): def __init__(self, config): super().__init__(config) self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) config = HubertConfig.from_pretrained(config_path) _contentvec_model = HubertModelWithFinalProj(config) state_dict = torch.load(model_path, map_location="cpu", weights_only=True) _contentvec_model.load_state_dict(state_dict) _contentvec_model.to(device).eval() logger.info(f"ContentVec loaded: hidden={config.hidden_size}, proj={config.classifier_proj_size}") except Exception as e: logger.warning(f"Failed to load ContentVec: {e}, falling back to torchaudio HuBERT") _contentvec_model = None return _contentvec_model def load_hubert(): """Load HuBERT model via torchaudio for v2 models (768-dim output)""" global _hubert_model, _hubert_bundle if _hubert_model is None: import torchaudio logger.info("Loading HuBERT model via torchaudio...") _hubert_bundle = torchaudio.pipelines.HUBERT_BASE _hubert_model = _hubert_bundle.get_model().to(device) _hubert_model.eval() logger.info("HuBERT model loaded") return _hubert_model, _hubert_bundle def extract_hubert_features(audio: np.ndarray, sr: int = 16000, version: str = "v2") -> torch.Tensor: """Extract ContentVec features from audio (same as Applio) v1 models: Use ContentVec with final_proj (256-dim) v2 models: Use ContentVec without final_proj (768-dim) """ audio = audio.astype(np.float32) if np.abs(audio).max() > 1.0: audio = audio / np.abs(audio).max() inputs = torch.from_numpy(audio).unsqueeze(0).to(device) # Use ContentVec for ALL versions (same as Applio) contentvec = load_contentvec() if contentvec is not None: with torch.no_grad(): output = contentvec(inputs) if version == "v1": # v1: use final_proj for 256-dim feats = contentvec.final_proj(output.last_hidden_state) else: # v2: use raw hidden state (768-dim) feats = output.last_hidden_state return feats # Fallback to torchaudio HuBERT if ContentVec not available logger.warning("ContentVec not available, using torchaudio HuBERT (results may be degraded)") hubert, bundle = load_hubert() with torch.no_grad(): features, _ = hubert.extract_features(inputs) layer_idx = 11 if version == "v2" else 8 feats = features[min(layer_idx, len(features)-1)] if version == "v1": proj = nn.Linear(768, 256, bias=False).to(device) with torch.no_grad(): w = torch.zeros(256, 768) for i in range(256): w[i, i*3:(i+1)*3] = 1/3 proj.weight.copy_(w) feats = proj(feats) return feats # ============================================================ # F0 EXTRACTION # ============================================================ def extract_f0_pm(audio: np.ndarray, sr: int = 16000, f0_up_key: int = 0) -> Tuple[np.ndarray, np.ndarray]: """Extract F0 using parselmouth (pm method)""" import parselmouth p_len = audio.shape[0] // 160 + 1 f0_min = 65 f0_max = 1100 l_pad = int(np.ceil(1.5 / f0_min * 16000)) r_pad = l_pad + 1 s = parselmouth.Sound(np.pad(audio, (l_pad, r_pad)), 16000).to_pitch_ac( time_step=0.01, voicing_threshold=0.6, pitch_floor=f0_min, pitch_ceiling=f0_max, ) f0 = s.selected_array["frequency"] if len(f0) < p_len: f0 = np.pad(f0, (0, p_len - len(f0))) f0 = f0[:p_len] f0 *= pow(2, f0_up_key / 12) return f0_to_coarse(f0) def extract_f0_harvest(audio: np.ndarray, sr: int = 16000, f0_up_key: int = 0) -> Tuple[np.ndarray, np.ndarray]: """Extract F0 using pyworld harvest""" import pyworld from scipy import signal as scipy_signal f0, t = pyworld.harvest(audio.astype(np.double), fs=16000, f0_ceil=1100, f0_floor=50, frame_period=10) f0 = scipy_signal.medfilt(f0, 3) f0 *= pow(2, f0_up_key / 12) return f0_to_coarse(f0) def f0_to_coarse(f0: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Convert f0 to coarse representation""" f0_min = 50 f0_max = 1100 f0_mel_min = 1127 * np.log(1 + f0_min / 700) f0_mel_max = 1127 * np.log(1 + f0_max / 700) f0bak = f0.copy() f0_mel = 1127 * np.log(1 + f0 / 700) f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1 f0_mel[f0_mel <= 1] = 1 f0_mel[f0_mel > 255] = 255 f0_coarse = np.rint(f0_mel).astype(np.int32) return f0_coarse, f0bak # ============================================================ # RMVPE F0 EXTRACTION (from Applio - IAHispano/Applio) # ============================================================ class RMVPE_ConvBlockRes(nn.Module): def __init__(self, in_channels, out_channels, momentum=0.01): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), nn.Conv2d(out_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), ) self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) if in_channels != out_channels else None def forward(self, x): r = self.conv(x) return r + self.shortcut(x) if self.shortcut else r + x class RMVPE_ResEncoderBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01): super().__init__() self.conv = nn.ModuleList([RMVPE_ConvBlockRes(in_channels, out_channels, momentum)]) for _ in range(n_blocks - 1): self.conv.append(RMVPE_ConvBlockRes(out_channels, out_channels, momentum)) self.kernel_size = kernel_size if kernel_size is not None: self.pool = nn.AvgPool2d(kernel_size=kernel_size) def forward(self, x): for c in self.conv: x = c(x) return (x, self.pool(x)) if self.kernel_size is not None else x class RMVPE_Encoder(nn.Module): def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01): super().__init__() self.n_encoders = n_encoders self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) self.layers = nn.ModuleList() for _ in range(n_encoders): self.layers.append(RMVPE_ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum)) in_channels = out_channels out_channels *= 2 in_size //= 2 self.out_size = in_size self.out_channel = out_channels def forward(self, x): concat_tensors = [] x = self.bn(x) for layer in self.layers: t, x = layer(x) concat_tensors.append(t) return x, concat_tensors class RMVPE_Intermediate(nn.Module): def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01): super().__init__() self.layers = nn.ModuleList([RMVPE_ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)]) for _ in range(n_inters - 1): self.layers.append(RMVPE_ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)) def forward(self, x): for layer in self.layers: x = layer(x) return x class RMVPE_ResDecoderBlock(nn.Module): def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): super().__init__() out_padding = (0, 1) if stride == (1, 2) else (1, 1) self.conv1 = nn.Sequential( nn.ConvTranspose2d(in_channels, out_channels, (3, 3), stride, (1, 1), out_padding, bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), ) self.conv2 = nn.ModuleList([RMVPE_ConvBlockRes(out_channels * 2, out_channels, momentum)]) for _ in range(n_blocks - 1): self.conv2.append(RMVPE_ConvBlockRes(out_channels, out_channels, momentum)) def forward(self, x, concat_tensor): x = self.conv1(x) x = torch.cat((x, concat_tensor), dim=1) for c in self.conv2: x = c(x) return x class RMVPE_Decoder(nn.Module): def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01): super().__init__() self.layers = nn.ModuleList() for _ in range(n_decoders): out_channels = in_channels // 2 self.layers.append(RMVPE_ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)) in_channels = out_channels self.n_decoders = n_decoders def forward(self, x, concat_tensors): for i in range(self.n_decoders): x = self.layers[i](x, concat_tensors[-1 - i]) return x class RMVPE_DeepUnet(nn.Module): def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): super().__init__() self.encoder = RMVPE_Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels) self.intermediate = RMVPE_Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) self.decoder = RMVPE_Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) def forward(self, x): x, concat_tensors = self.encoder(x) x = self.intermediate(x) x = self.decoder(x, concat_tensors) return x class RMVPE_BiGRU(nn.Module): def __init__(self, input_features, hidden_features, num_layers): super().__init__() self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) def forward(self, x): return self.gru(x)[0] RMVPE_N_MELS = 128 RMVPE_N_CLASS = 360 class RMVPE_E2E(nn.Module): def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): super().__init__() self.unet = RMVPE_DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) if n_gru: self.fc = nn.Sequential( RMVPE_BiGRU(3 * 128, 256, n_gru), nn.Linear(512, RMVPE_N_CLASS), nn.Dropout(0.25), nn.Sigmoid(), ) else: self.fc = nn.Sequential(nn.Linear(3 * RMVPE_N_MELS, RMVPE_N_CLASS), nn.Dropout(0.25), nn.Sigmoid()) def forward(self, mel): mel = mel.transpose(-1, -2).unsqueeze(1) x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) return self.fc(x) class RMVPE_MelSpectrogram(nn.Module): def __init__(self, n_mel_channels=128, sample_rate=16000, win_length=1024, hop_length=160, n_fft=None, mel_fmin=30, mel_fmax=8000, clamp=1e-5): super().__init__() from librosa.filters import mel as librosa_mel n_fft = win_length if n_fft is None else n_fft self.hann_window = {} mel_basis = librosa_mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True) self.register_buffer("mel_basis", torch.from_numpy(mel_basis).float()) self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length self.clamp = clamp def forward(self, audio, keyshift=0, speed=1, center=True): factor = 2 ** (keyshift / 12) n_fft_new = int(np.round(self.n_fft * factor)) win_length_new = int(np.round(self.win_length * factor)) hop_length_new = int(np.round(self.hop_length * speed)) key = f"{keyshift}_{audio.device}" if key not in self.hann_window: self.hann_window[key] = torch.hann_window(win_length_new).to(audio.device) fft = torch.stft(audio, n_fft=n_fft_new, hop_length=hop_length_new, win_length=win_length_new, window=self.hann_window[key], center=center, return_complex=True) magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) if keyshift != 0: size = self.n_fft // 2 + 1 resize = magnitude.size(1) if resize < size: magnitude = F.pad(magnitude, (0, 0, 0, size - resize)) magnitude = magnitude[:, :size, :] * self.win_length / win_length_new mel_output = torch.matmul(self.mel_basis, magnitude) return torch.log(torch.clamp(mel_output, min=self.clamp)) _rmvpe_model = None def load_rmvpe(): """Download and load RMVPE model for f0 extraction""" global _rmvpe_model if _rmvpe_model is None: logger.info("Downloading RMVPE model...") rmvpe_path = hf_hub_download(repo_id="IAHispano/Applio", filename="Resources/predictors/rmvpe.pt") model = RMVPE_E2E(4, 1, (2, 2)) ckpt = torch.load(rmvpe_path, map_location="cpu", weights_only=True) model.load_state_dict(ckpt) model.eval().to(device) mel_extractor = RMVPE_MelSpectrogram().to(device) cents_mapping = 20 * np.arange(RMVPE_N_CLASS) + 1997.3794084376191 _rmvpe_model = (model, mel_extractor, np.pad(cents_mapping, (4, 4))) logger.info("RMVPE model loaded") return _rmvpe_model def extract_f0_rmvpe(audio: np.ndarray, sr: int = 16000, f0_up_key: int = 0, thred: float = 0.03) -> Tuple[np.ndarray, np.ndarray]: """Extract F0 using RMVPE (best quality, neural network based)""" model, mel_extractor, cents_mapping = load_rmvpe() audio_t = torch.from_numpy(audio).float().to(device).unsqueeze(0) mel = mel_extractor(audio_t, center=True) del audio_t # mel2hidden with chunking with torch.no_grad(): n_frames = mel.shape[-1] mel_padded = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect") chunks = [] for start in range(0, mel_padded.shape[-1], 32000): end = min(start + 32000, mel_padded.shape[-1]) chunks.append(model(mel_padded[..., start:end])) hidden = torch.cat(chunks, dim=1)[:, :n_frames].squeeze(0).cpu().numpy() # Decode hidden to f0 center = np.argmax(hidden, axis=1) salience = np.pad(hidden, ((0, 0), (4, 4))) center += 4 todo_salience = [] todo_cents = [] for idx in range(salience.shape[0]): s, e = center[idx] - 4, center[idx] + 5 todo_salience.append(salience[idx, s:e]) todo_cents.append(cents_mapping[s:e]) todo_salience = np.array(todo_salience) todo_cents = np.array(todo_cents) cents_pred = np.sum(todo_salience * todo_cents, 1) / np.sum(todo_salience, 1) cents_pred[np.max(salience, axis=1) <= thred] = 0 f0 = 10 * (2 ** (cents_pred / 1200)) f0[f0 == 10] = 0 f0 *= pow(2, f0_up_key / 12) return f0_to_coarse(f0) # ============================================================ # MODEL LOADING # ============================================================ _model_cache = {} def load_rvc_model(model_path: str): """Load RVC model and auto-detect version""" if model_path in _model_cache: return _model_cache[model_path] logger.info(f"Loading RVC model: {model_path}") try: cpt = torch.load(model_path, map_location="cpu", weights_only=True) except Exception: logger.warning("Model requires unsafe loading - may be an older format") cpt = torch.load(model_path, map_location="cpu", weights_only=False) weight_key = None for key in ["weight", "model", "state_dict", "net_g"]: if key in cpt: weight_key = key break if weight_key is None: raise ValueError(f"Cannot find model weights. Keys: {list(cpt.keys())}") config = cpt.get("config", None) if config is None: config = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4], 109, 256, 40000] logger.warning("No config found, using v2 defaults") version = cpt.get("version", "v1") if_f0 = cpt.get("f0", 1) if weight_key in cpt: emb_weight = cpt[weight_key].get("emb_g.weight") if emb_weight is not None: config[-3] = emb_weight.shape[0] sr = config[-1] if isinstance(config[-1], int) else 40000 if version == "v1": model_class = SynthesizerTrnMs256NSFsid if if_f0 == 1 else SynthesizerTrnMs256NSFsid_nono else: model_class = SynthesizerTrnMs768NSFsid if if_f0 == 1 else SynthesizerTrnMs768NSFsid_nono model = model_class( spec_channels=config[0], segment_size=config[1], inter_channels=config[2], hidden_channels=config[3], filter_channels=config[4], n_heads=config[5], n_layers=config[6], kernel_size=config[7], p_dropout=config[8], resblock=config[9], resblock_kernel_sizes=config[10], resblock_dilation_sizes=config[11], upsample_rates=config[12], upsample_initial_channel=config[13], upsample_kernel_sizes=config[14], spk_embed_dim=config[15], gin_channels=config[16], sr=sr, is_half=False ) model.load_state_dict(cpt[weight_key], strict=False) model.eval().to(device) _model_cache[model_path] = (model, sr, version, if_f0) logger.info(f"Model loaded: version={version}, f0={if_f0}, sr={sr}") return model, sr, version, if_f0 # ============================================================ # TRAINING - Simplified for CPU testing # ============================================================ def spectrogram_torch(y, n_fft, hop_size, win_size, center=False): """Compute spectrogram""" hann_window = torch.hann_window(win_size).to(y.device) y = F.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect').squeeze(1) spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6) return spec # Mel spectrogram for training loss _mel_basis_cache = {} def spec_to_mel_torch(spec, n_fft=2048, num_mels=125, sampling_rate=40000, fmin=0, fmax=None): """Convert spectrogram to mel spectrogram""" from librosa.filters import mel as librosa_mel_fn global _mel_basis_cache if fmax is None: fmax = sampling_rate // 2 key = f"{n_fft}_{num_mels}_{sampling_rate}_{fmin}_{fmax}_{spec.dtype}_{spec.device}" if key not in _mel_basis_cache: mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) _mel_basis_cache[key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) melspec = torch.matmul(_mel_basis_cache[key], spec) melspec = torch.log(torch.clamp(melspec, min=1e-5)) # Log-amplitude return melspec def preprocess_audio_for_training(audio_path: str, output_dir: str, target_sr: int = 40000, f0_method: str = "rmvpe"): """Preprocess audio file for training - slice and extract features""" import scipy.signal as signal os.makedirs(output_dir, exist_ok=True) os.makedirs(f"{output_dir}/wavs", exist_ok=True) os.makedirs(f"{output_dir}/hubert", exist_ok=True) os.makedirs(f"{output_dir}/f0", exist_ok=True) logger.info(f"Preprocessing: {audio_path}") # Load and resample audio audio, sr = librosa.load(audio_path, sr=target_sr, mono=True) # High-pass filter bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=target_sr) audio = signal.lfilter(bh, ah, audio) # Slice into chunks (3.7 seconds with 0.3 overlap) chunk_size = int(3.7 * target_sr) hop = int(3.4 * target_sr) chunks = [] for i, start in enumerate(range(0, len(audio) - chunk_size, hop)): chunk = audio[start:start + chunk_size] # Normalize max_val = np.abs(chunk).max() if max_val > 0.01: # Skip silence chunk = chunk / max_val * 0.9 chunks.append((i, chunk)) if not chunks: logger.warning("No valid audio chunks found") return None logger.info(f"Created {len(chunks)} chunks") # Save chunks and extract features manifest = [] for idx, chunk in chunks: # Save wav wav_path = f"{output_dir}/wavs/{idx:04d}.wav" sf.write(wav_path, chunk, target_sr) # Resample to 16k for HuBERT chunk_16k = librosa.resample(chunk, orig_sr=target_sr, target_sr=16000) # Extract HuBERT features feats = extract_hubert_features(chunk_16k, sr=16000, version="v2") hubert_path = f"{output_dir}/hubert/{idx:04d}.npy" np.save(hubert_path, feats.squeeze(0).cpu().numpy()) # Extract F0 if f0_method == "rmvpe": f0_coarse, f0 = extract_f0_rmvpe(chunk_16k, 16000, 0) elif f0_method == "harvest": f0_coarse, f0 = extract_f0_harvest(chunk_16k, 16000, 0) else: f0_coarse, f0 = extract_f0_pm(chunk_16k, 16000, 0) f0_path = f"{output_dir}/f0/{idx:04d}.npy" np.save(f0_path, np.stack([f0_coarse, f0], axis=0)) manifest.append(f"{idx:04d}") # Save manifest with open(f"{output_dir}/manifest.txt", "w") as f: f.write("\n".join(manifest)) logger.info(f"Preprocessing complete: {len(manifest)} samples") return output_dir def train_rvc_generator( data_dir: str, output_dir: str, epochs: int = 10, batch_size: int = 2, lr: float = 1e-5, # Lower LR prevents overfitting on small data target_sr: int = 40000, progress_callback=None ): """Generator version of train_rvc - yields (epoch_msg, ckpt_path) tuples""" logger.info(f"Starting training: {data_dir} -> {output_dir}") os.makedirs(output_dir, exist_ok=True) # Load manifest with open(f"{data_dir}/manifest.txt") as f: samples = [l.strip() for l in f if l.strip()] if len(samples) < 1: logger.error("No training samples found") return None logger.info(f"Training with {len(samples)} samples") # Model config (v2 40k defaults) config = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4], 1, 256, target_sr] # Create models (v2 only - 768-dim HuBERT features) net_g = SynthesizerTrnMs768NSFsid( spec_channels=config[0], segment_size=config[1], inter_channels=config[2], hidden_channels=config[3], filter_channels=config[4], n_heads=config[5], n_layers=config[6], kernel_size=config[7], p_dropout=config[8], resblock=config[9], resblock_kernel_sizes=config[10], resblock_dilation_sizes=config[11], upsample_rates=config[12], upsample_initial_channel=config[13], upsample_kernel_sizes=config[14], spk_embed_dim=config[15], gin_channels=config[16], sr=target_sr ).to(train_device) net_d = MultiPeriodDiscriminator().to(train_device) logger.info(f"Training on device: {train_device}") # Download and load pretrained weights (essential for good results) sr_key = f"{target_sr // 1000}k" # e.g., "40k" try: pretrain_g_path = download_pretrained_rvc(f"f0G{sr_key}") pretrain_d_path = download_pretrained_rvc(f"f0D{sr_key}") load_pretrained_weights(net_g, pretrain_g_path) load_pretrained_weights(net_d, pretrain_d_path) except Exception as e: logger.warning(f"Failed to load pretrained weights: {e}") logger.warning("Training from scratch (results may be poor)") # Optimizers (after loading pretrained weights) optim_g = torch.optim.AdamW(net_g.parameters(), lr=lr, betas=(0.8, 0.99)) optim_d = torch.optim.AdamW(net_d.parameters(), lr=lr, betas=(0.8, 0.99)) # LR scheduler (matches Applio - exponential decay) lr_decay = 0.999875 scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=lr_decay) scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=lr_decay) net_g.train() net_d.train() # Training loop for epoch in range(epochs): total_loss_g, total_loss_d = 0, 0 np.random.shuffle(samples) for i in range(0, len(samples), batch_size): batch_samples = samples[i:i+batch_size] # Load batch data wavs, huberts, f0s = [], [], [] for s in batch_samples: wav, _ = librosa.load(f"{data_dir}/wavs/{s}.wav", sr=target_sr, mono=True) hubert = np.load(f"{data_dir}/hubert/{s}.npy") # Upsample 50Hz -> 100Hz using interpolation (same as inference) hubert_t = torch.from_numpy(hubert).unsqueeze(0).permute(0, 2, 1) # (1, 768, seq) hubert_t = F.interpolate(hubert_t, scale_factor=2, mode='linear', align_corners=False) hubert = hubert_t.permute(0, 2, 1).squeeze(0).numpy() # (seq*2, 768) f0_data = np.load(f"{data_dir}/f0/{s}.npy") wavs.append(wav) huberts.append(hubert) f0s.append(f0_data) # Compute spectrogram first to get target length max_wav_len = max(len(w) for w in wavs) wav_batch = np.zeros((len(wavs), max_wav_len)) for j, w in enumerate(wavs): wav_batch[j, :len(w)] = w wav_t = torch.FloatTensor(wav_batch).unsqueeze(1).to(train_device) spec = spectrogram_torch(wav_t.squeeze(1), 2048, 400, 2048) spec_len = spec.shape[2] # Target length for all features # Pad/truncate features to match spec length exactly hubert_batch = np.zeros((len(huberts), spec_len, huberts[0].shape[1])) f0_batch = np.zeros((len(f0s), spec_len)) f0f_batch = np.zeros((len(f0s), spec_len)) for j, (h, f) in enumerate(zip(huberts, f0s)): # Truncate or pad HuBERT to spec_len h_len = min(h.shape[0], spec_len) hubert_batch[j, :h_len] = h[:h_len] # Truncate or pad F0 to spec_len f0_len = min(f.shape[1], spec_len) f0_batch[j, :f0_len] = f[0, :f0_len] f0f_batch[j, :f0_len] = f[1, :f0_len] # To tensors - all features now have spec_len hubert_t = torch.FloatTensor(hubert_batch).to(train_device) f0_t = torch.LongTensor(f0_batch.astype(np.int64)).to(train_device) f0f_t = torch.FloatTensor(f0f_batch).to(train_device) lengths_t = torch.LongTensor([spec_len] * len(batch_samples)).to(train_device) sid_t = torch.LongTensor([0] * len(batch_samples)).to(train_device) spec_lengths = torch.LongTensor([spec_len] * len(batch_samples)).to(train_device) # Forward pass generator # Args: phone, phone_lengths, pitch, pitchf, y (spec), y_lengths, ds try: y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = net_g( hubert_t, lengths_t, f0_t, f0f_t, spec, spec_lengths, sid_t ) except Exception as e: logger.warning(f"Generator forward failed: {e}") continue # Slice wav at same position model generated (CRITICAL for proper loss) # ids_slice is in latent space, multiply by hop_length to get waveform position hop_length = 400 segment_size_wav = 32 * hop_length # segment_size in latent * hop_length y = slice_segments(wav_t, ids_slice * hop_length, segment_size_wav) # Discriminator forward y_d_rs, y_d_gs, fmap_rs, fmap_gs = net_d(y, y_hat.detach()) # Discriminator loss loss_d = discriminator_loss(y_d_rs, y_d_gs) optim_d.zero_grad() loss_d.backward() optim_d.step() # Generator loss y_d_rs, y_d_gs, fmap_rs, fmap_gs = net_d(y, y_hat) loss_gen = generator_loss(y_d_gs) loss_fm = feature_loss(fmap_rs, fmap_gs) loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) # Mel spectrogram loss (crucial for quality) # Config: n_fft=2048, hop=400, win=2048, n_mels=125, fmin=0, fmax=None y_mel = spec_to_mel_torch(spectrogram_torch(y.squeeze(1), 2048, 400, 2048), n_fft=2048, num_mels=125, sampling_rate=target_sr, fmin=0, fmax=None) y_hat_mel = spec_to_mel_torch(spectrogram_torch(y_hat.squeeze(1), 2048, 400, 2048), n_fft=2048, num_mels=125, sampling_rate=target_sr, fmin=0, fmax=None) # Align lengths if needed min_len = min(y_mel.shape[2], y_hat_mel.shape[2]) loss_mel = F.l1_loss(y_mel[:, :, :min_len], y_hat_mel[:, :, :min_len]) * 45 # c_mel = 45 loss_g = loss_gen + loss_fm + loss_mel + loss_kl optim_g.zero_grad() loss_g.backward() optim_g.step() total_loss_g += loss_g.item() total_loss_d += loss_d.item() avg_loss_g = total_loss_g / max(1, len(samples) // batch_size) avg_loss_d = total_loss_d / max(1, len(samples) // batch_size) epoch_msg = f"Epoch {epoch+1}/{epochs} - G: {avg_loss_g:.2f}, D: {avg_loss_d:.2f}" logger.info(epoch_msg) # Update progress callback if provided if progress_callback: progress_pct = 0.30 + (0.65 * (epoch + 1) / epochs) progress_callback(progress_pct, epoch_msg) # Yield epoch message for live UI updates yield epoch_msg, None, None # Step LR schedulers scheduler_g.step() scheduler_d.step() # Save checkpoint ckpt_path = f"{output_dir}/model.pth" torch.save({ "weight": net_g.state_dict(), "config": config, "version": "v2", # v2 only "f0": 1, }, ckpt_path) logger.info(f"Saved checkpoint: {ckpt_path}") # Generate index file for better speaker similarity index_path = None try: import faiss hubert_dir = f"{data_dir}/hubert" npys = [] for name in sorted(os.listdir(hubert_dir)): if name.endswith('.npy'): phone = np.load(os.path.join(hubert_dir, name)) npys.append(phone) if npys: big_npy = np.concatenate(npys, axis=0) n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39) n_ivf = max(1, n_ivf) # Ensure at least 1 index = faiss.index_factory(big_npy.shape[1], f"IVF{n_ivf},Flat") index.train(big_npy) index.add(big_npy) index_path = f"{output_dir}/model.index" faiss.write_index(index, index_path) logger.info(f"Saved index: {index_path}") except Exception as e: logger.warning(f"Failed to generate index: {e}") # Cleanup training models to free memory purge_memory(net_g, net_d, optim_g, optim_d, scheduler_g, scheduler_d) yield "Training complete!", ckpt_path, index_path def train_rvc( data_dir: str, output_dir: str, epochs: int = 10, batch_size: int = 2, lr: float = 1e-5, # Lower LR prevents overfitting on small data target_sr: int = 40000, progress_callback=None ): """Non-generator wrapper for CLI use - returns (checkpoint_path, index_path)""" ckpt = None idx = None for msg, path, index in train_rvc_generator(data_dir, output_dir, epochs, batch_size, lr, target_sr, progress_callback): if path: ckpt = path if index: idx = index return ckpt, idx # ============================================================ # INFERENCE # ============================================================ def convert_voice( source_audio: str, model_file, index_file=None, pitch_shift: int = 0, f0_method: str = "pm", index_rate: float = 0.5, protect: float = 0.33, volume_envelope: float = 1.0, progress=gr.Progress() ) -> Tuple[str, str]: """Convert voice using RVC model (Applio-compatible pipeline).""" try: if source_audio is None: return None, "Please upload source audio" if model_file is None: return None, "Please upload RVC model (.pth)" model_path = model_file.name if hasattr(model_file, 'name') else model_file progress(0.1, "Loading model...") model, tgt_sr, version, if_f0 = load_rvc_model(model_path) progress(0.2, "Loading audio...") audio, sr = librosa.load(source_audio, sr=16000, mono=True) # Apply 48Hz high-pass filter (critical - removes low-frequency artifacts) audio = signal.filtfilt(bh, ah, audio) # Normalize audio audio_max = np.abs(audio).max() / 0.95 if audio_max > 1: audio /= audio_max # Pipeline constants (same as Applio) window = 160 # Critical for feature/pitch alignment x_pad = 1 # Padding in seconds t_pad = 16000 * x_pad # Padding in samples # Pad audio audio_pad = np.pad(audio, (t_pad, t_pad), mode="reflect") p_len = audio_pad.shape[0] // window progress(0.3, "Extracting features...") feats = extract_hubert_features(audio_pad, sr=16000, version=version) # Save original features for protect mechanism feats0 = feats.clone() if if_f0 == 1 and protect < 0.5 else None # Index retrieval (speaker similarity) if index_file is not None and index_rate > 0: try: import faiss index_path = index_file.name if hasattr(index_file, 'name') else index_file progress(0.4, "Loading index...") index = faiss.read_index(index_path) big_npy = index.reconstruct_n(0, index.ntotal) npy = feats[0].cpu().numpy().astype("float32") score, ix = index.search(npy, k=8) weight = np.square(1 / score) weight /= weight.sum(axis=1, keepdims=True) npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) feats = torch.from_numpy(npy).unsqueeze(0).to(device) * index_rate + (1 - index_rate) * feats except Exception as e: logger.warning(f"Index retrieval failed: {e}") # Feature upsampling by 2x feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) # Adjust length based on audio p_len = min(audio_pad.shape[0] // window, feats.shape[1]) pitch, pitchf = None, None if if_f0 == 1: progress(0.5, f"Extracting F0 ({f0_method})...") if f0_method == "rmvpe": pitch, pitchf = extract_f0_rmvpe(audio_pad, 16000, pitch_shift) elif f0_method == "harvest": pitch, pitchf = extract_f0_harvest(audio_pad, 16000, pitch_shift) else: pitch, pitchf = extract_f0_pm(audio_pad, 16000, pitch_shift) pitch = pitch[:p_len] pitchf = pitchf[:p_len] # Upsample feats0 for protect if feats0 is not None: feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) # Apply protect mechanism (preserve original features for unvoiced segments) if protect < 0.5 and feats0 is not None: pitchf_tensor = torch.from_numpy(pitchf).float().to(device) pitchff = pitchf_tensor.clone() pitchff[pitchf_tensor > 0] = 1 pitchff[pitchf_tensor < 1] = protect pitchff = pitchff.unsqueeze(0).unsqueeze(-1) feats = feats[:, :p_len, :] * pitchff + feats0[:, :p_len, :] * (1 - pitchff) if len(pitch) < p_len: pitch = np.pad(pitch, (0, p_len - len(pitch))) pitchf = np.pad(pitchf, (0, p_len - len(pitchf))) pitch = torch.LongTensor(pitch).unsqueeze(0).to(device) pitchf = torch.FloatTensor(pitchf).unsqueeze(0).to(device) p_len_tensor = torch.LongTensor([p_len]).to(device) sid = torch.LongTensor([0]).to(device) progress(0.7, "Running inference...") with torch.no_grad(): if if_f0 == 1: audio_out = model.infer(feats[:, :p_len, :], p_len_tensor, pitch, pitchf, sid)[0][0, 0].data.cpu().float().numpy() else: audio_out = model.infer(feats[:, :p_len, :], p_len_tensor, sid)[0][0, 0].data.cpu().float().numpy() # Remove padding from output t_pad_tgt = int(t_pad * tgt_sr / 16000) if len(audio_out) > 2 * t_pad_tgt: audio_out = audio_out[t_pad_tgt:-t_pad_tgt] # RMS mixing - match volume dynamics of source audio if volume_envelope != 1.0: try: source_at_tgt_sr = librosa.resample(audio, orig_sr=16000, target_sr=tgt_sr) frame_len = tgt_sr // 2 * 2 hop_len = tgt_sr // 2 rms_source = librosa.feature.rms(y=source_at_tgt_sr, frame_length=frame_len, hop_length=hop_len) rms_output = librosa.feature.rms(y=audio_out, frame_length=frame_len, hop_length=hop_len) rms_source = F.interpolate( torch.from_numpy(rms_source).float().unsqueeze(0), size=audio_out.shape[0], mode="linear" ).squeeze() rms_output = F.interpolate( torch.from_numpy(rms_output).float().unsqueeze(0), size=audio_out.shape[0], mode="linear" ).squeeze() rms_output = torch.maximum(rms_output, torch.zeros_like(rms_output) + 1e-6) # Applio formula: target * (source^(1-rate) * output^(rate-1)) audio_out = audio_out * (torch.pow(rms_source, 1 - volume_envelope) * torch.pow(rms_output, volume_envelope - 1)).numpy() except Exception as e: logger.warning(f"RMS mixing failed: {e}") # Final normalization audio_max = np.abs(audio_out).max() / 0.99 if audio_max > 1: audio_out /= audio_max progress(0.9, "Saving output...") fd, output_path = tempfile.mkstemp(suffix=".wav") os.close(fd) sf.write(output_path, audio_out, tgt_sr) # Aggressive memory purge after inference — frees glibc arena on Linux _model_cache.clear() _cleanup_args = [model, feats, audio_out, audio, audio_pad] if feats0 is not None: _cleanup_args.append(feats0) purge_memory(*_cleanup_args) return output_path, f"Converted: {version}, sr={tgt_sr}, pitch={pitch_shift:+d}" except Exception as e: logger.exception("Conversion failed") _model_cache.clear() purge_memory() return None, f"Error: {str(e)}" # ============================================================ # DEFAULT MODEL DOWNLOAD # ============================================================ def load_example_model(): """Download and load the default example model from HuggingFace""" import shutil try: logger.info(f"Downloading example model from {DEFAULT_MODEL_REPO}...") model_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_MODEL_FILE) index_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_INDEX_FILE) # Gradio 6 requires files to be in allowed directories (cwd or /tmp) # Copy from HF cache to temp directory temp_dir = tempfile.mkdtemp() temp_model = os.path.join(temp_dir, DEFAULT_MODEL_FILE) temp_index = os.path.join(temp_dir, DEFAULT_INDEX_FILE) shutil.copy2(model_path, temp_model) shutil.copy2(index_path, temp_index) return temp_model, temp_index, f"Loaded: {DEFAULT_MODEL_REPO}" except Exception as e: logger.exception("Failed to download example model") return None, None, f"Error: {str(e)}" # ============================================================ # BEATRICE V2 MODEL # ============================================================ def beatrice_load_audio(file, **kwargs): """Load audio using soundfile directly (for Beatrice dataset)""" data, sr = sf.read(file, dtype='float32') # soundfile returns (samples, channels), convert to torch (channels, samples) wav = torch.from_numpy(data) if wav.ndim == 1: wav = wav.unsqueeze(0) # mono -> (1, samples) else: wav = wav.T # (samples, channels) -> (channels, samples) return wav, sr class AttrDict(dict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__dict__ = self def dump_params(params: torch.Tensor, f: BinaryIO): if params is None: return if params.dtype == torch.bfloat16: f.write( params.detach() .clone() .float() .view(torch.short) .numpy() .ravel()[1::2] .tobytes() ) else: f.write(params.detach().numpy().ravel().tobytes()) f.flush() def dump_layer(layer: nn.Module, f: BinaryIO): dump = partial(dump_params, f=f) if hasattr(layer, "dump"): layer.dump(f) elif isinstance(layer, (nn.Linear, nn.Conv1d, nn.LayerNorm)): dump(layer.weight) dump(layer.bias) elif isinstance(layer, nn.MultiheadAttention): embed_dim = layer.embed_dim num_heads = layer.num_heads # [3 * embed_dim, embed_dim] in_proj_weight = layer.in_proj_weight.data.clone() in_proj_weight[: 2 * embed_dim] *= 1.0 / math.sqrt( math.sqrt(embed_dim // num_heads) ) in_proj_weight = in_proj_weight.view( 3, num_heads, embed_dim // num_heads, embed_dim ) # [num_heads, 3, embed_dim / num_heads, embed_dim] in_proj_weight = in_proj_weight.transpose(0, 1) # [3 * embed_dim] in_proj_bias = layer.in_proj_bias.data.clone() in_proj_bias[: 2 * embed_dim] *= 1.0 / math.sqrt( math.sqrt(embed_dim // num_heads) ) in_proj_bias = in_proj_bias.view(3, num_heads, embed_dim // num_heads) # [num_heads, 3, embed_dim / num_heads] in_proj_bias = in_proj_bias.transpose(0, 1) dump(in_proj_weight) dump(in_proj_bias) dump(layer.out_proj.weight) dump(layer.out_proj.bias) elif isinstance(layer, nn.Embedding): dump(layer.weight) elif isinstance(layer, nn.Parameter): dump(layer) elif isinstance(layer, nn.ModuleList): for layer_i in layer: dump_layer(layer_i, f) else: assert False, layer class CausalConv1d(nn.Conv1d): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, delay: int = 0, ): padding = (kernel_size - 1) * dilation - delay self.trim = (kernel_size - 1) * dilation - 2 * delay if self.trim < 0: raise ValueError super().__init__( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) def forward(self, input: torch.Tensor) -> torch.Tensor: result = super().forward(input) if self.trim == 0: return result else: return result[:, :, : -self.trim] class WSConv1d(CausalConv1d): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, delay: int = 0, ): super().__init__( in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=groups, bias=bias, delay=delay, ) self.weight.data.normal_( 0.0, math.sqrt(1.0 / (in_channels * kernel_size // groups)) ) if bias: self.bias.data.zero_() self.gain = nn.Parameter(torch.ones((out_channels, 1, 1))) def standardized_weight(self) -> torch.Tensor: var, mean = torch.var_mean(self.weight, [1, 2], keepdim=True) scale = ( self.gain * ( self.in_channels * self.kernel_size[0] // self.groups * var + 1e-8 ).rsqrt() ) return scale * (self.weight - mean) def forward(self, input: torch.Tensor) -> torch.Tensor: result = F.conv1d( input, self.standardized_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups, ) if self.trim == 0: return result else: return result[:, :, : -self.trim] def merge_weights(self): self.weight.data[:] = self.standardized_weight().detach() self.gain.data.fill_(1.0) class WSLinear(nn.Linear): def __init__(self, in_features: int, out_features: int, bias: bool = True): super().__init__(in_features, out_features, bias) self.weight.data.normal_(0.0, math.sqrt(1.0 / in_features)) self.bias.data.zero_() self.gain = nn.Parameter(torch.ones((out_features, 1))) def standardized_weight(self) -> torch.Tensor: var, mean = torch.var_mean(self.weight, 1, keepdim=True) scale = self.gain * (self.in_features * var + 1e-8).rsqrt() return scale * (self.weight - mean) def forward(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, self.standardized_weight(), self.bias) def merge_weights(self): self.weight.data[:] = self.standardized_weight().detach() self.gain.data.fill_(1.0) class CrossAttention(nn.Module): def __init__( self, qk_channels: int, vo_channels: int, num_heads: int, in_q_channels: int, in_kv_channels: int, out_channels: int, dropout: float = 0.0, ): super().__init__() assert qk_channels % num_heads == 0 self.qk_channels = qk_channels self.vo_channels = vo_channels self.num_heads = num_heads self.in_q_channels = in_q_channels self.in_kv_channels = in_kv_channels self.out_channels = out_channels self.dropout = dropout self.head_qk_channels = qk_channels // num_heads self.head_vo_channels = vo_channels // num_heads self.q_projection = nn.Linear(in_q_channels, qk_channels) self.q_projection.weight.data.normal_(0.0, math.sqrt(1.0 / in_q_channels)) self.q_projection.bias.data.zero_() self.kv_projection = nn.Linear(in_kv_channels, qk_channels + vo_channels) self.kv_projection.weight.data.normal_(0.0, math.sqrt(1.0 / in_kv_channels)) self.kv_projection.bias.data.zero_() self.out_projection = nn.Linear(vo_channels, out_channels) self.out_projection.weight.data.normal_(0.0, math.sqrt(1.0 / vo_channels)) self.out_projection.bias.data.zero_() def forward( self, q: torch.Tensor, kv: torch.Tensor, ) -> torch.Tensor: # q: [batch_size, q_length, in_q_channels] # kv: [batch_size, kv_length, in_kv_channels] batch_size, q_length, _ = q.size() _, kv_length, _ = kv.size() # [batch_size, q_length, qk_channels] q = self.q_projection(q) # [batch_size, kv_length, qk_channels + vo_channels] kv = self.kv_projection(kv) # [batch_size, kv_length, qk_channels], [batch_size, kv_length, vo_channels] k, v = kv.split([self.qk_channels, self.vo_channels], dim=2) q = q.view( batch_size, q_length, self.num_heads, self.head_qk_channels ).transpose(1, 2) k = k.view( batch_size, kv_length, self.num_heads, self.head_qk_channels ).transpose(1, 2) v = v.view( batch_size, kv_length, self.num_heads, self.head_vo_channels ).transpose(1, 2) # [batch_size, num_heads, q_length, head_vo_channels] attn_out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout) # [batch_size, q_length, vo_channels] attn_out = ( attn_out.transpose(1, 2) .contiguous() .view(batch_size, q_length, self.vo_channels) ) # [batch_size, q_length, out_channels] attn_out = self.out_projection(attn_out) return attn_out def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError q_projection_weight = self.q_projection.weight.data.clone() q_projection_bias = self.q_projection.bias.data.clone() q_projection_weight *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) q_projection_bias *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) dump_params(q_projection_weight, f) dump_params(q_projection_bias, f) dump_layer(self.out_projection, f) def dump_kv(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump_kv(f) return if not hasattr(f, "write"): raise TypeError kv_projection_weight = self.kv_projection.weight.data.clone() kv_projection_bias = self.kv_projection.bias.data.clone() k_projection_weight, v_projection_weight = kv_projection_weight.split( [self.qk_channels, self.vo_channels] ) k_projection_bias, v_projection_bias = kv_projection_bias.split( [self.qk_channels, self.vo_channels] ) k_projection_weight *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) k_projection_bias *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) # [qk_channels, in_kv_channels] -> [num_heads, head_qk_channels, in_kv_channels] k_projection_weight = k_projection_weight.view( self.num_heads, self.head_qk_channels, self.in_kv_channels ) # [qk_channels] -> [num_heads, head_qk_channels] k_projection_bias = k_projection_bias.view( self.num_heads, self.head_qk_channels ) # [vo_channels, in_kv_channels] -> [num_heads, head_vo_channels, in_kv_channels] v_projection_weight = v_projection_weight.view( self.num_heads, self.head_vo_channels, self.in_kv_channels ) # [vo_channels] -> [num_heads, head_vo_channels] v_projection_bias = v_projection_bias.view( self.num_heads, self.head_vo_channels ) for i in range(self.num_heads): # [head_qk_channels, in_kv_channels] dump_params(k_projection_weight[i], f) # [head_vo_channels, in_kv_channels] dump_params(v_projection_weight[i], f) for i in range(self.num_heads): # [head_qk_channels] dump_params(k_projection_bias[i], f) # [head_vo_channels] dump_params(v_projection_bias[i], f) class ConvNeXtBlock(nn.Module): def __init__( self, channels: int, intermediate_channels: int, layer_scale_init_value: float, kernel_size: int = 7, use_weight_standardization: bool = False, enable_scaling: bool = False, pre_scale: float = 1.0, post_scale: float = 1.0, use_mha: bool = False, cross_attention: bool = False, num_heads: int = 4, attention_dropout: float = 0.1, attention_channels: Optional[int] = None, kv_channels: Optional[int] = None, ): super().__init__() self.use_weight_standardization = use_weight_standardization self.enable_scaling = enable_scaling self.use_mha = use_mha self.cross_attention = cross_attention if use_mha: self.attn_norm = nn.LayerNorm(channels) if cross_attention: self.mha = CrossAttention( qk_channels=attention_channels, vo_channels=attention_channels, num_heads=num_heads, in_q_channels=channels, in_kv_channels=kv_channels, out_channels=channels, dropout=attention_dropout, ) else: # self-attention assert attention_channels is None assert kv_channels is None self.mha = nn.MultiheadAttention( embed_dim=channels, num_heads=num_heads, dropout=attention_dropout, batch_first=True, ) self.dwconv = CausalConv1d( channels, channels, kernel_size=kernel_size, groups=channels ) self.norm = nn.LayerNorm(channels) self.pwconv1 = nn.Linear(channels, intermediate_channels) self.pwconv2 = nn.Linear(intermediate_channels, channels) self.gamma = nn.Parameter(torch.full((channels,), layer_scale_init_value)) self.dwconv.weight.data.normal_(0.0, math.sqrt(1.0 / kernel_size)) self.dwconv.bias.data.zero_() self.pwconv1.weight.data.normal_(0.0, math.sqrt(2.0 / channels)) self.pwconv1.bias.data.zero_() self.pwconv2.weight.data.normal_(0.0, math.sqrt(1.0 / intermediate_channels)) self.pwconv2.bias.data.zero_() if use_weight_standardization: self.norm = nn.Identity() self.dwconv = WSConv1d(channels, channels, kernel_size, groups=channels) self.pwconv1 = WSLinear(channels, intermediate_channels) self.pwconv2 = WSLinear(intermediate_channels, channels) del self.gamma if enable_scaling: self.register_buffer("pre_scale", torch.tensor(pre_scale)) self.register_buffer("post_scale", torch.tensor(post_scale)) self.post_scale_weight = nn.Parameter(torch.ones(())) def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, kv: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.use_mha: batch_size, channels, length = x.size() if self.cross_attention: assert kv is not None else: assert kv is None assert length % 4 == 0 identity = x if self.cross_attention: # kv: [batch_size, kv_length, kv_channels] x = x.transpose(1, 2) x = self.attn_norm(x) x = self.mha(x, kv) x = x.transpose(1, 2) else: x = x.view(batch_size, channels, length // 4, 4) x = x.permute(0, 3, 2, 1) x = x.reshape(batch_size * 4, length // 4, channels) x = self.attn_norm(x) x, _ = self.mha( x, x, x, attn_mask=attn_mask, is_causal=True, need_weights=False ) x = x.view(batch_size, 4, length // 4, channels) x = x.permute(0, 3, 2, 1) x = x.reshape(batch_size, channels, length) x += identity identity = x if self.enable_scaling: x = x * self.pre_scale x = self.dwconv(x) x = x.transpose(1, 2) x = self.norm(x) x = self.pwconv1(x) x = F.gelu(x, approximate="tanh") x = self.pwconv2(x) if not self.use_weight_standardization: x *= self.gamma if self.enable_scaling: x *= self.post_scale * self.post_scale_weight x = x.transpose(1, 2) x += identity return x def merge_weights(self): if self.use_mha: if self.cross_attention: assert isinstance(self.mha, CrossAttention) self.mha.q_projection.bias.data += torch.mv( self.mha.q_projection.weight.data, self.attn_norm.bias.data ) self.mha.q_projection.weight.data *= self.attn_norm.weight.data[None, :] self.attn_norm.bias.data[:] = 0.0 self.attn_norm.weight.data[:] = 1.0 else: # self-attention assert isinstance(self.mha, nn.MultiheadAttention) self.mha.in_proj_bias.data += torch.mv( self.mha.in_proj_weight.data, self.attn_norm.bias.data ) self.mha.in_proj_weight.data *= self.attn_norm.weight.data[None, :] self.attn_norm.bias.data[:] = 0.0 self.attn_norm.weight.data[:] = 1.0 if self.use_weight_standardization: self.dwconv.merge_weights() self.pwconv1.merge_weights() self.pwconv2.merge_weights() else: self.pwconv1.bias.data += torch.mv( self.pwconv1.weight.data, self.norm.bias.data ) self.pwconv1.weight.data *= self.norm.weight.data[None, :] self.norm.bias.data[:] = 0.0 self.norm.weight.data[:] = 1.0 self.pwconv2.weight.data *= self.gamma.data[:, None] self.pwconv2.bias.data *= self.gamma.data self.gamma.data[:] = 1.0 if self.enable_scaling: self.dwconv.weight.data *= self.pre_scale.data self.pre_scale.data.fill_(1.0) self.pwconv2.weight.data *= ( self.post_scale.data * self.post_scale_weight.data ) self.pwconv2.bias.data *= self.post_scale.data * self.post_scale_weight.data self.post_scale.data.fill_(1.0) self.post_scale_weight.data.fill_(1.0) def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError if self.use_mha: dump_layer(self.mha, f) dump_layer(self.dwconv, f) dump_layer(self.pwconv1, f) dump_layer(self.pwconv2, f) class ConvNeXtStack(nn.Module): def __init__( self, in_channels: int, channels: int, intermediate_channels: int, n_blocks: int, delay: int, embed_kernel_size: int, kernel_size: int, use_weight_standardization: bool = False, enable_scaling: bool = False, use_mha: bool = False, cross_attention: bool = False, kv_channels: Optional[int] = None, ): super().__init__() assert delay * 2 + 1 <= embed_kernel_size assert not (use_weight_standardization and use_mha) # 未対応 self.use_weight_standardization = use_weight_standardization self.use_mha = use_mha self.cross_attention = cross_attention self.embed = CausalConv1d(in_channels, channels, embed_kernel_size, delay=delay) self.norm = nn.LayerNorm(channels) self.convnext = nn.ModuleList() for i in range(n_blocks): pre_scale = 1.0 / math.sqrt(1.0 + i / n_blocks) if enable_scaling else 1.0 post_scale = 1.0 / math.sqrt(n_blocks) if enable_scaling else 1.0 block = ConvNeXtBlock( channels=channels, intermediate_channels=intermediate_channels, layer_scale_init_value=1.0 / n_blocks, kernel_size=kernel_size, use_weight_standardization=use_weight_standardization, enable_scaling=enable_scaling, pre_scale=pre_scale, post_scale=post_scale, use_mha=use_mha, cross_attention=cross_attention, num_heads=4, attention_dropout=0.1, attention_channels=kv_channels, kv_channels=kv_channels, ) self.convnext.append(block) self.final_layer_norm = nn.LayerNorm(channels) self.embed.weight.data.normal_( 0.0, math.sqrt(0.5 / (embed_kernel_size * in_channels)) ) self.embed.bias.data.zero_() if use_weight_standardization: self.embed = WSConv1d(in_channels, channels, embed_kernel_size, delay=delay) self.norm = nn.Identity() self.final_layer_norm = nn.Identity() def forward( self, x: torch.Tensor, kv: Optional[torch.Tensor] = None ) -> torch.Tensor: x = self.embed(x) x = self.norm(x.transpose(1, 2)).transpose(1, 2) if self.use_mha and not self.cross_attention: pad_length = -x.size(2) % 4 if pad_length: x = F.pad(x, (0, pad_length)) t40 = x.size(2) // 4 attn_mask = torch.ones((t40, t40), dtype=torch.bool, device=x.device).triu( 1 ) else: attn_mask = None for conv_block in self.convnext: x = conv_block(x, attn_mask=attn_mask, kv=kv) if self.use_mha and not self.cross_attention and pad_length: x = x[:, :, :-pad_length] x = self.final_layer_norm(x.transpose(1, 2)).transpose(1, 2) return x def merge_weights(self): if self.use_weight_standardization: self.embed.merge_weights() for conv_block in self.convnext: conv_block.merge_weights() def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError dump_layer(self.embed, f) if not self.use_weight_standardization: dump_layer(self.norm, f) dump_layer(self.convnext, f) if not self.use_weight_standardization: dump_layer(self.final_layer_norm, f) def dump_kv(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump_kv(f) return if not hasattr(f, "write"): raise TypeError assert self.use_mha and self.cross_attention for conv_block in self.convnext: if not conv_block.use_mha or not conv_block.cross_attention: continue assert isinstance(conv_block, ConvNeXtBlock) assert hasattr(conv_block, "mha") assert isinstance(conv_block.mha, CrossAttention) conv_block.mha.dump_kv(f) class FeatureExtractor(nn.Module): def __init__(self, hidden_channels: int): super().__init__() # fmt: off self.conv0 = weight_norm(nn.Conv1d(1, hidden_channels // 8, 10, 5, bias=False)) self.conv1 = weight_norm(nn.Conv1d(hidden_channels // 8, hidden_channels // 4, 3, 2, bias=False)) self.conv2 = weight_norm(nn.Conv1d(hidden_channels // 4, hidden_channels // 2, 3, 2, bias=False)) self.conv3 = weight_norm(nn.Conv1d(hidden_channels // 2, hidden_channels, 3, 2, bias=False)) self.conv4 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 3, 2, bias=False)) self.conv5 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 2, 2, bias=False)) # fmt: on def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [batch_size, 1, wav_length] wav_length = x.size(2) if wav_length % 160 != 0: warnings.warn("wav_length % 160 != 0") x = F.pad(x, (40, 40)) x = F.gelu(self.conv0(x), approximate="tanh") x = F.gelu(self.conv1(x), approximate="tanh") x = F.gelu(self.conv2(x), approximate="tanh") x = F.gelu(self.conv3(x), approximate="tanh") x = F.gelu(self.conv4(x), approximate="tanh") x = F.gelu(self.conv5(x), approximate="tanh") # [batch_size, hidden_channels, wav_length / 160] return x def remove_weight_norm(self): remove_weight_norm(self.conv0) remove_weight_norm(self.conv1) remove_weight_norm(self.conv2) remove_weight_norm(self.conv3) remove_weight_norm(self.conv4) remove_weight_norm(self.conv5) def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError dump_layer(self.conv0, f) dump_layer(self.conv1, f) dump_layer(self.conv2, f) dump_layer(self.conv3, f) dump_layer(self.conv4, f) dump_layer(self.conv5, f) class FeatureProjection(nn.Module): def __init__(self, channels: int): super().__init__() self.norm = nn.LayerNorm(channels) self.dropout = nn.Dropout(0.1) def forward(self, x: torch.Tensor) -> torch.Tensor: # [batch_size, channels, length] x = self.norm(x.transpose(1, 2)).transpose(1, 2) x = self.dropout(x) return x class PhoneExtractor(nn.Module): def __init__( self, phone_channels: int = 128, hidden_channels: int = 128, backbone_embed_kernel_size: int = 9, kernel_size: int = 17, n_blocks: int = 20, ): super().__init__() self.feature_extractor = FeatureExtractor(hidden_channels) self.feature_projection = FeatureProjection(hidden_channels) self.backbone = ConvNeXtStack( in_channels=hidden_channels, channels=hidden_channels, intermediate_channels=hidden_channels * 3, n_blocks=n_blocks, delay=0, embed_kernel_size=backbone_embed_kernel_size, kernel_size=kernel_size, use_mha=True, ) self.head = weight_norm(nn.Conv1d(hidden_channels, phone_channels, 1)) def forward( self, x: torch.Tensor, return_stats: bool = True ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]: # x: [batch_size, 1, wav_length] stats = {} # [batch_size, 1, wav_length] -> [batch_size, feature_extractor_hidden_channels, length] x = self.feature_extractor(x) if return_stats: stats["feature_norm"] = x.detach().norm(dim=1).mean() # [batch_size, feature_extractor_hidden_channels, length] -> [batch_size, hidden_channels, length] x = self.feature_projection(x) # [batch_size, hidden_channels, length] x = self.backbone(x) # [batch_size, hidden_channels, length] -> [batch_size, phone_channels, length] phone = self.head(F.gelu(x, approximate="tanh")) results = [phone] if return_stats: stats["code_norm"] = phone.detach().norm(dim=1).mean() results.append(stats) if len(results) == 1: return results[0] return tuple(results) @torch.inference_mode() def units(self, x: torch.Tensor) -> torch.Tensor: # x: [batch_size, 1, wav_length] # [batch_size, 1, wav_length] -> [batch_size, phone_channels, length] phone = self.forward(x, return_stats=False) # [batch_size, phone_channels, length] -> [batch_size, length, phone_channels] phone = phone.transpose(1, 2) # [batch_size, length, phone_channels] return phone def remove_weight_norm(self): self.feature_extractor.remove_weight_norm() remove_weight_norm(self.head) def merge_weights(self): self.backbone.merge_weights() self.backbone.embed.bias.data += ( ( self.feature_projection.norm.bias.data[None, :, None] * self.backbone.embed.weight.data # [o, i, k] ) .sum(1) .sum(1) ) self.backbone.embed.weight.data *= self.feature_projection.norm.weight.data[ None, :, None ] self.feature_projection.norm.bias.data[:] = 0.0 self.feature_projection.norm.weight.data[:] = 1.0 def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError dump_layer(self.feature_extractor, f) dump_layer(self.backbone, f) dump_layer(self.head, f) class VectorQuantizer(nn.Module): def __init__( self, n_speakers: int, codebook_size: int, channels: int, topk: int = 4, training_time_vq: Literal["none", "self", "random"] = "none", ): super().__init__() assert 1 <= topk <= codebook_size self.n_speakers = n_speakers self.codebook_size = codebook_size self.channels = channels self.topk = topk self.training_time_vq = training_time_vq self.register_buffer( "codebooks", torch.empty(n_speakers, codebook_size, channels, dtype=torch.half), ) self.codebooks: torch.Tensor # VQ の適用箇所を変更しやすいように hook にしている self._hook_handle: Optional[torch.utils.hooks.RemovableHandle] = None self.target_speaker_ids: Optional[torch.Tensor] = None def _hook(_, __, output): return self(output, self.target_speaker_ids) self._hook_fn = _hook @torch.no_grad() def build_codebooks( self, collector_func: Callable, target_layer: nn.Module, inputs: Sequence[Iterable[torch.Tensor]], kmeans_n_iters: int = 50, ): assert len(inputs) == self.n_speakers assert self._hook_handle is None, "hook already installed" device = next(self.buffers()).device for spk_id, inps in enumerate(tqdm(inputs, desc="Building codebooks")): activations: list[torch.Tensor] = [] # TODO: データ多すぎる場合に間引く処理をする def _collect(_, __, output): # output: [batch_size, channels, length] activations.append(output.detach()) handle = target_layer.register_forward_hook(_collect) for x in inps: collector_func(x.to(device)) handle.remove() if not activations: raise RuntimeError(f"No activation collected for speaker {spk_id}") # [n_data, channels] activations: torch.Tensor = torch.cat( [ a.transpose(1, 2).reshape(a.size(0) * a.size(2), self.channels) for a in activations ] ) activations = activations.float() activations = F.normalize(activations, dim=1, eps=1e-6) # [codebook_size, channels] centers = ( self._kmeans_plus_plus(activations, self.codebook_size, kmeans_n_iters) if activations.size(0) >= self.codebook_size else self._pad_replicate(activations, self.codebook_size) ) self.codebooks[spk_id] = centers.to(self.codebooks.dtype) def forward( self, x: torch.Tensor, speaker_ids: Optional[torch.Tensor] = None ) -> torch.Tensor: batch_size, channels, length = x.size() assert channels == self.channels device = x.device dtype = x.dtype if self.training: if self.training_time_vq == "none": return x elif self.training_time_vq == "self": if self.target_speaker_ids is None: raise ValueError("target_speaker_ids is not set") elif self.training_time_vq == "random": speaker_ids = torch.randint( 0, self.n_speakers, (batch_size,), device=device ) else: raise ValueError(f"Unknown training_time_vq: {self.training_time_vq}") else: if speaker_ids is None: return x speaker_ids = speaker_ids.to(device) # [batch_size, channels, length] → [batch_size, length, channels] q = F.normalize(x, dim=1, eps=1e-6) codes = self.codebooks[speaker_ids].to(q.dtype) # [batch_size, length, codebook_size] sim = torch.einsum("bcl,bkc->blk", q, codes) # [batch_size, length, topk] _, topk_idx = sim.topk(self.topk, dim=-1) # [batch_size, length, codebook_size, channels] expanded_codes = codes[:, None, :, :].expand(-1, length, -1, -1) # [batch_size, length, topk, channels] expanded_topk_idx = topk_idx[:, :, :, None].expand(-1, -1, -1, channels) # [batch_size, length, topk, channels] gathered = expanded_codes.gather(2, expanded_topk_idx) # [batch_size, length, channels] gathered = gathered.mean(2) # [batch_size, channels, length] return gathered.transpose(1, 2).to(dtype) def enable_hook(self, target_layer: nn.Module): if self._hook_handle is not None: raise RuntimeError("hook already installed") self._hook_handle = target_layer.register_forward_hook(self._hook_fn) def disable_hook(self): if self._hook_handle is None: raise RuntimeError("hook not installed") self._hook_handle.remove() self._hook_handle = None def set_target_speaker_ids(self, speaker_ids: Optional[torch.Tensor]): # この話者が使われる条件は forward() を参照 self.target_speaker_ids = speaker_ids @staticmethod def _pad_replicate(x: torch.Tensor, n: int) -> torch.Tensor: # データ数が n に満たないとき適当に複製して埋める idx = torch.arange(n, device=x.device) % x.size(0) return x[idx] @staticmethod def _kmeans_plus_plus( x: torch.Tensor, n_clusters: int, n_iters: int = 50 ) -> torch.Tensor: n_data, _ = x.size() center_indices = [torch.randint(0, n_data, ()).item()] min_distances = torch.full((n_data,), math.inf, device=x.device) for _ in range(1, n_clusters): last_center_index = center_indices[-1] min_distances = min_distances.minimum( torch.cdist(x, x[last_center_index : last_center_index + 1]) .float() .square_() .squeeze_(1) ) probs = min_distances / (min_distances.sum() + 1e-12) center_indices.append(torch.multinomial(probs, 1).item()) centers = x[center_indices] del min_distances, probs for _ in range(n_iters): distances = torch.cdist(x, centers) # [n_data, n_clusters] labels = distances.argmin(1) # [n_data] # [n_clusters, dim] new_centers = torch.zeros_like(centers).index_add_(0, labels, x) # [n_clusters] counts = labels.bincount(minlength=n_clusters) if (counts == 0).sum().item() != 0: # TODO: 割り当てがないクラスタの処理 warnings.warn("Some clusters have no assigned data points.") new_centers /= counts[:, None].clamp_(min=1).float() centers = new_centers return centers def extract_pitch_features( y: torch.Tensor, # [..., wav_length] hop_length: int = 160, # 10ms win_length: int = 560, # 35ms max_corr_period: int = 256, # 16ms, 62.5Hz (16000 / 256) corr_win_length: int = 304, # 19ms instfreq_features_cutoff_bin: int = 64, # 1828Hz (16000 * 64 / 560) ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert max_corr_period + corr_win_length == win_length # パディングする padding_length = (win_length - hop_length) // 2 y = F.pad(y, (padding_length, padding_length)) # フレームにする # [..., win_length, n_frames] y_frames = y.unfold(-1, win_length, hop_length).transpose_(-2, -1) # 複素スペクトログラム # Complex[..., (win_length // 2 + 1), n_frames] spec: torch.Tensor = torch.fft.rfft(y_frames, n=win_length, dim=-2) # Complex[..., instfreq_features_cutoff_bin, n_frames] spec = spec[..., :instfreq_features_cutoff_bin, :] # 対数パワースペクトログラム log_power_spec = spec.abs().add_(1e-5).log10_() # 瞬時位相の時間差分 # 時刻 0 の値は 0 delta_spec = spec[..., :, 1:] * spec[..., :, :-1].conj() delta_spec /= delta_spec.abs().add_(1e-5) delta_spec = torch.cat( [torch.zeros_like(delta_spec[..., :, :1]), delta_spec], dim=-1 ) # [..., instfreq_features_cutoff_bin * 3, n_frames] instfreq_features = torch.cat( [log_power_spec, delta_spec.real, delta_spec.imag], dim=-2 ) # 自己相関 # 元々これに 2.0 / corr_win_length を掛けて使おうと思っていたが、 # この値は振幅の 2 乗に比例していて、NN に入力するために良い感じに分散を # 標準化する方法が思いつかなかったのでやめた flipped_y_frames = y_frames.flip((-2,)) a = torch.fft.rfft(flipped_y_frames, n=win_length, dim=-2) b = torch.fft.rfft(y_frames[..., -corr_win_length:, :], n=win_length, dim=-2) # [..., max_corr_period, n_frames] corr = torch.fft.irfft(a * b, n=win_length, dim=-2)[..., corr_win_length:, :] # エネルギー項 energy = flipped_y_frames.square_().cumsum_(-2) energy0 = energy[..., corr_win_length - 1 : corr_win_length, :] energy = energy[..., corr_win_length:, :] - energy[..., :-corr_win_length, :] # Difference function corr_diff = (energy0 + energy).sub_(corr.mul_(2.0)) assert corr_diff.min() >= -1e-3, corr_diff.min() corr_diff.clamp_(min=0.0) # 計算誤差対策 # 標準化 corr_diff *= 2.0 / corr_win_length corr_diff.sqrt_() # 変換モデルへの入力用のエネルギー energy = ( (y_frames * torch.signal.windows.cosine(win_length, device=y.device)[..., None]) .square_() .sum(-2, keepdim=True) ) energy.clamp_(min=1e-3).log10_() # >= -3, 振幅 1 の正弦波なら大体 2.15 energy *= 0.5 # >= -1.5, 振幅 1 の正弦波なら大体 1.07, 1 の差は振幅で 20dB の差 return ( instfreq_features, # [..., instfreq_features_cutoff_bin * 3, n_frames] corr_diff, # [..., max_corr_period, n_frames] energy, # [..., 1, n_frames] ) class PitchEstimator(nn.Module): def __init__( self, input_instfreq_channels: int = 192, input_corr_channels: int = 256, pitch_bins: int = 448, channels: int = 192, intermediate_channels: int = 192 * 2, n_blocks: int = 9, delay: int = 1, # 10ms, 特徴抽出と合わせると 22.5ms embed_kernel_size: int = 3, kernel_size: int = 33, pitch_bins_per_octave: int = 96, ): super().__init__() self.pitch_bins_per_octave = pitch_bins_per_octave self.instfreq_embed_0 = nn.Conv1d(input_instfreq_channels, channels, 1) self.instfreq_embed_1 = nn.Conv1d(channels, channels, 1) self.corr_embed_0 = nn.Conv1d(input_corr_channels, channels, 1) self.corr_embed_1 = nn.Conv1d(channels, channels, 1) self.backbone = ConvNeXtStack( channels, channels, intermediate_channels, n_blocks, delay, embed_kernel_size, kernel_size, enable_scaling=True, ) self.head = nn.Conv1d(channels, pitch_bins, 1) def forward(self, wav: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # wav: [batch_size, 1, wav_length] # [batch_size, input_instfreq_channels, length], # [batch_size, input_corr_channels, length] with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): instfreq_features, corr_diff, energy = extract_pitch_features( wav.squeeze(1), hop_length=160, win_length=560, max_corr_period=256, corr_win_length=304, instfreq_features_cutoff_bin=64, ) instfreq_features = F.gelu( self.instfreq_embed_0(instfreq_features), approximate="tanh" ) instfreq_features = self.instfreq_embed_1(instfreq_features) corr_diff = F.gelu(self.corr_embed_0(corr_diff), approximate="tanh") corr_diff = self.corr_embed_1(corr_diff) # [batch_size, channels, length] x = F.gelu(instfreq_features + corr_diff, approximate="tanh") x = self.backbone(x) # [batch_size, pitch_bins, length] x = self.head(x) return x, energy def sample_pitch( self, pitch: torch.Tensor, band_width: int = 4, return_features: bool = False ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: # pitch: [batch_size, pitch_bins, length] # 返されるピッチの値には 0 は含まれない batch_size, pitch_bins, length = pitch.size() pitch = pitch.softmax(1) if return_features: unvoiced_proba = pitch[:, :1, :].clone() pitch[:, 0, :] = -100.0 pitch = ( pitch.transpose(1, 2).contiguous().view(batch_size * length, 1, pitch_bins) ) band_pitch = F.conv1d( pitch, torch.ones((1, 1, 1), device=pitch.device).expand(1, 1, band_width), ) # [batch_size * length, 1, pitch_bins - band_width + 1] -> Long[batch_size * length, 1] quantized_band_pitch = band_pitch.argmax(2) if return_features: # [batch_size * length, 1] band_proba = band_pitch.gather(2, quantized_band_pitch[:, :, None]) # [batch_size * length, 1] half_pitch_band_proba = band_pitch.gather( 2, (quantized_band_pitch - self.pitch_bins_per_octave).clamp_(min=1)[ :, :, None ], ) half_pitch_band_proba[ quantized_band_pitch <= self.pitch_bins_per_octave ] = 0.0 half_pitch_proba = (half_pitch_band_proba / (band_proba + 1e-6)).view( batch_size, 1, length ) # [batch_size * length, 1] double_pitch_band_proba = band_pitch.gather( 2, (quantized_band_pitch + self.pitch_bins_per_octave).clamp_( max=pitch_bins - band_width )[:, :, None], ) double_pitch_band_proba[ quantized_band_pitch > pitch_bins - band_width - self.pitch_bins_per_octave ] = 0.0 double_pitch_proba = (double_pitch_band_proba / (band_proba + 1e-6)).view( batch_size, 1, length ) # Long[1, pitch_bins] mask = torch.arange(pitch_bins, device=pitch.device)[None, :] # bool[batch_size * length, pitch_bins] mask = (quantized_band_pitch <= mask) & ( mask < quantized_band_pitch + band_width ) # Long[batch_size, length] quantized_pitch = (pitch.squeeze(1) * mask).argmax(1).view(batch_size, length) if return_features: features = torch.cat( [unvoiced_proba, half_pitch_proba, double_pitch_proba], dim=1 ) # Long[batch_size, length], [batch_size, 3, length] return quantized_pitch, features else: return quantized_pitch def merge_weights(self): self.backbone.merge_weights() def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError dump_layer(self.instfreq_embed_0, f) dump_layer(self.instfreq_embed_1, f) dump_layer(self.corr_embed_0, f) dump_layer(self.corr_embed_1, f) dump_layer(self.backbone, f) dump_layer(self.head, f) def overlap_add( ir_amp: torch.Tensor, ir_phase: torch.Tensor, window: torch.Tensor, pitch: torch.Tensor, hop_length: int = 240, delay: int = 0, sr: float = 24000.0, ) -> torch.Tensor: batch_size, ir_length, length = ir_amp.size() ir_length = (ir_length - 1) * 2 assert ir_phase.size() == ir_amp.size() assert window.size() == (ir_length,), (window.size(), ir_amp.size()) assert pitch.size() == (batch_size, length * hop_length) assert 0 <= delay < ir_length, (delay, ir_length) # 正規化角周波数 [2π rad] normalized_freq = pitch / sr # 初期位相 [2π rad] をランダムに設定 normalized_freq[:, 0] = torch.rand(batch_size, device=pitch.device) with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): phase = (normalized_freq.double().cumsum_(1) % 1.0).float() # 重ねる箇所を求める # [n_pitchmarks], [n_pitchmarks] indices0, indices1 = torch.nonzero(phase[:, :-1] > phase[:, 1:], as_tuple=True) # 重ねる箇所の小数部分 (位相の遅れ) を求める numer = 1.0 - phase[indices0, indices1] # [n_pitchmarks] fractional_part = numer / (numer + phase[indices0, indices1 + 1]) # 重ねる値を求める # Complex[n_pitchmarks, ir_length / 2 + 1] ir_amp = ir_amp[indices0, :, indices1 // hop_length] ir_phase = ir_phase[indices0, :, indices1 // hop_length] # 位相遅れの量 [rad] # [n_pitchmarks, ir_length / 2 + 1] delay_phase = ( torch.arange(ir_length // 2 + 1, device=pitch.device, dtype=torch.float32)[ None, : ] * (-math.tau / ir_length) * fractional_part[:, None] ) # Complex[n_pitchmarks, ir_length / 2 + 1] spec = torch.polar(ir_amp, ir_phase + delay_phase) # [n_pitchmarks, ir_length] ir = torch.fft.irfft(spec, n=ir_length, dim=1) ir *= window # 加算する値をサンプル単位にばらす # [n_pitchmarks * ir_length] ir = ir.ravel() # Long[n_pitchmarks * ir_length] indices0 = indices0[:, None].expand(-1, ir_length).ravel() # Long[n_pitchmarks * ir_length] indices1 = ( indices1[:, None] + torch.arange(ir_length, device=pitch.device) ).ravel() # overlap-add する overlap_added_signal = torch.zeros( (batch_size, length * hop_length + ir_length), device=pitch.device ) overlap_added_signal.index_put_((indices0, indices1), ir, accumulate=True) overlap_added_signal = overlap_added_signal[:, delay : -ir_length + delay] return overlap_added_signal def generate_noise( aperiodicity: torch.Tensor, delay: int = 0 ) -> tuple[torch.Tensor, torch.Tensor]: # aperiodicity: [batch_size, hop_length, length] batch_size, hop_length, length = aperiodicity.size() excitation = torch.rand( batch_size, (length + 1) * hop_length, device=aperiodicity.device ) excitation -= 0.5 n_fft = 2 * hop_length # 矩形窓で分析 # Complex[batch_size, hop_length + 1, length] noise = torch.stft( excitation, n_fft=n_fft, hop_length=hop_length, window=torch.ones(n_fft, device=excitation.device), center=False, return_complex=True, ) assert noise.size(2) == aperiodicity.size(2) noise[:, 0, :] = 0.0 noise[:, 1:, :] *= aperiodicity # ハン窓で合成 # torch.istft は最適合成窓が使われるので使えないことに注意 # [batch_size, 2 * hop_length, length] noise = torch.fft.irfft(noise, n=2 * hop_length, dim=1) noise *= torch.hann_window(2 * hop_length, device=noise.device)[None, :, None] # [batch_size, (length + 1) * hop_length] noise = F.fold( noise, (1, (length + 1) * hop_length), (1, 2 * hop_length), stride=(1, hop_length), ).squeeze_((1, 2)) assert delay < hop_length noise = noise[:, delay : -hop_length + delay] excitation = excitation[:, delay : -hop_length + delay] return noise, excitation # [batch_size, length * hop_length] D4C_PREVENT_ZERO_DIVISION = True # False にすると本家の処理 def interp(x: torch.Tensor, y: torch.Tensor, xi: torch.Tensor) -> torch.Tensor: # x が単調増加で等間隔と仮定 # 外挿は起こらないと仮定 x = torch.as_tensor(x) y = torch.as_tensor(y) xi = torch.as_tensor(xi) if xi.ndim < y.ndim: diff_ndim = y.ndim - xi.ndim xi = xi.view(tuple([1] * diff_ndim) + xi.size()) if xi.size()[:-1] != y.size()[:-1]: xi = xi.expand(y.size()[:-1] + (xi.size(-1),)) assert (x.min(-1).values == x[..., 0]).all() assert (x.max(-1).values == x[..., -1]).all() assert (xi.min(-1).values >= x[..., 0]).all() assert (xi.max(-1).values <= x[..., -1]).all() delta_x = (x[..., -1].double() - x[..., 0].double()) / (x.size(-1) - 1.0) delta_x = delta_x.to(x.dtype) xi = (xi - x[..., :1]).div_(delta_x[..., None]) xi_base = xi.floor() xi_fraction = xi.sub_(xi_base) xi_base = xi_base.long() delta_y = y.diff(dim=-1, append=y[..., -1:]) yi = y.gather(-1, xi_base) + delta_y.gather(-1, xi_base) * xi_fraction return yi def linear_smoothing( group_delay: torch.Tensor, sr: int, n_fft: int, width: torch.Tensor ) -> torch.Tensor: group_delay = torch.as_tensor(group_delay) assert group_delay.size(-1) == n_fft // 2 + 1 width = torch.as_tensor(width) boundary = (width.max() * n_fft / sr).long() + 1 dtype = group_delay.dtype device = group_delay.device fft_resolution = sr / n_fft mirroring_freq_axis = ( torch.arange(-boundary, n_fft // 2 + 1 + boundary, dtype=dtype, device=device) .add(0.5) .mul(fft_resolution) ) if group_delay.ndim == 1: mirroring_spec = F.pad( group_delay[None], (boundary, boundary), mode="reflect" ).squeeze_(0) elif group_delay.ndim >= 4: shape = group_delay.size() mirroring_spec = F.pad( group_delay.view(math.prod(shape[:-1]), group_delay.size(-1)), (boundary, boundary), mode="reflect", ).view(shape[:-1] + (shape[-1] + 2 * boundary,)) else: mirroring_spec = F.pad(group_delay, (boundary, boundary), mode="reflect") mirroring_segment = mirroring_spec.mul(fft_resolution).cumsum_(-1) center_freq = torch.arange(n_fft // 2 + 1, dtype=dtype, device=device).mul_( fft_resolution ) low_freq = center_freq - width[..., None] * 0.5 high_freq = center_freq + width[..., None] * 0.5 levels = interp( mirroring_freq_axis, mirroring_segment, torch.cat([low_freq, high_freq], dim=-1) ) low_levels, high_levels = levels.split([n_fft // 2 + 1] * 2, dim=-1) smoothed = (high_levels - low_levels).div_(width[..., None]) return smoothed def dc_correction( spec: torch.Tensor, sr: int, n_fft: int, f0: torch.Tensor ) -> torch.Tensor: spec = torch.as_tensor(spec) f0 = torch.as_tensor(f0) dtype = spec.dtype device = spec.device upper_limit = 2 + (f0 * (n_fft / sr)).long() max_upper_limit = upper_limit.max() upper_limit_mask = ( torch.arange(max_upper_limit - 1, device=device) < (upper_limit - 1)[..., None] ) low_freq_axis = torch.arange(max_upper_limit + 1, dtype=dtype, device=device) * ( sr / n_fft ) low_freq_replica = interp( f0[..., None] - low_freq_axis.flip(-1), spec[..., : max_upper_limit + 1].flip(-1), low_freq_axis[..., : max_upper_limit - 1] * upper_limit_mask, ) output = spec.clone() output[..., : max_upper_limit - 1] += low_freq_replica * upper_limit_mask return output def nuttall(n: int, device: torch.types.Device) -> torch.Tensor: t = torch.linspace(0, math.tau, n, device=device) coefs = torch.tensor([0.355768, -0.487396, 0.144232, -0.012604], device=device) terms = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device) cos_matrix = (terms[:, None] * t).cos_() # [4, n] window = coefs.matmul(cos_matrix) return window def get_windowed_waveform( x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, half_window_length_ratio: float, window_type: Literal["hann", "blackman"], n_fft: int, ) -> tuple[torch.Tensor, torch.Tensor]: x = torch.as_tensor(x) f0 = torch.as_tensor(f0) position = torch.as_tensor(position) current_sample = position * sr # [...] half_window_length = (half_window_length_ratio * sr / f0).add_(0.5).long() # [..., fft_size] base_index = -half_window_length[..., None] + torch.arange(n_fft, device=x.device) base_index_mask = base_index <= half_window_length[..., None] # [..., fft_size] safe_index = ((current_sample + 0.501).long()[..., None] + base_index).clamp_( 0, x.size(-1) - 1 ) # [..., fft_size] time_axis = base_index.to(x.dtype).div_(half_window_length_ratio) # [...] normalized_f0 = math.pi / sr * f0 # [..., fft_size] phase = time_axis.mul_(normalized_f0[..., None]) if window_type == "hann": window = phase.cos_().mul_(0.5).add_(0.5) elif window_type == "blackman": window = phase.mul(2.0).cos_().mul_(0.08).add_(phase.cos().mul_(0.5)).add_(0.42) else: assert False window *= base_index_mask prefix_shape = tuple( max(x_size, i_size) for x_size, i_size in zip(x.size(), safe_index.size()) )[:-1] waveform = ( x.expand(prefix_shape + (-1,)) .gather(-1, safe_index.expand(prefix_shape + (-1,))) .mul_(window) ) if not D4C_PREVENT_ZERO_DIVISION: waveform += torch.randn_like(window).mul_(1e-12) waveform *= base_index_mask waveform -= window * waveform.sum(-1, keepdim=True).div_( window.sum(-1, keepdim=True) ) return waveform, window def get_centroid(x: torch.Tensor, n_fft: int) -> torch.Tensor: x = torch.as_tensor(x) if D4C_PREVENT_ZERO_DIVISION: x = x / x.norm(dim=-1, keepdim=True).clamp(min=6e-8) else: x = x / x.norm(dim=-1, keepdim=True) spec0 = torch.fft.rfft(x, n_fft) spec1 = torch.fft.rfft( x * torch.arange(1, x.size(-1) + 1, dtype=x.dtype, device=x.device).div_(n_fft), n_fft, ) centroid = spec0.real * spec1.real + spec0.imag * spec1.imag return centroid def get_static_centroid( x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int ) -> torch.Tensor: """First step: calculation of temporally static parameters on basis of group delay""" x1, _ = get_windowed_waveform( x, sr, f0, position + 0.25 / f0, 2.0, "blackman", n_fft ) x2, _ = get_windowed_waveform( x, sr, f0, position - 0.25 / f0, 2.0, "blackman", n_fft ) centroid1 = get_centroid(x1, n_fft) centroid2 = get_centroid(x2, n_fft) return dc_correction(centroid1 + centroid2, sr, n_fft, f0) def get_smoothed_power_spec( x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int ) -> tuple[torch.Tensor, torch.Tensor]: x = torch.as_tensor(x) f0 = torch.as_tensor(f0) x, window = get_windowed_waveform(x, sr, f0, position, 2.0, "hann", n_fft) window_weight = window.square().sum(-1, keepdim=True) rms = x.square().sum(-1, keepdim=True).div_(window_weight).sqrt_() if D4C_PREVENT_ZERO_DIVISION: x = x / (rms * math.sqrt(n_fft)).clamp_(min=6e-8) smoothed_power_spec = torch.fft.rfft(x, n_fft).abs().square_() smoothed_power_spec = dc_correction(smoothed_power_spec, sr, n_fft, f0) smoothed_power_spec = linear_smoothing(smoothed_power_spec, sr, n_fft, f0) return smoothed_power_spec, rms.detach().squeeze(-1) def get_static_group_delay( static_centroid: torch.Tensor, smoothed_power_spec: torch.Tensor, sr: int, f0: torch.Tensor, n_fft: int, ) -> torch.Tensor: """Second step: calculation of parameter shaping""" if D4C_PREVENT_ZERO_DIVISION: smoothed_power_spec = smoothed_power_spec.clamp(min=6e-8) static_group_delay = static_centroid / smoothed_power_spec # t_g static_group_delay = linear_smoothing( static_group_delay, sr, n_fft, f0 * 0.5 ) # t_gs smoothed_group_delay = linear_smoothing(static_group_delay, sr, n_fft, f0) # t_gb static_group_delay = static_group_delay - smoothed_group_delay # t_D return static_group_delay def get_coarse_aperiodicity( group_delay: torch.Tensor, sr: int, n_fft: int, freq_interval: int, n_aperiodicities: int, window: torch.Tensor, ) -> torch.Tensor: """Third step: estimation of band-aperiodicity""" group_delay = torch.as_tensor(group_delay) window = torch.as_tensor(window) boundary = int(round(n_fft * 8 / window.size(-1))) half_window_length = window.size(-1) // 2 coarse_aperiodicity = torch.empty( group_delay.size()[:-1] + (n_aperiodicities,), dtype=group_delay.dtype, device=group_delay.device, ) for i in range(n_aperiodicities): center = freq_interval * (i + 1) * n_fft // sr segment = ( group_delay[ ..., center - half_window_length : center + half_window_length + 1 ] * window ) power_spec: torch.Tensor = torch.fft.rfft(segment, n_fft).abs().square_() cumulative_power_spec = power_spec.sort(-1).values.cumsum_(-1) if D4C_PREVENT_ZERO_DIVISION: cumulative_power_spec.clamp_(min=6e-8) coarse_aperiodicity[..., i] = ( cumulative_power_spec[..., n_fft // 2 - boundary - 1] / cumulative_power_spec[..., -1] ) coarse_aperiodicity.log10_().mul_(10.0) return coarse_aperiodicity def d4c_love_train( x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, threshold: float ) -> int: x = torch.as_tensor(x) position = torch.as_tensor(position) f0: torch.Tensor = torch.as_tensor(f0) vuv = f0 != 0 lowest_f0 = 40 f0 = f0.clamp(min=lowest_f0) n_fft = 1 << (3 * sr // lowest_f0).bit_length() boundary0 = (100 * n_fft - 1) // sr + 1 boundary1 = (4000 * n_fft - 1) // sr + 1 boundary2 = (7900 * n_fft - 1) // sr + 1 waveform, _ = get_windowed_waveform(x, sr, f0, position, 1.5, "blackman", n_fft) power_spec = torch.fft.rfft(waveform, n_fft).abs().square_() power_spec[..., : boundary0 + 1] = 0.0 cumulative_spec = power_spec.cumsum_(-1) vuv = vuv & ( cumulative_spec[..., boundary1] > threshold * cumulative_spec[..., boundary2] ) return vuv def d4c_general_body( x: torch.Tensor, sr: int, f0: torch.Tensor, freq_interval: int, position: torch.Tensor, n_fft: int, n_aperiodicities: int, window: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: static_centroid = get_static_centroid(x, sr, f0, position, n_fft) smoothed_power_spec, rms = get_smoothed_power_spec(x, sr, f0, position, n_fft) static_group_delay = get_static_group_delay( static_centroid, smoothed_power_spec, sr, f0, n_fft ) coarse_aperiodicity = get_coarse_aperiodicity( static_group_delay, sr, n_fft, freq_interval, n_aperiodicities, window ) coarse_aperiodicity.add_((f0[..., None] - 100.0).div_(50.0)).clamp_(max=0.0) return coarse_aperiodicity, rms def d4c( x: torch.Tensor, f0: torch.Tensor, t: torch.Tensor, sr: int, threshold: float = 0.85, n_fft_spec: Optional[int] = None, coarse_only: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """Adapted from https://github.com/tuanad121/Python-WORLD/blob/master/world/d4c.py""" FLOOR_F0 = 71 FLOOR_F0_D4C = 47 UPPER_LIMIT = 15000 FREQ_INTERVAL = 3000 assert sr == int(sr) sr = int(sr) assert sr % 2 == 0 x = torch.as_tensor(x) f0 = torch.as_tensor(f0) temporal_positions = torch.as_tensor(t) n_fft_d4c = 1 << (4 * sr // FLOOR_F0_D4C).bit_length() if n_fft_spec is None: n_fft_spec = 1 << (3 * sr // FLOOR_F0).bit_length() n_aperiodicities = min(UPPER_LIMIT, sr // 2 - FREQ_INTERVAL) // FREQ_INTERVAL assert n_aperiodicities >= 1 window_length = FREQ_INTERVAL * n_fft_d4c // sr * 2 + 1 window = nuttall(window_length, device=x.device) freq_axis = torch.arange(n_fft_spec // 2 + 1, device=x.device) * (sr / n_fft_spec) coarse_aperiodicity, rms = d4c_general_body( x[..., None, :], sr, f0.clamp(min=FLOOR_F0_D4C), FREQ_INTERVAL, temporal_positions, n_fft_d4c, n_aperiodicities, window, ) if coarse_only: return coarse_aperiodicity, rms even_coarse_axis = ( torch.arange(n_aperiodicities + 3, device=x.device) * FREQ_INTERVAL ) assert even_coarse_axis[-2] <= sr // 2 < even_coarse_axis[-1], sr coarse_axis_low = ( torch.arange(n_aperiodicities + 1, dtype=torch.float, device=x.device) * FREQ_INTERVAL ) aperiodicity_low = interp( coarse_axis_low, F.pad(coarse_aperiodicity, (1, 0), value=-60.0), freq_axis[freq_axis < n_aperiodicities * FREQ_INTERVAL], ) coarse_axis_high = torch.tensor( [n_aperiodicities * FREQ_INTERVAL, sr * 0.5], device=x.device ) aperiodicity_high = interp( coarse_axis_high, F.pad(coarse_aperiodicity[..., -1:], (0, 1), value=-1e-12), freq_axis[freq_axis >= n_aperiodicities * FREQ_INTERVAL], ) aperiodicity = torch.cat([aperiodicity_low, aperiodicity_high], -1) aperiodicity = 10.0 ** (aperiodicity / 20.0) vuv = d4c_love_train(x[..., None, :], sr, f0, temporal_positions, threshold) aperiodicity = torch.where(vuv[..., None], aperiodicity, 1 - 1e-12) return aperiodicity, coarse_aperiodicity class Vocoder(nn.Module): def __init__( self, channels: int, speaker_embedding_channels: int = 128, hop_length: int = 240, n_pre_blocks: int = 4, out_sample_rate: float = 24000.0, ): super().__init__() self.hop_length = hop_length self.out_sample_rate = out_sample_rate self.prenet = ConvNeXtStack( in_channels=channels, channels=channels, intermediate_channels=channels * 2, n_blocks=n_pre_blocks, delay=2, # 20ms 遅延 embed_kernel_size=7, kernel_size=33, enable_scaling=True, use_mha=True, cross_attention=True, kv_channels=speaker_embedding_channels, ) self.ir_generator = ConvNeXtStack( in_channels=channels, channels=channels, intermediate_channels=channels * 2, n_blocks=2, delay=0, embed_kernel_size=3, kernel_size=33, use_weight_standardization=True, enable_scaling=True, ) self.ir_generator_post = WSConv1d(channels, 512, 1) self.register_buffer("ir_scale", torch.tensor(1.0)) self.ir_window = nn.Parameter(torch.ones(512)) self.aperiodicity_generator = ConvNeXtStack( in_channels=channels, channels=channels, intermediate_channels=channels * 2, n_blocks=1, delay=0, embed_kernel_size=3, kernel_size=33, use_weight_standardization=True, enable_scaling=True, ) self.aperiodicity_generator_post = WSConv1d(channels, hop_length, 1, bias=False) self.register_buffer("aperiodicity_scale", torch.tensor(0.005)) self.post_filter_generator = ConvNeXtStack( in_channels=channels, channels=channels, intermediate_channels=channels * 2, n_blocks=1, delay=0, embed_kernel_size=3, kernel_size=33, use_weight_standardization=True, enable_scaling=True, ) self.post_filter_generator_post = WSConv1d(channels, 512, 1, bias=False) self.register_buffer("post_filter_scale", torch.tensor(0.01)) def forward( self, x: torch.Tensor, pitch: torch.Tensor, speaker_embedding: torch.Tensor ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: # x: [batch_size, channels, length] # pitch: [batch_size, length] # speaker_embedding: [batch_size, speaker_embedding_length, speaker_embedding_channels] batch_size, _, length = x.size() x = self.prenet(x, speaker_embedding) ir = self.ir_generator(x) ir = F.silu(ir, inplace=True) # [batch_size, 512, length] ir = self.ir_generator_post(ir) ir *= self.ir_scale ir_amp = ir[:, : ir.size(1) // 2 + 1, :].exp() ir_phase = F.pad(ir[:, ir.size(1) // 2 + 1 :, :], (0, 0, 1, 1)) ir_phase[:, 1::2, :] += math.pi # TODO: 直流成分が正の値しか取れないのを修正する # 最近傍補間 # [batch_size, length * hop_length] pitch = torch.repeat_interleave(pitch, self.hop_length, dim=1) # [batch_size, length * hop_length] periodic_signal = overlap_add( ir_amp, ir_phase, self.ir_window, pitch, self.hop_length, delay=0, sr=self.out_sample_rate, ) aperiodicity = self.aperiodicity_generator(x) aperiodicity = F.silu(aperiodicity, inplace=True) # [batch_size, hop_length, length] aperiodicity = self.aperiodicity_generator_post(aperiodicity) aperiodicity *= self.aperiodicity_scale # [batch_size, length * hop_length], [batch_size, length * hop_length] aperiodic_signal, noise_excitation = generate_noise(aperiodicity, delay=0) post_filter = self.post_filter_generator(x) post_filter = F.silu(post_filter, inplace=True) # [batch_size, 512, length] post_filter = self.post_filter_generator_post(post_filter) post_filter *= self.post_filter_scale post_filter[:, 0, :] += 1.0 # [batch_size, length, 512] post_filter = post_filter.transpose(1, 2) with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): periodic_signal = periodic_signal.float() aperiodic_signal = aperiodic_signal.float() post_filter = post_filter.float() post_filter = torch.fft.rfft(post_filter, n=768) # [batch_size, length, 768] periodic_signal = torch.fft.irfft( torch.fft.rfft( periodic_signal.view(batch_size, length, self.hop_length), n=768 ) * post_filter, n=768, ) aperiodic_signal = torch.fft.irfft( torch.fft.rfft( aperiodic_signal.view(batch_size, length, self.hop_length), n=768 ) * post_filter, n=768, ) periodic_signal = F.fold( periodic_signal.transpose(1, 2), (1, (length - 1) * self.hop_length + 768), (1, 768), stride=(1, self.hop_length), ).squeeze_((1, 2)) aperiodic_signal = F.fold( aperiodic_signal.transpose(1, 2), (1, (length - 1) * self.hop_length + 768), (1, 768), stride=(1, self.hop_length), ).squeeze_((1, 2)) periodic_signal = periodic_signal[:, 120 : 120 + length * self.hop_length] aperiodic_signal = aperiodic_signal[:, 120 : 120 + length * self.hop_length] noise_excitation = noise_excitation[:, 120:] # TODO: compensation の正確さが怪しくなってくる。今も本当に必要なのか? # [batch_size, 1, length * hop_length] y_g_hat = (periodic_signal + aperiodic_signal)[:, None, :] return y_g_hat, { "periodic_signal": periodic_signal.detach(), "aperiodic_signal": aperiodic_signal.detach(), "noise_excitation": noise_excitation.detach(), } def merge_weights(self): self.prenet.merge_weights() self.ir_generator.merge_weights() self.ir_generator_post.merge_weights() self.aperiodicity_generator.merge_weights() self.aperiodicity_generator_post.merge_weights() self.ir_generator_post.weight.data *= self.ir_scale self.ir_generator_post.bias.data *= self.ir_scale self.ir_scale.fill_(1.0) self.aperiodicity_generator_post.weight.data *= self.aperiodicity_scale self.aperiodicity_scale.fill_(1.0) self.post_filter_generator.merge_weights() self.post_filter_generator_post.merge_weights() self.post_filter_generator_post.weight.data *= self.post_filter_scale self.post_filter_scale.fill_(1.0) def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError dump_layer(self.prenet, f) dump_layer(self.ir_generator, f) dump_layer(self.ir_generator_post, f) dump_layer(self.ir_window, f) dump_layer(self.aperiodicity_generator, f) dump_layer(self.aperiodicity_generator_post, f) dump_layer(self.post_filter_generator, f) dump_layer(self.post_filter_generator_post, f) def compute_loudness( x: torch.Tensor, sr: int, win_lengths: list[int] ) -> list[torch.Tensor]: # x: [batch_size, wav_length] assert x.ndim == 2 n_fft = 2048 chunk_length = n_fft // 2 n_taps = chunk_length + 1 results = [] with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): if not hasattr(compute_loudness, "filter"): compute_loudness.filter = {} if sr not in compute_loudness.filter: ir = torch.zeros(n_taps, device=x.device, dtype=torch.double) ir[0] = 0.5 ir = torchaudio.functional.treble_biquad( ir, sr, 4.0, 1500.0, 1.0 / math.sqrt(2) ) ir = torchaudio.functional.highpass_biquad(ir, sr, 38.0, 0.5) ir *= 2.0 compute_loudness.filter[sr] = torch.fft.rfft(ir, n=n_fft).to( torch.complex64 ) x = x.float() wav_length = x.size(-1) if wav_length % chunk_length != 0: x = F.pad(x, (0, chunk_length - wav_length % chunk_length)) padded_wav_length = x.size(-1) x = x.view(x.size()[:-1] + (padded_wav_length // chunk_length, chunk_length)) x = torch.fft.irfft( torch.fft.rfft(x, n=n_fft) * compute_loudness.filter[sr], n=n_fft, ) x = F.fold( x.transpose(-2, -1), (1, padded_wav_length + chunk_length), (1, n_fft), stride=(1, chunk_length), ).squeeze_((-3, -2))[..., :wav_length] x.square_() for win_length in win_lengths: hop_length = win_length // 4 # [..., n_frames] energy = ( x.unfold(-1, win_length, hop_length) .matmul(torch.hann_window(win_length, device=x.device)) .add_(win_length / 4.0 * 1e-5) .log10_() ) # フィルタリング後の波形が振幅 1 の正弦波なら大体 log10(win_length/4), 1 の差は 10dB の差 results.append(energy) return results def beatrice_slice_segments( x: torch.Tensor, start_indices: torch.Tensor, segment_length: int ) -> torch.Tensor: batch_size, channels, _ = x.size() # [batch_size, 1, segment_size] indices = start_indices[:, None, None] + torch.arange( segment_length, device=start_indices.device ) # [batch_size, channels, segment_size] indices = indices.expand(batch_size, channels, segment_length) return x.gather(2, indices) class ConverterNetwork(nn.Module): def __init__( self, phone_extractor: PhoneExtractor, pitch_estimator: PitchEstimator, n_speakers: int, pitch_bins: int, hidden_channels: int, vq_topk: int = 4, training_time_vq: Literal["none", "self", "random"] = "none", phone_noise_ratio: int = 0.5, floor_noise_level: float = 1e-3, ): super().__init__() self.frozen_modules = { "phone_extractor": phone_extractor.eval().requires_grad_(False), "pitch_estimator": pitch_estimator.eval().requires_grad_(False), } self.pitch_bins = pitch_bins self.phone_noise_ratio = phone_noise_ratio self.floor_noise_level = floor_noise_level self.out_sample_rate = out_sample_rate = 24000 phone_channels = 128 self.vq = VectorQuantizer( n_speakers=n_speakers, codebook_size=512, channels=phone_channels, topk=vq_topk, training_time_vq=training_time_vq, ) self.embed_phone = nn.Conv1d(phone_channels, hidden_channels, 1) self.embed_phone.weight.data.normal_(0.0, math.sqrt(2.0 / (256 * 5))) self.embed_phone.bias.data.zero_() self.embed_quantized_pitch = nn.Embedding(pitch_bins, hidden_channels) phase = ( torch.arange(pitch_bins, dtype=torch.float)[:, None] * ( torch.arange(0, hidden_channels, 2, dtype=torch.float) * (-math.log(10000.0) / hidden_channels) ).exp_() ) self.embed_quantized_pitch.weight.data[:, 0::2] = phase.sin() self.embed_quantized_pitch.weight.data[:, 1::2] = phase.cos_() self.embed_quantized_pitch.weight.data *= math.sqrt(4.0 / 5.0) self.embed_quantized_pitch.weight.requires_grad_(False) self.embed_pitch_features = nn.Conv1d(4, hidden_channels, 1) self.embed_pitch_features.weight.data.normal_(0.0, math.sqrt(2.0 / (4 * 5))) self.embed_pitch_features.bias.data.zero_() self.embed_speaker = nn.Embedding(n_speakers, hidden_channels) self.embed_speaker.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0)) self.embed_formant_shift = nn.Embedding(9, hidden_channels) self.embed_formant_shift.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0)) self.key_value_speaker_embedding_length = 384 self.key_value_speaker_embedding_channels = 128 self.key_value_speaker_embedding = nn.Embedding( n_speakers, self.key_value_speaker_embedding_length * self.key_value_speaker_embedding_channels, ) self.key_value_speaker_embedding.weight.data[0].normal_() self.key_value_speaker_embedding.weight.data[1:] = ( self.key_value_speaker_embedding.weight.data[0] ) self.vocoder = Vocoder( channels=hidden_channels, speaker_embedding_channels=self.key_value_speaker_embedding_channels, hop_length=out_sample_rate // 100, n_pre_blocks=4, out_sample_rate=out_sample_rate, ) self.melspectrograms = nn.ModuleList() for win_length, n_mels in [ (32, 5), (64, 10), (128, 20), (256, 40), (512, 80), (1024, 160), (2048, 320), ]: self.melspectrograms.append( torchaudio.transforms.MelSpectrogram( sample_rate=out_sample_rate, n_fft=win_length, win_length=win_length, hop_length=win_length // 4, n_mels=n_mels, power=2, norm="slaney", mel_scale="slaney", ) ) def initialize_vq(self, inputs: Sequence[Iterable[torch.Tensor]]): collector_func = self.frozen_modules["phone_extractor"].units target_layer = self.frozen_modules["phone_extractor"].head self.vq.build_codebooks( collector_func, target_layer, inputs, ) self.vq.enable_hook(target_layer) def enable_hook(self): target_layer = self.frozen_modules["phone_extractor"].head self.vq.enable_hook(target_layer) def _get_resampler( self, orig_freq, new_freq, device, cache={} ) -> torchaudio.transforms.Resample: key = orig_freq, new_freq if key in cache: return cache[key] resampler = torchaudio.transforms.Resample(orig_freq, new_freq).to( device, non_blocking=True ) cache[key] = resampler return resampler def forward( self, x: torch.Tensor, target_speaker_id: torch.Tensor, formant_shift_semitone: torch.Tensor, pitch_shift_semitone: Optional[torch.Tensor] = None, slice_start_indices: Optional[torch.Tensor] = None, slice_segment_length: Optional[int] = None, return_stats: bool = False, ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]: # x: [batch_size, 1, wav_length] # target_speaker_id: Long[batch_size] # formant_shift_semitone: [batch_size] # pitch_shift_semitone: [batch_size] # slice_start_indices: [batch_size] batch_size, _, _ = x.size() self.vq.set_target_speaker_ids(target_speaker_id) with torch.inference_mode(): phone_extractor: PhoneExtractor = self.frozen_modules["phone_extractor"] pitch_estimator: PitchEstimator = self.frozen_modules["pitch_estimator"] # [batch_size, 1, wav_length] -> [batch_size, phone_channels, length] phone = phone_extractor.units(x).transpose(1, 2) if self.training and self.phone_noise_ratio != 0.0: phone *= (1.0 - self.phone_noise_ratio) / phone.square().mean( 1, keepdim=True ).sqrt_() noise = torch.randn_like(phone) noise *= ( self.phone_noise_ratio / noise.square().mean(1, keepdim=True).sqrt_() ) phone += noise # F.rms_norm は PyTorch >= 2.4 が必要 phone *= ( 1.0 / phone.square() .mean(1, keepdim=True) .add_(torch.finfo(torch.float).eps) .sqrt_() ) # [batch_size, 1, wav_length] -> [batch_size, pitch_bins, length], [batch_size, 1, length] pitch, energy = pitch_estimator(x) # augmentation if self.training: # [batch_size, pitch_bins - 1] weights = pitch.softmax(1)[:, 1:, :].mean(2) # [batch_size] mean_pitch = ( weights * torch.arange( 1, self.embed_quantized_pitch.num_embeddings, device=weights.device, ) ).sum(1) / weights.sum(1) mean_pitch = mean_pitch.round_().long() target_pitch = torch.randint_like(mean_pitch, 64, 257) shift = target_pitch - mean_pitch shift_ratio = ( 2.0 ** (shift.float() / pitch_estimator.pitch_bins_per_octave) ).tolist() shift = [] interval_length = 100 # 1s interval_zeros = torch.zeros( (1, 1, interval_length * 160), device=x.device ) concatenated_shifted_x = [] offsets = [0] torch.backends.cudnn.benchmark = False for i in range(batch_size): shift_ratio_i = shift_ratio[i] shift_ratio_fraction_i = Fraction.from_float( shift_ratio_i ).limit_denominator(30) shift_numer_i = shift_ratio_fraction_i.numerator shift_denom_i = shift_ratio_fraction_i.denominator shift_ratio_i = shift_numer_i / shift_denom_i shift_i = int( round( math.log2(shift_ratio_i) * pitch_estimator.pitch_bins_per_octave ) ) shift.append(shift_i) shift_ratio[i] = shift_ratio_i # [1, 1, wav_length / shift_ratio] with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): shifted_x_i = self._get_resampler( shift_numer_i, shift_denom_i, x.device )(x[i])[None] if shifted_x_i.size(2) % 160 != 0: shifted_x_i = F.pad( shifted_x_i, (0, 160 - shifted_x_i.size(2) % 160), mode="reflect", ) assert shifted_x_i.size(2) % 160 == 0 offsets.append( offsets[-1] + interval_length + shifted_x_i.size(2) // 160 ) concatenated_shifted_x.extend([interval_zeros, shifted_x_i]) if offsets[-1] % 256 != 0: # 長さが同じ方が何かのキャッシュが効いて早くなるようなので # 適当に 256 の倍数になるようにパディングして長さのパターン数を減らす concatenated_shifted_x.append( torch.zeros( (1, 1, (256 - offsets[-1] % 256) * 160), device=x.device ) ) # [batch_size, 1, sum(wav_length) + batch_size * 16000] concatenated_shifted_x = torch.cat(concatenated_shifted_x, dim=2) assert concatenated_shifted_x.size(2) % (256 * 160) == 0 # [1, pitch_bins, length / shift_ratio], [1, 1, length / shift_ratio] concatenated_pitch, concatenated_energy = pitch_estimator( concatenated_shifted_x ) for i in range(batch_size): shift_i = shift[i] shift_ratio_i = shift_ratio[i] left = offsets[i] + interval_length right = offsets[i + 1] pitch_i = concatenated_pitch[:, :, left:right] energy_i = concatenated_energy[:, :, left:right] pitch_i = F.interpolate( pitch_i, scale_factor=shift_ratio_i, mode="linear", align_corners=False, ) energy_i = F.interpolate( energy_i, scale_factor=shift_ratio_i, mode="linear", align_corners=False, ) assert pitch_i.size(2) == energy_i.size(2) assert abs(pitch_i.size(2) - pitch.size(2)) <= 10 length = min(pitch_i.size(2), pitch.size(2)) if shift_i > 0: pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length] pitch[i : i + 1, 1:-shift_i, :length] = pitch_i[ :, 1 + shift_i :, :length ] pitch[i : i + 1, -shift_i:, :length] = -10.0 elif shift_i < 0: pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length] pitch[i : i + 1, 1 : 1 - shift_i, :length] = -10.0 pitch[i : i + 1, 1 - shift_i :, :length] = pitch_i[ :, 1:shift_i, :length ] energy[i : i + 1, :, :length] = energy_i[:, :, :length] torch.backends.cudnn.benchmark = True # [batch_size, pitch_bins, length] -> Long[batch_size, length], [batch_size, 3, length] quantized_pitch, pitch_features = pitch_estimator.sample_pitch( pitch, return_features=True ) if pitch_shift_semitone is not None: quantized_pitch = torch.where( quantized_pitch == 0, quantized_pitch, ( quantized_pitch + ( pitch_shift_semitone[:, None] * (pitch_estimator.pitch_bins_per_octave / 12.0) ) .round_() .long() ).clamp_(1, self.pitch_bins - 1), ) pitch = 55.0 * 2.0 ** ( quantized_pitch.float() / pitch_estimator.pitch_bins_per_octave ) # phone が 2.5ms 先読みしているのに対して、 # energy は 12.5ms, pitch_features は 22.5ms 先読みしているので、 # ずらして phone に合わせる energy = F.pad(energy[:, :, :-1], (1, 0), mode="reflect") quantized_pitch = F.pad(quantized_pitch[:, :-2], (2, 0), mode="reflect") pitch_features = F.pad(pitch_features[:, :, :-2], (2, 0), mode="reflect") # [batch_size, 1, length], [batch_size, 3, length] -> [batch_size, 4, length] pitch_features = torch.cat([energy, pitch_features], dim=1) formant_shift_indices = ( ((formant_shift_semitone + 2.0) * 2.0).round_().long() ) phone = phone.clone() quantized_pitch = quantized_pitch.clone() pitch_features = pitch_features.clone() formant_shift_indices = formant_shift_indices.clone() pitch = pitch.clone() # [batch_sise, hidden_channels, length] x = ( self.embed_phone(phone) + self.embed_quantized_pitch(quantized_pitch).transpose(1, 2) + self.embed_pitch_features(pitch_features) + ( self.embed_speaker(target_speaker_id)[:, :, None] + self.embed_formant_shift(formant_shift_indices)[:, :, None] ) ) if slice_start_indices is not None: assert slice_segment_length is not None # [batch_size, hidden_channels, length] -> [batch_size, hidden_channels, segment_length] x = beatrice_slice_segments(x, slice_start_indices, slice_segment_length) x = F.silu(x, inplace=True) speaker_embedding = self.key_value_speaker_embedding(target_speaker_id).view( batch_size, self.key_value_speaker_embedding_length, self.key_value_speaker_embedding_channels, ) # [batch_size, hidden_channels, segment_length] -> [batch_size, 1, segment_length * 240] y_g_hat, stats = self.vocoder(x, pitch, speaker_embedding) stats["pitch"] = pitch if return_stats: return y_g_hat, stats else: return y_g_hat def _normalize_melsp(self, x): return x.clamp(min=1e-10).log_() def forward_and_compute_loss( self, noisy_wavs_16k: torch.Tensor, target_speaker_id: torch.Tensor, formant_shift_semitone: torch.Tensor, slice_start_indices: torch.Tensor, slice_segment_length: int, y_all: torch.Tensor, enable_loss_ap: bool = False, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float], ]: # noisy_wavs_16k: [batch_size, 1, wav_length] # target_speaker_id: Long[batch_size] # formant_shift_semitone: [batch_size] # slice_start_indices: [batch_size] # slice_segment_length: int # y_all: [batch_size, 1, wav_length] stats = {} loss_mel = 0.0 loss_loudness = 0.0 loudness_win_lengths = [512, 1024, 2048, 4096] # [batch_size, 1, wav_length] -> [batch_size, 1, wav_length * 240] y_hat_all, intermediates = self( noisy_wavs_16k, target_speaker_id, formant_shift_semitone, return_stats=True, ) y_hat_all = y_hat_all.detach().where(y_all == 0.0, y_hat_all) with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): periodic_signal = intermediates["periodic_signal"].float() aperiodic_signal = intermediates["aperiodic_signal"].float() noise_excitation = intermediates["noise_excitation"].float() periodic_signal = periodic_signal[:, : noise_excitation.size(1)] aperiodic_signal = aperiodic_signal[:, : noise_excitation.size(1)] y_hat_all = y_hat_all.float() floor_noise = torch.randn_like(y_all) * self.floor_noise_level y_all = y_all + floor_noise y_hat_all += floor_noise y_hat_all_truncated = y_hat_all.squeeze(1)[:, : periodic_signal.size(1)] y_all_truncated = y_all.squeeze(1)[:, : periodic_signal.size(1)] y_loudness = compute_loudness( y_all_truncated, self.out_sample_rate, loudness_win_lengths ) y_hat_loudness = compute_loudness( y_hat_all_truncated, self.out_sample_rate, loudness_win_lengths ) for win_length, y_loudness_i, y_hat_loudness_i in zip( loudness_win_lengths, y_loudness, y_hat_loudness ): loss_loudness_i = F.mse_loss(y_hat_loudness_i, y_loudness_i) loss_loudness += loss_loudness_i * math.sqrt(win_length) stats[f"loss_loudness_{win_length}"] = loss_loudness_i.item() for melspectrogram in self.melspectrograms: melsp_periodic_signal = melspectrogram(periodic_signal) melsp_aperiodic_signal = melspectrogram(aperiodic_signal) melsp_noise_excitation = melspectrogram(noise_excitation) # [1, n_mels, 1] # 1/6 ... [-0.5, 0.5] の一様乱数の平均パワー # 3/8 ... ハン窓をかけた時のパワー減衰 # 0.5 ... 謎 reference_melsp = melspectrogram.mel_scale( torch.full( (1, melspectrogram.n_fft // 2 + 1, 1), (1 / 6) * (3 / 8) * 0.5 * melspectrogram.win_length, device=noisy_wavs_16k.device, ) ) aperiodic_ratio = melsp_aperiodic_signal / ( melsp_periodic_signal + melsp_aperiodic_signal + 1e-5 ) compensation_ratio = reference_melsp / (melsp_noise_excitation + 1e-5) melsp_y_hat = melspectrogram(y_hat_all_truncated) melsp_y_hat = melsp_y_hat * ( (1.0 - aperiodic_ratio) + aperiodic_ratio * compensation_ratio ) y_hat_mel = self._normalize_melsp(melsp_y_hat) y_mel = self._normalize_melsp(melspectrogram(y_all_truncated)) loss_mel_i = F.l1_loss(y_hat_mel, y_mel) loss_mel += loss_mel_i stats[ f"loss_mel_{melspectrogram.win_length}_{melspectrogram.n_mels}" ] = loss_mel_i.item() loss_mel /= len(self.melspectrograms) if enable_loss_ap: t = ( torch.arange(intermediates["pitch"].size(1), device=y_all.device) * 0.01 + 0.005 ) y_coarse_aperiodicity, y_rms = d4c( y_all.squeeze(1), intermediates["pitch"], t, self.vocoder.out_sample_rate, coarse_only=True, ) y_coarse_aperiodicity = 10.0 ** (y_coarse_aperiodicity / 10.0) y_hat_coarse_aperiodicity, y_hat_rms = d4c( y_hat_all.squeeze(1), intermediates["pitch"], t, self.vocoder.out_sample_rate, coarse_only=True, ) y_hat_coarse_aperiodicity = 10.0 ** (y_hat_coarse_aperiodicity / 10.0) rms = torch.maximum(y_rms, y_hat_rms) loss_ap = F.mse_loss( y_hat_coarse_aperiodicity, y_coarse_aperiodicity, reduction="none" ) loss_ap *= (rms / (rms + 1e-3) * (rms > 1e-5))[:, :, None] loss_ap = loss_ap.mean() else: loss_ap = torch.tensor(0.0) # [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240] y_hat = beatrice_slice_segments( y_hat_all, slice_start_indices * 240, slice_segment_length * 240 ) # [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240] y = beatrice_slice_segments(y_all, slice_start_indices * 240, slice_segment_length * 240) return y, y_hat, y_hat_all, loss_loudness, loss_mel, loss_ap, stats def merge_weights(self): self.vocoder.merge_weights() def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError dump_layer(self.embed_phone, f) dump_layer(self.embed_quantized_pitch, f) dump_layer(self.embed_pitch_features, f) dump_layer(self.vocoder, f) def dump_speaker_embeddings(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump_speaker_embeddings(f) return if not hasattr(f, "write"): raise TypeError dump_params(self.vq.codebooks, f) dump_layer(self.embed_speaker, f) dump_layer(self.embed_formant_shift, f) dump_layer(self.key_value_speaker_embedding, f) def dump_embedding_setter(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump_embedding_setter(f) return if not hasattr(f, "write"): raise TypeError self.vocoder.prenet.dump_kv(f) # Discriminator def _normalize(tensor: torch.Tensor, dim: int) -> torch.Tensor: denom = tensor.norm(p=2.0, dim=dim, keepdim=True).clamp_min(1e-6) return tensor / denom class SANConv2d(nn.Conv2d): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, bias: bool = True, padding_mode="zeros", device=None, dtype=None, ): super().__init__( in_channels, out_channels, kernel_size, stride, padding=padding, dilation=dilation, groups=1, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype, ) scale = self.weight.norm(p=2.0, dim=[1, 2, 3], keepdim=True).clamp_min(1e-6) self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight)) self.scale = nn.parameter.Parameter(scale.view(out_channels)) if bias: self.bias = nn.parameter.Parameter( torch.zeros(in_channels, device=device, dtype=dtype) ) else: self.register_parameter("bias", None) def forward( self, input: torch.Tensor, flg_san_train: bool = False ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if self.bias is not None: input = input + self.bias.view(self.in_channels, 1, 1) normalized_weight = self._get_normalized_weight() scale = self.scale.view(self.out_channels, 1, 1) if flg_san_train: out_fun = F.conv2d( input, normalized_weight.detach(), None, self.stride, self.padding, self.dilation, self.groups, ) out_dir = F.conv2d( input.detach(), normalized_weight, None, self.stride, self.padding, self.dilation, self.groups, ) out = out_fun * scale, out_dir * scale.detach() else: out = F.conv2d( input, normalized_weight, None, self.stride, self.padding, self.dilation, self.groups, ) out = out * scale return out @torch.no_grad() def normalize_weight(self): self.weight.data = self._get_normalized_weight() def _get_normalized_weight(self) -> torch.Tensor: return _normalize(self.weight, dim=[1, 2, 3]) def get_padding(kernel_size: int, dilation: int = 1) -> int: return (kernel_size * dilation - dilation) // 2 class BeatriceDiscriminatorP(nn.Module): def __init__( self, period: int, kernel_size: int = 5, stride: int = 3, san: bool = False ): super().__init__() self.period = period self.san = san # fmt: off self.convs = nn.ModuleList([ weight_norm(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), weight_norm(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), weight_norm(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), weight_norm(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), weight_norm(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, (get_padding(kernel_size, 1), 0))), ]) # fmt: on if san: self.conv_post = SANConv2d(1024, 1, (3, 1), 1, (1, 0)) else: self.conv_post = weight_norm(nn.Conv2d(1024, 1, (3, 1), 1, (1, 0))) def forward( self, x: torch.Tensor, flg_san_train: bool = False ) -> tuple[ Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor] ]: fmap = [] b, c, t = x.shape if t % self.period != 0: n_pad = self.period - (t % self.period) x = F.pad(x, (0, n_pad), "reflect") t = t + n_pad x = x.view(b, c, t // self.period, self.period) for conv in self.convs: x = conv(x) x = F.silu(x, inplace=True) fmap.append(x) if self.san: x = self.conv_post(x, flg_san_train=flg_san_train) else: x = self.conv_post(x) if flg_san_train: x_fun, x_dir = x fmap.append(x_fun) x_fun = torch.flatten(x_fun, 1, -1) x_dir = torch.flatten(x_dir, 1, -1) x = x_fun, x_dir else: fmap.append(x) x = torch.flatten(x, 1, -1) return x, fmap class BeatriceDiscriminatorR(nn.Module): def __init__(self, resolution: int, san: bool = False): super().__init__() self.resolution = resolution self.san = san assert len(self.resolution) == 3 self.convs = nn.ModuleList( [ weight_norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), weight_norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), ] ) if san: self.conv_post = SANConv2d(32, 1, (3, 3), padding=(1, 1)) else: self.conv_post = weight_norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) def forward( self, x: torch.Tensor, flg_san_train: bool = False ) -> tuple[ Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor] ]: fmap = [] x = self._spectrogram(x).unsqueeze(1) for conv in self.convs: x = conv(x) x = F.silu(x, inplace=True) fmap.append(x) if self.san: x = self.conv_post(x, flg_san_train=flg_san_train) else: x = self.conv_post(x) if flg_san_train: x_fun, x_dir = x fmap.append(x_fun) x_fun = torch.flatten(x_fun, 1, -1) x_dir = torch.flatten(x_dir, 1, -1) x = x_fun, x_dir else: fmap.append(x) x = torch.flatten(x, 1, -1) return x, fmap def _spectrogram(self, x: torch.Tensor) -> torch.Tensor: n_fft, hop_length, win_length = self.resolution x = F.pad( x, ((n_fft - hop_length) // 2, (n_fft - hop_length) // 2), mode="reflect" ).squeeze(1) with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): mag = torch.stft( x.float(), n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=torch.ones(win_length, device=x.device), center=False, return_complex=True, ).abs() return mag class BeatriceMultiPeriodDiscriminator(nn.Module): def __init__(self, san: bool = False): super().__init__() resolutions = [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]] periods = [2, 3, 5, 7, 11] self.discriminators = nn.ModuleList( [BeatriceDiscriminatorR(r, san=san) for r in resolutions] + [BeatriceDiscriminatorP(p, san=san) for p in periods] ) self.discriminator_names = [f"R_{n}_{h}_{w}" for n, h, w in resolutions] + [ f"P_{p}" for p in periods ] self.san = san def forward( self, y: torch.Tensor, y_hat: torch.Tensor, flg_san_train: bool = False ) -> tuple[ list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]], list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]], list[list[torch.Tensor]], list[list[torch.Tensor]], ]: batch_size = y.size(0) concatenated_y_y_hat = torch.cat([y, y_hat]) y_d_rs = [] y_d_gs = [] fmap_rs = [] fmap_gs = [] for d in self.discriminators: if flg_san_train: (y_d_fun, y_d_dir), fmap = d( concatenated_y_y_hat, flg_san_train=flg_san_train ) y_d_r_fun, y_d_g_fun = torch.split(y_d_fun, batch_size) y_d_r_dir, y_d_g_dir = torch.split(y_d_dir, batch_size) y_d_r = y_d_r_fun, y_d_r_dir y_d_g = y_d_g_fun, y_d_g_dir else: y_d, fmap = d(concatenated_y_y_hat, flg_san_train=flg_san_train) y_d_r, y_d_g = torch.split(y_d, batch_size) fmap_r = [] fmap_g = [] for fm in fmap: fm_r, fm_g = torch.split(fm, batch_size) fmap_r.append(fm_r) fmap_g.append(fm_g) y_d_rs.append(y_d_r) y_d_gs.append(y_d_g) fmap_rs.append(fmap_r) fmap_gs.append(fmap_g) return y_d_rs, y_d_gs, fmap_rs, fmap_gs def forward_and_compute_loss( self, y: torch.Tensor, y_hat: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float]]: y_d_rs, y_d_gs, fmap_rs, fmap_gs = self(y, y_hat, flg_san_train=self.san) stats = {} assert len(y_d_gs) == len(y_d_rs) == len(self.discriminators) with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): # discriminator loss d_loss = 0.0 for dr, dg, name in zip(y_d_rs, y_d_gs, self.discriminator_names): if self.san: dr_fun, dr_dir = map(lambda x: x.float(), dr) dg_fun, dg_dir = map(lambda x: x.float(), dg) r_loss_fun = F.softplus(1.0 - dr_fun).square().mean() g_loss_fun = F.softplus(dg_fun).square().mean() r_loss_dir = F.softplus(1.0 - dr_dir).square().mean() g_loss_dir = -F.softplus(1.0 - dg_dir).square().mean() r_loss = r_loss_fun + r_loss_dir g_loss = g_loss_fun + g_loss_dir else: dr = dr.float() dg = dg.float() r_loss = (1.0 - dr).square().mean() g_loss = dg.square().mean() stats[f"{name}_dr_loss"] = r_loss.item() stats[f"{name}_dg_loss"] = g_loss.item() d_loss += r_loss + g_loss # adversarial loss adv_loss = 0.0 for dg, name in zip(y_d_gs, self.discriminator_names): if self.san: dg_fun = dg[0].float() g_loss = F.softplus(1.0 - dg_fun).square().mean() else: dg = dg.float() g_loss = (1.0 - dg).square().mean() stats[f"{name}_gg_loss"] = g_loss.item() adv_loss += g_loss # feature mathcing loss fm_loss = 0.0 for fr, fg, name in zip(fmap_rs, fmap_gs, self.discriminator_names): fm_loss_i = 0.0 for j, (r, g) in enumerate(zip(fr, fg)): fm_loss_ij = (r.detach().float() - g.float()).abs().mean() stats[f"~{name}_fm_loss_{j}"] = fm_loss_ij.item() fm_loss_i += fm_loss_ij stats[f"{name}_fm_loss"] = fm_loss_i.item() fm_loss += fm_loss_i return d_loss, adv_loss, fm_loss, stats class GradBalancer: """Adapted from https://github.com/facebookresearch/encodec/blob/main/encodec/balancer.py""" def __init__( self, weights: dict[str, float], rescale_grads: bool = True, total_norm: float = 1.0, ema_decay: float = 0.999, per_batch_item: bool = True, ): self.weights = weights self.per_batch_item = per_batch_item self.total_norm = total_norm self.ema_decay = ema_decay self.rescale_grads = rescale_grads self.ema_total: dict[str, float] = defaultdict(float) self.ema_fix: dict[str, float] = defaultdict(float) def backward( self, losses: dict[str, torch.Tensor], input: torch.Tensor, scaler: Optional[torch.amp.GradScaler] = None, skip_update_ema: bool = False, ) -> dict[str, float]: stats = {} if skip_update_ema: assert len(losses) == len(self.ema_total) ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()} else: # 各 loss に対して d loss / d input とそのノルムを計算する norms = {} grads = {} for name, loss in losses.items(): if scaler is not None: loss = scaler.scale(loss) (grad,) = torch.autograd.grad(loss, [input], retain_graph=True) if not grad.isfinite().all(): input.backward(grad) return {} grad = grad.detach() / (1.0 if scaler is None else scaler.get_scale()) if self.per_batch_item: dims = tuple(range(1, grad.dim())) ema_norm = grad.norm(dim=dims).mean() else: ema_norm = grad.norm() norms[name] = float(ema_norm) grads[name] = grad # ノルムの移動平均を計算する for key, value in norms.items(): self.ema_total[key] = self.ema_total[key] * self.ema_decay + value self.ema_fix[key] = self.ema_fix[key] * self.ema_decay + 1.0 ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()} # ログを取る total_ema_norm = sum(ema_norms.values()) for k, ema_norm in ema_norms.items(): stats[f"grad_norm_value_{k}"] = ema_norm stats[f"grad_norm_ratio_{k}"] = ema_norm / (total_ema_norm + 1e-12) # loss の係数の比率を計算する if self.rescale_grads: total_weights = sum([self.weights[k] for k in ema_norms]) ratios = {k: w / total_weights for k, w in self.weights.items()} # 勾配を修正する loss = 0.0 for name, ema_norm in ema_norms.items(): if self.rescale_grads: scale = ratios[name] * self.total_norm / (ema_norm + 1e-12) else: scale = self.weights[name] loss += (losses if skip_update_ema else grads)[name] * scale if scaler is not None: loss = scaler.scale(loss) if skip_update_ema: (loss,) = torch.autograd.grad(loss, [input]) input.backward(loss) return stats def state_dict(self) -> dict[str, dict[str, float]]: return { "ema_total": dict(self.ema_total), "ema_fix": dict(self.ema_fix), } def load_state_dict(self, state_dict): self.ema_total = defaultdict(float, state_dict["ema_total"]) self.ema_fix = defaultdict(float, state_dict["ema_fix"]) class QualityTester(nn.Module): def __init__(self): super().__init__() self.utmos = torch.hub.load( "tarepan/SpeechMOS:v1.0.0", "utmos22_strong", trust_repo=True ).eval() @torch.inference_mode() def compute_mos(self, wav: torch.Tensor) -> dict[str, list[float]]: res = {"utmos": self.utmos(wav, sr=16000).tolist()} return res def test( self, converted_wav: torch.Tensor, source_wav: torch.Tensor ) -> dict[str, list[float]]: # [batch_size, wav_length] res = {} res.update(self.compute_mos(converted_wav)) return res def test_many( self, converted_wavs: list[torch.Tensor], source_wavs: list[torch.Tensor] ) -> tuple[dict[str, float], dict[str, list[float]]]: # list[batch_size, wav_length] results = defaultdict(list) assert len(converted_wavs) == len(source_wavs) for converted_wav, source_wav in zip(converted_wavs, source_wavs): res = self.test(converted_wav, source_wav) for metric_name, value in res.items(): results[metric_name].extend(value) return { metric_name: sum(values) / len(values) for metric_name, values in results.items() }, results def compute_grad_norm( model: nn.Module, return_stats: bool = False ) -> Union[float, dict[str, float]]: total_norm = 0.0 stats = {} for name, p in model.named_parameters(): if p.grad is None: continue param_norm = p.grad.data.norm().item() if not math.isfinite(param_norm): param_norm = p.grad.data.float().norm().item() total_norm += param_norm * param_norm if return_stats: stats[f"grad_norm_{name}"] = param_norm total_norm = math.sqrt(total_norm) if return_stats: return total_norm, stats else: return total_norm def compute_mean_f0( files: list[Path], method: Literal["dio", "harvest"] = "dio" ) -> float: sum_log_f0 = 0.0 n_frames = 0 for file in files: wav, sr = beatrice_load_audio(file) if method == "dio": f0, _ = pyworld.dio(wav.ravel().numpy().astype(np.float64), sr) elif method == "harvest": f0, _ = pyworld.harvest(wav.ravel().numpy().astype(np.float64), sr) else: raise ValueError(f"Invalid method: {method}") f0 = f0[f0 > 0] sum_log_f0 += float(np.log(f0).sum()) n_frames += len(f0) if n_frames == 0: return math.nan mean_log_f0 = sum_log_f0 / n_frames return math.exp(mean_log_f0) def get_resampler( sr_before: int, sr_after: int, device="cpu", cache={} ) -> torchaudio.transforms.Resample: if not isinstance(device, str): device = str(device) if (sr_before, sr_after, device) not in cache: cache[(sr_before, sr_after, device)] = torchaudio.transforms.Resample( sr_before, sr_after ).to(device) return cache[(sr_before, sr_after, device)] def convolve(signal: torch.Tensor, ir: torch.Tensor) -> torch.Tensor: n = 1 << (signal.size(-1) + ir.size(-1) - 2).bit_length() res = torch.fft.irfft(torch.fft.rfft(signal, n=n) * torch.fft.rfft(ir, n=n), n=n) return res[..., : signal.size(-1)] def random_formant_shift( wav: torch.Tensor, sample_rate: int, formant_shift_semitone_min: float = -3.0, formant_shift_semitone_max: float = 3.0, ) -> torch.Tensor: assert wav.ndim == 2 assert wav.size(0) == 1 device = wav.device hop_length = 256 # [wav_length] wav_np = wav.ravel().double().cpu().numpy() f0, t = pyworld.dio( wav_np, sample_rate, f0_floor=55, f0_ceil=1400, frame_period=hop_length * 1000 / sample_rate, ) f0 = pyworld.stonemask(wav_np, f0, t, sample_rate) world_sp = pyworld.cheaptrick(wav_np, f0, t, sample_rate) world_sp = ( torch.from_numpy(world_sp).float().to(device).sqrt_()[None] ) # [1, length, n_fft // 2 + 1] n_fft = win_length = (world_sp.size(2) - 1) * 2 window = torch.hann_window(win_length, device=device) # [1, n_fft // 2 + 1, length] stft_sp = torch.stft( wav, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, return_complex=True, ) assert world_sp.size(1) == stft_sp.size(2), (world_sp.size(), stft_sp.size()) assert world_sp.size(2) == stft_sp.size(1), (world_sp.size(), stft_sp.size()) shift_semitones = ( torch.rand(()).item() * (formant_shift_semitone_max - formant_shift_semitone_min) + formant_shift_semitone_min ) shift_ratio = 2.0 ** (shift_semitones / 12.0) shifted_world_sp = F.interpolate( world_sp, scale_factor=shift_ratio, mode="linear", align_corners=True ) if shifted_world_sp.size(2) > n_fft // 2 + 1: shifted_world_sp = shifted_world_sp[:, :, : n_fft // 2 + 1] elif shifted_world_sp.size(2) < n_fft // 2 + 1: shifted_world_sp = F.pad( shifted_world_sp, (0, n_fft // 2 + 1 - shifted_world_sp.size(2)) ) ratio = ((shifted_world_sp + 1e-5) / (world_sp + 1e-5)).clamp(0.1, 10.0) stft_sp *= ratio.transpose(-2, -1) # [1, n_fft // 2 + 1, length] out = torch.istft( stft_sp, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=wav.size(-1), ) return out def random_filter(audio: torch.Tensor) -> torch.Tensor: assert audio.ndim == 2 ab = torch.rand(audio.size(0), 6) * 0.75 - 0.375 a, b = ab[:, :3], ab[:, 3:] a[:, 0] = 1.0 b[:, 0] = 1.0 audio = torchaudio.functional.lfilter(audio, a, b, clamp=False) return audio def get_noise( n_samples: int, sample_rate: float, files: list[Union[str, bytes, os.PathLike]] ) -> torch.Tensor: resample_augmentation_candidates = [0.9, 0.95, 1.0, 1.05, 1.1] wavs = [] current_length = 0 while current_length < n_samples: idx_files = torch.randint(0, len(files), ()) file = files[idx_files] wav, sr = beatrice_load_audio(file) assert wav.size(0) == 1 augmented_sample_rate = int( round( sample_rate * resample_augmentation_candidates[ torch.randint(0, len(resample_augmentation_candidates), ()) ] ) ) resampler = get_resampler(sr, augmented_sample_rate) wav = resampler(wav) wav = random_filter(wav) wav *= 0.99 / (wav.abs().max() + 1e-5) wavs.append(wav) current_length += wav.size(1) start = torch.randint(0, current_length - n_samples + 1, ()) wav = torch.cat(wavs, dim=1)[:, start : start + n_samples] assert wav.size() == (1, n_samples), wav.size() return wav def get_butterworth_lpf( cutoff_freq: float, sample_rate: int, cache={} ) -> tuple[torch.Tensor, torch.Tensor]: if (cutoff_freq, sample_rate) not in cache: q = math.sqrt(0.5) omega = math.tau * cutoff_freq / sample_rate cos_omega = math.cos(omega) alpha = math.sin(omega) / (2.0 * q) b1 = (1.0 - cos_omega) / (1.0 + alpha) b0 = b1 * 0.5 a1 = -2.0 * cos_omega / (1.0 + alpha) a2 = (1.0 - alpha) / (1.0 + alpha) cache[(cutoff_freq, sample_rate)] = ( torch.tensor([b0, b1, b0]), torch.tensor([1.0, a1, a2]), ) return cache[(cutoff_freq, sample_rate)] def augment_audio( clean: torch.Tensor, sample_rate: int, noise_files: list[Union[str, bytes, os.PathLike]], ir_files: list[Union[str, bytes, os.PathLike]], snr_candidates: list[float] = [20.0, 25.0, 30.0, 35.0, 40.0, 45.0], formant_shift_probability: float = 0.5, formant_shift_semitone_min: float = -3.0, formant_shift_semitone_max: float = 3.0, reverb_probability: float = 0.5, lpf_probability: float = 0.2, lpf_cutoff_freq_candidates: list[float] = [2000.0, 3000.0, 4000.0, 6000.0], ) -> torch.Tensor: # [1, wav_length] assert clean.size(0) == 1 n_samples = clean.size(1) original_clean_rms = clean.square().mean().sqrt_() # clean をフォルマントシフトする if torch.rand(()) < formant_shift_probability: clean = random_formant_shift( clean, sample_rate, formant_shift_semitone_min, formant_shift_semitone_max ) # noise を取得して clean と concat する noise = get_noise(n_samples, sample_rate, noise_files) signals = torch.cat([clean, noise]) # clean, noise に異なるランダムフィルタをかける signals = random_filter(signals) # clean, noise にリバーブをかける if torch.rand(()) < reverb_probability: ir_file = ir_files[torch.randint(0, len(ir_files), ())] ir, sr = beatrice_load_audio(ir_file) assert ir.size() == (2, sr), ir.size() assert sr == sample_rate, (sr, sample_rate) signals = convolve(signals, ir) # clean, noise に同じ LPF をかける if torch.rand(()) < lpf_probability: if signals.abs().max() > 0.8: signals /= signals.abs().max() * 1.25 cutoff_freq = lpf_cutoff_freq_candidates[ torch.randint(0, len(lpf_cutoff_freq_candidates), ()) ] b, a = get_butterworth_lpf(cutoff_freq, sample_rate) signals = torchaudio.functional.lfilter(signals, a, b, clamp=False) # clean の音量を合わせる clean, noise = signals clean_rms = clean.square().mean().sqrt_() clean *= original_clean_rms / clean_rms if len(snr_candidates) >= 1: # clean, noise の音量をピークを重視して取る clean_level = clean.square().square_().mean().sqrt_().sqrt_() noise_level = noise.square().square_().mean().sqrt_().sqrt_() # SNR snr = snr_candidates[torch.randint(0, len(snr_candidates), ())] # noisy を生成 noisy = clean + noise * ( 0.1 ** (snr / 20.0) * clean_level / (noise_level + 1e-5) ) return noisy class WavDataset(torch.utils.data.Dataset): def __init__( self, audio_files: list[tuple[Path, int]], in_sample_rate: int = 16000, out_sample_rate: int = 24000, wav_length: int = 4 * 24000, # 4s segment_length: int = 100, # 1s noise_files: Optional[list[Union[str, bytes, os.PathLike]]] = None, ir_files: Optional[list[Union[str, bytes, os.PathLike]]] = None, augmentation_snr_candidates: list[float] = [20.0, 25.0, 30.0, 35.0, 40.0, 45.0], augmentation_formant_shift_probability: float = 0.5, augmentation_formant_shift_semitone_min: float = -3.0, augmentation_formant_shift_semitone_max: float = 3.0, augmentation_reverb_probability: float = 0.5, augmentation_lpf_probability: float = 0.2, augmentation_lpf_cutoff_freq_candidates: list[float] = [ 2000.0, 3000.0, 4000.0, 6000.0, ], ): self.audio_files = audio_files self.in_sample_rate = in_sample_rate self.out_sample_rate = out_sample_rate self.wav_length = wav_length self.segment_length = segment_length self.noise_files = noise_files self.ir_files = ir_files self.augmentation_snr_candidates = augmentation_snr_candidates self.augmentation_formant_shift_probability = ( augmentation_formant_shift_probability ) self.augmentation_formant_shift_semitone_min = ( augmentation_formant_shift_semitone_min ) self.augmentation_formant_shift_semitone_max = ( augmentation_formant_shift_semitone_max ) self.augmentation_reverb_probability = augmentation_reverb_probability self.augmentation_lpf_probability = augmentation_lpf_probability self.augmentation_lpf_cutoff_freq_candidates = ( augmentation_lpf_cutoff_freq_candidates ) if (noise_files is None) is not (ir_files is None): raise ValueError("noise_files and ir_files must be both None or not None") self.in_hop_length = in_sample_rate // 100 self.out_hop_length = out_sample_rate // 100 # 10ms 刻み def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, int, int]: file, speaker_id = self.audio_files[index] clean_wav, sample_rate = beatrice_load_audio(file) if clean_wav.size(0) != 1: ch = torch.randint(0, clean_wav.size(0), ()) clean_wav = clean_wav[ch : ch + 1] formant_shift_candidates = [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0] formant_shift = formant_shift_candidates[ torch.randint(0, len(formant_shift_candidates), ()).item() ] resampler_fraction = Fraction( sample_rate / self.out_sample_rate * 2.0 ** (formant_shift / 12.0) ).limit_denominator(300) clean_wav = get_resampler( resampler_fraction.numerator, resampler_fraction.denominator )(clean_wav) assert clean_wav.size(0) == 1 assert clean_wav.size(1) != 0 clean_wav = F.pad(clean_wav, (self.wav_length, self.wav_length)) if self.noise_files is None: noisy_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)( clean_wav ) else: clean_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)( clean_wav ) noisy_wav_16k = augment_audio( clean_wav_16k, self.in_sample_rate, self.noise_files, self.ir_files, self.augmentation_snr_candidates, self.augmentation_formant_shift_probability, self.augmentation_formant_shift_semitone_min, self.augmentation_formant_shift_semitone_max, self.augmentation_reverb_probability, self.augmentation_lpf_probability, self.augmentation_lpf_cutoff_freq_candidates, ) clean_wav = clean_wav.squeeze_(0) noisy_wav_16k = noisy_wav_16k.squeeze_(0) # 音量をランダマイズする amplitude = torch.rand(()).item() * 0.899 + 0.1 factor = amplitude / clean_wav.abs().max() clean_wav *= factor noisy_wav_16k *= factor while noisy_wav_16k.abs().max() >= 1.0: clean_wav *= 0.5 noisy_wav_16k *= 0.5 return clean_wav, noisy_wav_16k, speaker_id, formant_shift def __len__(self) -> int: return len(self.audio_files) def collate( self, batch: list[tuple[torch.Tensor, torch.Tensor, int, int]] ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: assert self.wav_length % self.out_hop_length == 0 length = self.wav_length // self.out_hop_length clean_wavs = [] noisy_wavs = [] slice_starts = [] speaker_ids = [] formant_shifts = [] for clean_wav, noisy_wav, speaker_id, formant_shift in batch: # 発声部分をランダムに 1 箇所選ぶ (voiced,) = clean_wav.nonzero(as_tuple=True) assert voiced.numel() != 0 center = voiced[torch.randint(0, voiced.numel(), ()).item()].item() # 発声部分が中央にくるように、スライス区間を選ぶ slice_start = center - self.segment_length * self.out_hop_length // 2 assert slice_start >= 0 # スライス区間が含まれるように、ランダムに wav_length の長さを切り出す r = torch.randint(0, length - self.segment_length + 1, ()).item() offset = slice_start - r * self.out_hop_length clean_wavs.append(clean_wav[offset : offset + self.wav_length]) offset_in_sample_rate = int( round(offset * self.in_sample_rate / self.out_sample_rate) ) noisy_wavs.append( noisy_wav[ offset_in_sample_rate : offset_in_sample_rate + length * self.in_hop_length ] ) slice_start = r slice_starts.append(slice_start) speaker_ids.append(speaker_id) formant_shifts.append(formant_shift) clean_wavs = torch.stack(clean_wavs) noisy_wavs = torch.stack(noisy_wavs) slice_starts = torch.tensor(slice_starts) speaker_ids = torch.tensor(speaker_ids) formant_shifts = torch.tensor(formant_shifts) return ( clean_wavs, # [batch_size, wav_length] noisy_wavs, # [batch_size, wav_length] slice_starts, # Long[batch_size] speaker_ids, # Long[batch_size] formant_shifts, # Long[batch_size] ) AUDIO_FILE_SUFFIXES = { ".wav", ".aif", ".aiff", ".fla", ".flac", ".oga", ".ogg", ".opus", ".mp3", } def get_compressed_optimizer_state_dict( optimizer: torch.optim.Optimizer, ) -> dict: state_dict = {} for k0, v0 in optimizer.state_dict().items(): if k0 != "state": state_dict[k0] = v0 continue state_dict[k0] = {} for k1, v1 in v0.items(): state_dict[k0][k1] = {} for k2, v2 in v1.items(): if isinstance(v2, torch.Tensor): state_dict[k0][k1][k2] = v2.bfloat16() assert state_dict[k0][k1][k2].isfinite().all() else: state_dict[k0][k1][k2] = v2 return state_dict def get_decompressed_optimizer_state_dict(compressed_state_dict: dict) -> dict: state_dict = {} for k0, v0 in compressed_state_dict.items(): if k0 != "state": state_dict[k0] = v0 continue state_dict[k0] = {} for k1, v1 in v0.items(): state_dict[k0][k1] = {} for k2, v2 in v1.items(): if isinstance(v2, torch.Tensor): state_dict[k0][k1][k2] = v2.float() assert state_dict[k0][k1][k2].isfinite().all() else: state_dict[k0][k1][k2] = v2 return state_dict # ============================================================ # BEATRICE V2 TRAINING - Embedded (downloads assets from HuggingFace) # ============================================================ BEATRICE_AUDIO_FILE_SUFFIXES = {".wav", ".aif", ".aiff", ".fla", ".flac", ".oga", ".ogg", ".opus", ".mp3"} def preprocess_audio_for_beatrice(audio_path: str, output_dir: str, speaker_name: str = "speaker"): """Preprocess audio for Beatrice training using silence-based splitting""" # Create speaker directory structure required by Beatrice speaker_dir = os.path.join(output_dir, speaker_name) os.makedirs(speaker_dir, exist_ok=True) # Load audio at 16kHz (Beatrice input sample rate) audio, sr = librosa.load(audio_path, sr=16000, mono=True) # Simple silence-based splitting (RMS threshold) chunk_size = int(4.0 * sr) # 4 second chunks hop = int(3.5 * sr) # 0.5s overlap threshold = 0.01 # RMS threshold chunks_saved = 0 for i, start in enumerate(range(0, len(audio) - chunk_size, hop)): chunk = audio[start:start + chunk_size] rms = np.sqrt(np.mean(chunk ** 2)) if rms > threshold: # Skip silence # Normalize max_val = np.abs(chunk).max() if max_val > 0: chunk = chunk / max_val * 0.9 chunk_path = os.path.join(speaker_dir, f"{speaker_name}_{chunks_saved:04d}.wav") sf.write(chunk_path, chunk, sr) chunks_saved += 1 logger.info(f"Beatrice preprocessing: {chunks_saved} chunks saved to {speaker_dir}") return chunks_saved, output_dir def train_beatrice_generator( data_dir: str, output_dir: str, epochs: int = 30, batch_size: int = 8, lr_g: float = 5e-5, lr_d: float = 5e-5, use_augmentation: bool = False, resume: bool = False, progress_callback=None, ): """Train Beatrice v2 model - generator yielding (message, model_path) tuples""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Download pretrained yield "Downloading pretrained models...", None phone_extractor_path = download_beatrice_asset("phone_extractor") pitch_estimator_path = download_beatrice_asset("pitch_estimator") pretrained_model_path = download_beatrice_asset("pretrained_model") # Discover speakers from directory structure # Expected: data_dir/speaker_name/*.wav speakers = [] training_filelist = [] speaker_audio_files = [] for speaker_dir in sorted(Path(data_dir).iterdir()): if not speaker_dir.is_dir(): continue candidates = [f for f in sorted(speaker_dir.rglob("*")) if f.is_file() and f.suffix.lower() in BEATRICE_AUDIO_FILE_SUFFIXES] if candidates: speaker_id = len(speakers) speakers.append(speaker_dir.name) training_filelist.extend([(f, speaker_id) for f in candidates]) speaker_audio_files.append(candidates) n_speakers = len(speakers) if n_speakers == 0: yield "Error: No speakers found in data directory", None return yield f"Found {n_speakers} speaker(s), {len(training_filelist)} files", None # Augmentation assets (optional) noise_files = None ir_files = None if use_augmentation: try: noise_dir, ir_dir = download_beatrice_augmentation() if noise_dir and ir_dir: noise_files = sorted(list(Path(noise_dir).rglob("*.wav")) + list(Path(noise_dir).rglob("*.flac"))) ir_files = sorted(list(Path(ir_dir).rglob("*.wav")) + list(Path(ir_dir).rglob("*.flac"))) if noise_files and ir_files: yield f"Loaded augmentation: {len(noise_files)} noise, {len(ir_files)} IR files", None else: noise_files = None ir_files = None except Exception as e: yield f"Warning: Could not load augmentation assets: {e}", None # Build models yield "Building models...", None phone_extractor = PhoneExtractor().to(device).eval().requires_grad_(False) pe_ckpt = torch.load(phone_extractor_path, map_location="cpu", weights_only=True) phone_extractor.load_state_dict(pe_ckpt["phone_extractor"], strict=False) del pe_ckpt pitch_estimator = PitchEstimator().to(device).eval().requires_grad_(False) pi_ckpt = torch.load(pitch_estimator_path, map_location="cpu", weights_only=True) pitch_estimator.load_state_dict(pi_ckpt["pitch_estimator"]) del pi_ckpt hidden_channels = 256 pitch_bins = 448 net_g = ConverterNetwork( phone_extractor, pitch_estimator, n_speakers=n_speakers, pitch_bins=pitch_bins, hidden_channels=hidden_channels, vq_topk=4, training_time_vq="none", phone_noise_ratio=0.5, floor_noise_level=1e-3, ).to(device) net_d = BeatriceMultiPeriodDiscriminator(san=True).to(device) # Optimizers optim_g = torch.optim.AdamW(net_g.parameters(), lr_g, betas=(0.8, 0.99), eps=1e-6) optim_d = torch.optim.AdamW(net_d.parameters(), lr_d, betas=(0.8, 0.99), eps=1e-6) grad_scaler = torch.amp.GradScaler(device.type, enabled=device.type == "cuda") grad_balancer = GradBalancer( weights={ "loss_loudness": 1.0, "loss_mel": 45.0, "loss_adv": 1.0, "loss_fm": 2.0, }, ema_decay=0.999, ) initial_iteration = 0 os.makedirs(output_dir, exist_ok=True) # Load pretrained or resume if resume: latest_ckpt = os.path.join(output_dir, "checkpoint_latest.pt.gz") if os.path.isfile(latest_ckpt): yield "Resuming from checkpoint...", None with gzip.open(latest_ckpt, "rb") as f: ckpt = torch.load(f, map_location="cpu", weights_only=True) net_g.load_state_dict(ckpt["net_g"], strict=False) # Filter discriminator for shape mismatches net_d_state = net_d.state_dict() filtered_d = {k: v for k, v in ckpt["net_d"].items() if k in net_d_state and v.shape == net_d_state[k].shape} net_d.load_state_dict(filtered_d, strict=False) optim_g.load_state_dict(get_decompressed_optimizer_state_dict(ckpt["optim_g"])) optim_d.load_state_dict(get_decompressed_optimizer_state_dict(ckpt["optim_d"])) if "grad_balancer" in ckpt: grad_balancer.load_state_dict(ckpt["grad_balancer"]) if "grad_scaler" in ckpt: grad_scaler.load_state_dict(ckpt["grad_scaler"]) initial_iteration = ckpt.get("iteration", 0) del ckpt else: yield "No checkpoint found, starting fresh with pretrained", None resume = False if not resume: yield "Loading pretrained weights...", None with gzip.open(pretrained_model_path, "rb") as f: pretrained_ckpt = torch.load(f, map_location="cpu", weights_only=True) # Adapt pretrained for our n_speakers initial_speaker_emb = pretrained_ckpt["net_g"]["embed_speaker.weight"][:1] pretrained_ckpt["net_g"]["embed_speaker.weight"] = initial_speaker_emb[[0] * n_speakers] initial_kv_emb = pretrained_ckpt["net_g"]["key_value_speaker_embedding.weight"][:1] pretrained_ckpt["net_g"]["key_value_speaker_embedding.weight"] = initial_kv_emb[[0] * n_speakers] pretrained_ckpt["net_g"]["vq.codebooks"] = pretrained_ckpt["net_g"]["vq.codebooks"][[0] * n_speakers] net_g.load_state_dict(pretrained_ckpt["net_g"], strict=False) # Filter discriminator state dict for shape mismatches (pretrained may use san=False) net_d_state = net_d.state_dict() filtered_d = {k: v for k, v in pretrained_ckpt["net_d"].items() if k in net_d_state and v.shape == net_d_state[k].shape} net_d.load_state_dict(filtered_d, strict=False) logger.info(f"Loaded {len(filtered_d)}/{len(pretrained_ckpt['net_d'])} discriminator weights") # Don't load grad_balancer/grad_scaler from pretrained - our loss weights may differ # These will be re-initialized fresh for fine-tuning del pretrained_ckpt # Build VQ codebooks yield "Building VQ codebooks...", None def wav_iterator(files): for file in files: wav, sr = beatrice_load_audio(file) wav = wav.to(device) if sr != 16000: wav = get_resampler(sr, 16000, str(device))(wav) yield wav[:, None, :] if resume: net_g.enable_hook() else: net_g.initialize_vq([wav_iterator(files) for files in speaker_audio_files]) # Dataset dataset = WavDataset( training_filelist, in_sample_rate=16000, out_sample_rate=24000, wav_length=96000, segment_length=100, noise_files=noise_files, ir_files=ir_files, ) _num_workers = min(4, os.cpu_count() or 1) effective_batch = min(batch_size, len(training_filelist)) dataloader = torch.utils.data.DataLoader( dataset, num_workers=_num_workers, collate_fn=dataset.collate, shuffle=True, batch_size=effective_batch, pin_memory=True, drop_last=len(training_filelist) > effective_batch, persistent_workers=_num_workers > 0, ) # Calculate steps steps_per_epoch = max(1, len(training_filelist) // batch_size) total_steps = epochs * steps_per_epoch warmup_steps = min(total_steps // 4, 5000) # LR scheduler with warmup def lr_lambda(step): if step < warmup_steps: return step / max(1, warmup_steps) return 0.999 ** (step - warmup_steps) scheduler_g = torch.optim.lr_scheduler.LambdaLR(optim_g, lr_lambda) scheduler_d = torch.optim.lr_scheduler.LambdaLR(optim_d, lr_lambda) # Advance schedulers if resuming with warnings.catch_warnings(): warnings.filterwarnings("ignore", message=r"Detected call of `lr_scheduler\.step\(\)") for _ in range(initial_iteration + 1): scheduler_g.step() scheduler_d.step() net_g.train() net_d.train() yield f"Training {total_steps} steps ({epochs} epochs x {steps_per_epoch} steps/epoch)", None # Training loop step = initial_iteration data_iter = None ckpt_path = None for epoch in range(epochs): epoch_loss_g = 0.0 epoch_loss_d = 0.0 epoch_steps = 0 for batch_idx in range(steps_per_epoch): if data_iter is None: data_iter = iter(dataloader) batch = next(data_iter, None) if batch is None: data_iter = iter(dataloader) batch = next(data_iter, None) if batch is None: break clean_wavs, noisy_wavs_16k, slice_starts, speaker_ids, formant_shifts = \ [x.to(device, non_blocking=True) for x in batch] with torch.amp.autocast(device.type, enabled=device.type == "cuda"): # Generator forward y, y_hat, y_hat_for_backward, loss_loudness, loss_mel, loss_ap, gen_stats = \ net_g.forward_and_compute_loss( noisy_wavs_16k[:, None, :], speaker_ids, formant_shifts, slice_start_indices=slice_starts, slice_segment_length=100, y_all=clean_wavs[:, None, :], ) # Discriminator forward loss_disc, loss_adv, loss_fm, disc_stats = \ net_d.forward_and_compute_loss(y, y_hat) # Discriminator backward optim_d.zero_grad(set_to_none=True) grad_scaler.scale(loss_disc).backward(retain_graph=True, inputs=list(net_d.parameters())) grad_scaler.unscale_(optim_d) # Generator backward optim_g.zero_grad(set_to_none=True) grad_balancer.backward( {"loss_loudness": loss_loudness, "loss_mel": loss_mel, "loss_adv": loss_adv, "loss_fm": loss_fm}, y_hat_for_backward, grad_scaler, skip_update_ema=step > 10 and step % 5 != 0, ) grad_scaler.unscale_(optim_g) # Update grad_scaler.step(optim_g) grad_scaler.step(optim_d) grad_scaler.update() optim_g.zero_grad(set_to_none=True) optim_d.zero_grad(set_to_none=True) scheduler_g.step() scheduler_d.step() epoch_loss_g += loss_mel.item() epoch_loss_d += loss_disc.item() epoch_steps += 1 step += 1 if progress_callback: progress_callback(step / total_steps) avg_loss_g = epoch_loss_g / max(1, epoch_steps) avg_loss_d = epoch_loss_d / max(1, epoch_steps) yield f"Epoch {epoch+1}/{epochs} | G loss: {avg_loss_g:.4f} | D loss: {avg_loss_d:.4f} | LR: {scheduler_g.get_last_lr()[0]:.2e}", None # Save checkpoint periodically if (epoch + 1) % max(1, epochs // 5) == 0 or epoch == epochs - 1: ckpt_path = os.path.join(output_dir, f"checkpoint_{step:08d}.pt.gz") with gzip.open(ckpt_path, "wb") as f: torch.save({ "iteration": step, "net_g": net_g.state_dict(), "phone_extractor": phone_extractor.state_dict(), "pitch_estimator": pitch_estimator.state_dict(), "net_d": {k: v.half() for k, v in net_d.state_dict().items()}, "optim_g": get_compressed_optimizer_state_dict(optim_g), "optim_d": get_compressed_optimizer_state_dict(optim_d), "grad_balancer": grad_balancer.state_dict(), "grad_scaler": grad_scaler.state_dict(), "h": { "hidden_channels": hidden_channels, "pitch_bins": pitch_bins, "vq_topk": 4, "training_time_vq": "none", "phone_noise_ratio": 0.5, "floor_noise_level": 1e-3, "san": True, }, "speakers": speakers, }, f) shutil.copy(ckpt_path, os.path.join(output_dir, "checkpoint_latest.pt.gz")) yield f"Saved checkpoint: {ckpt_path}", ckpt_path # Cleanup purge_memory(net_g, net_d, optim_g, optim_d, phone_extractor, pitch_estimator) yield "Training complete!", ckpt_path def convert_voice_beatrice( source_audio, model_file, target_speaker: int = 0, pitch_shift: int = 0, formant_shift: float = 0.0, progress=None, ): """Convert voice using Beatrice v2 model Args: source_audio: Path to source audio file model_file: Path to Beatrice checkpoint (.pt.gz) or file object with .name target_speaker: Target speaker index pitch_shift: Pitch shift in semitones formant_shift: Formant shift in semitones (-2 to 2) progress: Gradio progress callback Returns: (output_path, status_message) tuple """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Get model path if hasattr(model_file, 'name'): model_path = model_file.name elif isinstance(model_file, str): model_path = model_file else: return None, "Invalid model file" if not model_path or not os.path.exists(model_path): return None, f"Model file not found: {model_path}" try: if progress: progress(0.1, "Loading models...") # Download pretrained assets (phone extractor + pitch estimator) phone_extractor_path = download_beatrice_asset("phone_extractor") pitch_estimator_path = download_beatrice_asset("pitch_estimator") # Build phone extractor phone_extractor = PhoneExtractor().to(device).eval().requires_grad_(False) pe_ckpt = torch.load(phone_extractor_path, map_location="cpu", weights_only=True) phone_extractor.load_state_dict(pe_ckpt["phone_extractor"], strict=False) del pe_ckpt # Build pitch estimator pitch_estimator = PitchEstimator().to(device).eval().requires_grad_(False) pi_ckpt = torch.load(pitch_estimator_path, map_location="cpu", weights_only=True) pitch_estimator.load_state_dict(pi_ckpt["pitch_estimator"]) del pi_ckpt if progress: progress(0.3, "Loading trained model...") # Load trained checkpoint with gzip.open(model_path, "rb") as f: checkpoint = torch.load(f, map_location="cpu", weights_only=True) # Determine model params from checkpoint n_speakers = checkpoint["net_g"]["embed_speaker.weight"].shape[0] h = checkpoint.get("h", {}) hidden_channels = h.get("hidden_channels", 256) pitch_bins = h.get("pitch_bins", 448) speakers = checkpoint.get("speakers", [f"Speaker {i}" for i in range(n_speakers)]) if target_speaker >= n_speakers: target_speaker = 0 net_g = ConverterNetwork( phone_extractor, pitch_estimator, n_speakers=n_speakers, pitch_bins=pitch_bins, hidden_channels=hidden_channels, vq_topk=h.get("vq_topk", 4), training_time_vq=h.get("training_time_vq", "none"), phone_noise_ratio=h.get("phone_noise_ratio", 0.5), floor_noise_level=h.get("floor_noise_level", 1e-3), ).to(device).eval() net_g.load_state_dict(checkpoint["net_g"], strict=False) net_g.enable_hook() del checkpoint if progress: progress(0.5, "Converting voice...") # Load audio at 16kHz audio_path = source_audio if isinstance(source_audio, str) else source_audio audio, sr = librosa.load(audio_path, sr=16000, mono=True) audio_tensor = torch.from_numpy(audio).float().unsqueeze(0).unsqueeze(0).to(device) # Pad to multiple of 160 (phone extractor stride) original_length = audio_tensor.shape[-1] if original_length % 160 != 0: pad_len = 160 - original_length % 160 audio_tensor = F.pad(audio_tensor, (0, pad_len)) # Convert with torch.inference_mode(): y_hat = net_g( audio_tensor, torch.tensor([target_speaker], device=device), torch.tensor([formant_shift], device=device), torch.tensor([float(pitch_shift)], device=device), ) # Output is 24kHz, trim to match input duration output_length = original_length // 160 * 240 # 16kHz→24kHz frame ratio output = y_hat.squeeze().cpu().numpy()[:output_length] # Save fd, output_path = tempfile.mkstemp(suffix=".wav") os.close(fd) sf.write(output_path, output, 24000) # Cleanup purge_memory(net_g, phone_extractor, pitch_estimator) speaker_name = speakers[target_speaker] if target_speaker < len(speakers) else f"Speaker {target_speaker}" return output_path, f"Converted using Beatrice v2 | 24kHz | Speaker: {speaker_name} | Pitch: {pitch_shift:+d} | Formant: {formant_shift:+.1f}" except Exception as e: logger.exception("Beatrice inference error") return None, f"Error: {str(e)}" # ============================================================ # GRADIO UI - Gradio 6 Compatible # ============================================================ def train_ui( audio_file, model_name: str, epochs: int, batch_size: int, sample_rate: int, f0_method: str = "rmvpe", progress=gr.Progress() ): """Training function for Gradio UI - Generator for live log updates""" if audio_file is None: yield None, None, "❌ Please upload training audio" return if not model_name or model_name.strip() == "": yield None, None, "❌ Please enter a model name" return # Check if CUDA available has_cuda = torch.cuda.is_available() device_info = "GPU (CUDA)" if has_cuda else "CPU" # Log accumulator for live updates logs = [] try: model_name = sanitize_model_name(model_name) output_dir = f"trained_models/{model_name}" data_dir = f"{output_dir}/data" # Preprocessing phase logs.append(f"🚀 Starting on {device_info}") logs.append(f"📂 Output: {output_dir}") yield None, None, "\n".join(logs) progress(0.1, "Preprocessing...") logs.append("🔄 Preprocessing audio...") yield None, None, "\n".join(logs) audio_path = audio_file if isinstance(audio_file, str) else audio_file.name result = preprocess_audio_for_training(audio_path, data_dir, target_sr=sample_rate, f0_method=f0_method) if result is None: logs.append("❌ Preprocessing failed - no valid audio chunks") yield None, None, "\n".join(logs) return logs.append("✅ Preprocessing complete") logs.append(f"🏋️ Training {epochs} epochs...") logs.append("─" * 40) yield None, None, "\n".join(logs) # Training phase - iterate over generator for live updates ckpt = None idx = None for msg, path, index in train_rvc_generator( data_dir=data_dir, output_dir=output_dir, epochs=epochs, batch_size=batch_size, lr=1e-5, target_sr=sample_rate, progress_callback=progress ): logs.append(msg) yield None, None, "\n".join(logs) if path: ckpt = path if index: idx = index if ckpt: logs.append("─" * 40) logs.append(f"✅ Training complete!") logs.append(f"📦 Model: {ckpt}") if idx: logs.append(f"📦 Index: {idx}") progress(1.0, "Done!") yield ckpt, idx, "\n".join(logs) else: logs.append("❌ Training failed") yield None, None, "\n".join(logs) except Exception as e: logger.exception("Training error") logs.append(f"❌ Error: {str(e)}") yield None, None, "\n".join(logs) # ============================================================ # SPLIT-AND-STITCH BACKGROUND PROCESSOR # # Architecture rationale: # Heavy RVC inference on a 2-core CPU can push 10–14 GB RAM for # long audio. The solution is to: # 1. Chop the input into CHUNK_SEC-second slices with a # OVERLAP_SEC-second overlap on each side. The overlap gives # the vocoder context at boundaries so there are no clicks. # 2. Process each chunk independently through the full RVC # pipeline (RMVPE + FAISS + vocoder). After each chunk, # call purge_memory() so RAM never accumulates. # 3. Cross-fade adjacent chunks over the overlap region using a # linear fade-out/fade-in (equal-power is not needed here # because RVC output is already bandlimited). The crossfade # is the only "small render" step — it's just numpy slicing # and a linear ramp, not another model pass. # 4. Concatenate the stitched segments and write the final WAV. # # Non-negotiable settings preserved: # - f0_method is always forwarded as-is (RMVPE by default). # - index_file is always forwarded (FAISS stays active). # - Model sample rate (40k/48k) is respected — chunks are saved # at 16k for processing and the stitched output is at tgt_sr. # ============================================================ CHUNK_SEC = 30 # seconds per chunk fed to RVC (keeps RAM under ~4 GB/chunk) OVERLAP_SEC = 0.4 # seconds of overlap for crossfade seam (at tgt_sr) MIN_CHUNK_SEC = 5 # don't split files shorter than this def _crossfade_join(seg_a: np.ndarray, seg_b: np.ndarray, overlap_samples: int) -> np.ndarray: """ Overlap-add two mono float32 segments. seg_a: …audio… [overlap_samples tail] seg_b: [overlap_samples head] …audio… Returns the seamlessly stitched result. """ if overlap_samples <= 0 or len(seg_a) < overlap_samples or len(seg_b) < overlap_samples: return np.concatenate([seg_a, seg_b]) fade_out = np.linspace(1.0, 0.0, overlap_samples, dtype=np.float32) fade_in = np.linspace(0.0, 1.0, overlap_samples, dtype=np.float32) blended = seg_a[-overlap_samples:] * fade_out + seg_b[:overlap_samples] * fade_in return np.concatenate([seg_a[:-overlap_samples], blended, seg_b[overlap_samples:]]) def convert_voice_chunked( source_audio, model_file, index_file=None, pitch_shift: int = 0, f0_method: str = "rmvpe", # RMVPE is non-negotiable — forwarded as-is index_rate: float = 0.75, # FAISS active — forwarded as-is protect: float = 0.33, volume_envelope: float = 1.0, progress=None, chunk_sec: int = CHUNK_SEC, overlap_sec: float = OVERLAP_SEC, ) -> Tuple[str, str]: """ Split-and-stitch wrapper around convert_voice(). For audio shorter than MIN_CHUNK_SEC seconds, falls straight through to convert_voice() with no splitting overhead. For longer audio: • Loads the full waveform once at 16 kHz (source SR for RVC). • Slices into overlapping chunks at the 16k level. • Each chunk is written to a temp WAV, run through convert_voice(), and the result is loaded back as a numpy array. • purge_memory() is called after every chunk so RAM is returned to the OS via malloc_trim() before the next chunk starts. • Chunks are crossfaded and concatenated. • The final stitched audio is written to a single output WAV. """ if source_audio is None: return None, "Please upload source audio" if model_file is None: return None, "Please upload RVC model (.pth)" # ------------------------------------------------------------------ # 1. Load source audio once to check duration # ------------------------------------------------------------------ try: audio_full, _ = librosa.load(source_audio, sr=16000, mono=True) except Exception as e: return None, f"Failed to load audio: {e}" duration_sec = len(audio_full) / 16000.0 # Short file — no splitting needed, avoids overhead if duration_sec <= MIN_CHUNK_SEC: purge_memory(audio_full) return convert_voice( source_audio, model_file, index_file, pitch_shift, f0_method, index_rate, protect, volume_envelope, progress if progress is not None else gr.Progress() ) # ------------------------------------------------------------------ # 2. Determine chunk boundaries (in 16k samples) # ------------------------------------------------------------------ sr_in = 16000 chunk_samp = int(chunk_sec * sr_in) overlap_samp = int(overlap_sec * sr_in) hop_samp = chunk_samp - overlap_samp # non-overlapping step starts = list(range(0, len(audio_full), hop_samp)) n_chunks = len(starts) logger.info(f"[Chunked] {duration_sec:.1f}s → {n_chunks} chunks " f"({chunk_sec}s each, {overlap_sec}s overlap)") # ------------------------------------------------------------------ # 3. Process each chunk through convert_voice() # ------------------------------------------------------------------ stitched_segments: list[np.ndarray] = [] tgt_sr = None # learned from first chunk output overlap_out_samp = 0 # overlap at target SR (learned after first chunk) for i, start in enumerate(starts): end = min(start + chunk_samp, len(audio_full)) chunk = audio_full[start:end] if progress is not None: try: progress((i + 0.5) / n_chunks, f"Processing chunk {i+1}/{n_chunks}…") except Exception: pass # Write chunk to temp WAV fd, chunk_path = tempfile.mkstemp(suffix=".wav") os.close(fd) sf.write(chunk_path, chunk, sr_in) try: # Run full RVC pipeline on this chunk (RMVPE + FAISS inside) out_path, status = convert_voice( chunk_path, model_file, index_file, pitch_shift, f0_method, index_rate, protect, volume_envelope, # Suppress inner progress — we own the bar gr.Progress() if progress is None else progress ) except Exception as e: logger.warning(f"[Chunked] Chunk {i+1} failed: {e}") # Remove temp and skip — silence gap is better than crash try: os.unlink(chunk_path) except OSError: pass purge_memory(chunk) continue finally: try: os.unlink(chunk_path) except OSError: pass if out_path is None: logger.warning(f"[Chunked] Chunk {i+1} returned no output: {status}") purge_memory(chunk) continue # Load the converted chunk chunk_out, chunk_sr = sf.read(out_path, dtype="float32") try: os.unlink(out_path) except OSError: pass # Learn target SR from first successful chunk if tgt_sr is None: tgt_sr = chunk_sr # Scale overlap to target SR overlap_out_samp = int(overlap_sec * tgt_sr) stitched_segments.append(chunk_out) # Critical: free everything before next chunk loads the model purge_memory(chunk, chunk_out) # ------------------------------------------------------------------ # 4. Stitch segments with crossfade # ------------------------------------------------------------------ if not stitched_segments: return None, "All chunks failed — no output produced" result = stitched_segments[0] for seg in stitched_segments[1:]: result = _crossfade_join(result, seg, overlap_out_samp) # Final normalization (matches convert_voice behaviour) audio_max = np.abs(result).max() / 0.99 if audio_max > 1.0: result /= audio_max # ------------------------------------------------------------------ # 5. Write stitched output # ------------------------------------------------------------------ final_sr = tgt_sr if tgt_sr is not None else 40000 fd, output_path = tempfile.mkstemp(suffix=".wav") os.close(fd) sf.write(output_path, result, final_sr) purge_memory(result, audio_full) logger.info(f"[Chunked] Stitched output: {len(stitched_segments)} chunks → {output_path}") return output_path, ( f"Chunked conversion: {n_chunks} chunks stitched | " f"sr={final_sr} | pitch={pitch_shift:+d} | f0={f0_method}" ) with gr.Blocks() as demo: gr.Markdown(f"# 🎤 Voice Conversion (RVC + Beatrice)\nInference: CPU • Training: {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}") with gr.Tabs(): # ==================== TAB 1: VOICE CONVERSION ==================== with gr.Tab("🎵 Voice Conversion"): with gr.Row(): with gr.Column(): source_audio = gr.Audio(label="Source Audio", type="filepath") gr.Markdown("### Model") model_type = gr.Radio( ["RVC v2", "Beatrice v2"], value="RVC v2", label="Model Type", info="RVC: .pth files | Beatrice: .pt.gz files" ) # RVC model inputs with gr.Group(visible=True) as rvc_model_group: _available_models = list_rvc_models() rvc_model_dropdown = gr.Dropdown( choices=_available_models, label="Select Voice Model", info="Models auto-loaded from weights/ folder", value="model_kunni.pth" if "model_kunni.pth" in _available_models else (_available_models[0] if _available_models else None), ) with gr.Row(): model_file = gr.File(label="OR Upload New (.pth)", file_types=[".pth"]) load_example_btn = gr.Button("Load Example (Benee)", size="sm") index_file = gr.File(label="Index File (.index) - Optional", file_types=[".index"]) # Beatrice model inputs with gr.Group(visible=False) as beatrice_model_group: beatrice_model_file = gr.File(label="Beatrice Model (.pt.gz)", file_types=[".gz"]) with gr.Row(): beatrice_target_speaker = gr.Number(value=0, label="Target Speaker", precision=0) beatrice_formant_shift = gr.Slider(-2, 2, value=0.0, step=0.5, label="Formant Shift") with gr.Row(): pitch_shift = gr.Slider(-12, 12, value=0, step=1, label="Pitch (semitones)") f0_method = gr.Radio(["rmvpe", "pm", "harvest"], value="rmvpe", label="F0 Method", visible=True) with gr.Row(visible=True) as rvc_extra_options: index_rate = gr.Slider(0, 1, value=0.75, step=0.05, label="Index Rate") protect = gr.Slider(0, 0.5, value=0.33, step=0.01, label="Protect (voiceless consonants)") convert_btn = gr.Button("Convert", variant="primary") with gr.Column(): output_audio = gr.Audio(label="Converted Audio", type="filepath") output_info = gr.Textbox(label="Status", lines=2) def update_model_type(model_type_val): is_rvc = model_type_val == "RVC v2" return ( gr.update(visible=is_rvc), # rvc_model_group gr.update(visible=not is_rvc), # beatrice_model_group gr.update(visible=is_rvc), # f0_method gr.update(visible=is_rvc), # rvc_extra_options ) model_type.change( update_model_type, [model_type], [rvc_model_group, beatrice_model_group, f0_method, rvc_extra_options] ) load_example_btn.click( load_example_model, [], [model_file, index_file, output_info] ) def convert_unified(source, m_type, rvc_dropdown_model, rvc_model, rvc_index, beat_model, beat_speaker, beat_formant, pitch, f0, idx_rate, prot, progress=gr.Progress()): if m_type == "RVC v2": # Resolve model: prefer uploaded file, fall back to dropdown selection resolved_model = rvc_model resolved_index = rvc_index if resolved_model is None and rvc_dropdown_model: pth_path = Path("weights") / rvc_dropdown_model # Wrap in a simple object so convert_voice can call .name on it class _FileObj: def __init__(self, p): self.name = str(p) resolved_model = _FileObj(pth_path) # Auto-hunt for matching .index in weights/ if resolved_index is None: stem = pth_path.stem index_path = Path("weights") / f"{stem}.index" if index_path.exists(): resolved_index = _FileObj(index_path) # Use the split-and-stitch processor for RVC. # It falls through to plain convert_voice() for short clips # and auto-chunks long audio to prevent OOM on HF Spaces. return convert_voice_chunked(source, resolved_model, resolved_index, pitch, f0, idx_rate, prot, progress=progress) else: return convert_voice_beatrice( source, beat_model, target_speaker=int(beat_speaker), pitch_shift=int(pitch), formant_shift=float(beat_formant), progress=progress ) convert_btn.click( convert_unified, [source_audio, model_type, rvc_model_dropdown, model_file, index_file, beatrice_model_file, beatrice_target_speaker, beatrice_formant_shift, pitch_shift, f0_method, index_rate, protect], [output_audio, output_info], api_name="convert", concurrency_limit=1, ) gr.Markdown("**Models:** [HuggingFace](https://huggingface.co/models?search=rvc) | [Weights.gg](https://weights.gg)") # ==================== TAB 2: TRAINING ==================== with gr.Tab("🏋️ Training"): # GPU Warning gpu_status = "🟢 GPU Available (CUDA)" if torch.cuda.is_available() else "🟡 CPU Only (Training will be slow)" gr.Markdown(f""" ### Training Status: {gpu_status} {'**GPU detected!** Training will use CUDA acceleration.' if torch.cuda.is_available() else '**No GPU detected.** Training will run on CPU (~30 sec/epoch). For faster training, run locally with CUDA GPU.'} """) # Trainer selector - Beatrice always available (downloads assets from HF) trainer_selector = gr.Dropdown( choices=["RVC v2", "Beatrice v2"], value="RVC v2", label="Trainer", info="RVC v2: general purpose | Beatrice v2: low latency (~50ms)" ) with gr.Row(): with gr.Column(): train_audio = gr.Audio(label="Training Audio", type="filepath") train_model_name = gr.Textbox(label="Model Name", placeholder="my_voice", value="my_voice") # RVC-specific options with gr.Group(visible=True) as rvc_options: with gr.Row(): train_epochs = gr.Slider(1, 500, value=50, step=1, label="Epochs") train_batch = gr.Slider(1, 8, value=2, step=1, label="Batch Size") with gr.Row(): train_sr = gr.Radio([32000, 40000, 48000], value=40000, label="Sample Rate") train_f0 = gr.Radio(["rmvpe", "pm", "harvest"], value="rmvpe", label="F0 Method") # Beatrice-specific options with gr.Group(visible=False) as beatrice_options: with gr.Row(): beatrice_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs (30 recommended)") beatrice_batch = gr.Slider(1, 64, value=8, step=1, label="Batch Size") beatrice_resume = gr.Checkbox(label="Resume from checkpoint", value=False) train_btn = gr.Button("Start Training", variant="primary") with gr.Column(): train_output_model = gr.File(label="Trained Model (.pth)") train_output_index = gr.File(label="Index File (.index)") train_status = gr.Textbox(label="Training Status", lines=6) # Toggle visibility based on trainer selection def update_trainer_options(trainer): is_rvc = trainer == "RVC v2" return gr.update(visible=is_rvc), gr.update(visible=not is_rvc) trainer_selector.change( update_trainer_options, [trainer_selector], [rvc_options, beatrice_options] ) # Unified training function def train_unified(trainer, audio, name, rvc_epochs, rvc_batch, rvc_sr, rvc_f0, beat_epochs, beat_batch, beat_resume, progress=gr.Progress()): if trainer == "RVC v2": # Use generator for live updates (yields 3-tuples: model, index, status) for result in train_ui(audio, name, rvc_epochs, rvc_batch, rvc_sr, rvc_f0, progress): yield result else: # Beatrice training (embedded - downloads assets from HF) if audio is None: yield None, None, "❌ Please upload training audio" return if not name or name.strip() == "": yield None, None, "❌ Please enter a model name" return try: name = sanitize_model_name(name) output_dir = f"trained_models/beatrice_{name}" data_dir = f"{output_dir}/training_data" logs = [] # Preprocess audio into speaker chunks logs.append(f"🚀 Starting Beatrice training on {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}") yield None, None, "\n".join(logs) progress(0.05, "Preprocessing audio...") audio_path = audio if isinstance(audio, str) else audio.name chunks, _ = preprocess_audio_for_beatrice(audio_path, data_dir, name) if chunks == 0: yield None, None, "❌ Preprocessing failed - no valid audio chunks" return logs.append(f"✅ Preprocessed {chunks} audio chunks") logs.append("─" * 40) yield None, None, "\n".join(logs) # Train using embedded generator ckpt = None for msg, path in train_beatrice_generator( data_dir=data_dir, output_dir=output_dir, epochs=beat_epochs, batch_size=beat_batch, resume=beat_resume, progress_callback=progress ): logs.append(msg) yield None, None, "\n".join(logs) if path: ckpt = path if ckpt: logs.append("─" * 40) logs.append(f"✅ Training complete!") logs.append(f"📦 Model: {ckpt}") progress(1.0, "Done!") yield ckpt, None, "\n".join(logs) else: logs.append("❌ Training failed") yield None, None, "\n".join(logs) except Exception as e: logger.exception("Beatrice training error") yield None, None, f"❌ Error: {str(e)}" train_btn.click( train_unified, [trainer_selector, train_audio, train_model_name, train_epochs, train_batch, train_sr, train_f0, beatrice_epochs, beatrice_batch, beatrice_resume], [train_output_model, train_output_index, train_status], api_name="train", concurrency_limit=1, ) gr.Markdown(""" --- ### Training Tips **RVC v2:** - 50-100 epochs for quick test, 200-500 for quality - CPU training: ~30 sec/epoch (100 epochs ≈ 50 min) **Beatrice v2:** - 20-50 epochs recommended, GPU recommended (CPU works but slow) - Pretrained assets downloaded automatically from HuggingFace - Output: .pt.gz checkpoint (use in Voice Conversion tab) - Lower latency (~50ms vs ~100ms) ### CLI ```bash python app.py train -a voice.mp3 -o ./model --epochs 100 python app.py train-beatrice -a voice.mp3 -o ./beatrice_model --epochs 30 python app.py infer -i input.wav -m beatrice_model.pt.gz -o output.wav ``` """) def cli_convert(args): """CLI mode conversion - supports both RVC and Beatrice models""" print(f"Converting: {args.input}") print(f"Model: {args.model}") # Auto-detect model type from extension or --type flag model_type = getattr(args, 'type', None) model_path = args.model index_path = getattr(args, 'index', None) if args.example: print("Downloading example model...") model_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_MODEL_FILE) index_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_INDEX_FILE) model_type = "rvc" print(f"Model: {model_path}") if not model_path: print("Error: No model specified. Use -m MODEL.pth or --example") sys.exit(1) # Auto-detect type from extension if not specified if model_type is None: if model_path.endswith('.pt.gz') or model_path.endswith('.gz'): model_type = "beatrice" else: model_type = "rvc" if model_type == "beatrice": # Beatrice inference print(f"Using Beatrice v2 inference") output_path, status = convert_voice_beatrice( source_audio=args.input, model_file=model_path, target_speaker=getattr(args, 'speaker', 0), pitch_shift=args.pitch, formant_shift=getattr(args, 'formant_shift', 0.0), ) else: # RVC inference class FileObj: def __init__(self, path): self.name = path model_file = FileObj(model_path) index_file = FileObj(index_path) if index_path else None output_path, status = convert_voice( source_audio=args.input, model_file=model_file, index_file=index_file, pitch_shift=args.pitch, f0_method=args.f0, index_rate=args.index_rate, progress=lambda *a, **k: None ) if output_path: shutil.copy(output_path, args.output) print(f"Output: {args.output}") print(status) else: print(f"Failed: {status}") sys.exit(1) def cli_train_beatrice(args): """CLI Beatrice training mode (embedded - downloads assets from HF)""" print(f"=== Beatrice v2 Training (Embedded) ===") print(f"Input audio: {args.audio}") print(f"Output dir: {args.output}") # Preprocess audio - Beatrice expects: data_dir/speaker_name/*.wav data_dir = os.path.join(args.output, "training_data") speaker_name = os.path.splitext(os.path.basename(args.audio))[0] print(f"\n[1/2] Preprocessing audio for Beatrice...") chunks, _ = preprocess_audio_for_beatrice(args.audio, data_dir, speaker_name) if chunks == 0: print("Preprocessing failed - no valid audio chunks") sys.exit(1) print(f"Created {chunks} audio chunks") # Train using embedded code print(f"\n[2/2] Training Beatrice model ({args.epochs} epochs)...") ckpt = None for msg, path in train_beatrice_generator( data_dir=data_dir, output_dir=args.output, epochs=args.epochs, batch_size=args.batch, resume=args.resume, ): print(msg) if path: ckpt = path if ckpt: print(f"\nTraining complete!") print(f"Model: {ckpt}") else: print("Training failed!") sys.exit(1) def cli_train(args): """CLI training mode""" print(f"=== RVC Training ===") print(f"Input audio: {args.audio}") print(f"Output dir: {args.output}") # Create temp dir for preprocessing data_dir = f"{args.output}/data" # Preprocess print("\n[1/2] Preprocessing audio...") result = preprocess_audio_for_training(args.audio, data_dir, target_sr=args.sr, f0_method=args.f0) if result is None: print("Preprocessing failed!") sys.exit(1) # Train print(f"\n[2/2] Training for {args.epochs} epochs...") ckpt, idx = train_rvc( data_dir=data_dir, output_dir=args.output, epochs=args.epochs, batch_size=args.batch, lr=args.lr, target_sr=args.sr ) if ckpt: print(f"\nTraining complete!") print(f"Model saved: {ckpt}") if idx: print(f"Index saved: {idx}") else: print("Training failed!") sys.exit(1) if __name__ == "__main__": # Check if any CLI args (besides script name) if len(sys.argv) > 1: parser = argparse.ArgumentParser( description="RVC Voice Conversion - Inference and Training", formatter_class=argparse.RawDescriptionHelpFormatter, ) subparsers = parser.add_subparsers(dest="command", help="Commands") # Inference subcommand infer_parser = subparsers.add_parser("infer", help="Voice conversion inference (RVC + Beatrice)", epilog=""" Examples: python app.py infer -i voice.wav -m model.pth -o output.wav python app.py infer -i voice.wav -m beatrice_model.pt.gz -o output.wav --type beatrice python app.py infer -i voice.wav --example -o output.wav """) infer_parser.add_argument("-i", "--input", required=True, help="Input audio file") infer_parser.add_argument("-o", "--output", required=True, help="Output audio file") infer_parser.add_argument("-m", "--model", help="Model file (.pth for RVC, .pt.gz for Beatrice)") infer_parser.add_argument("--type", choices=["rvc", "beatrice"], default=None, help="Model type (auto-detected from extension)") infer_parser.add_argument("--index", help="Index file (.index) - RVC only") infer_parser.add_argument("--example", action="store_true", help="Use example model (Benee-RVC)") infer_parser.add_argument("-p", "--pitch", type=int, default=0, help="Pitch shift (-12 to 12)") infer_parser.add_argument("--f0", choices=["rmvpe", "pm", "harvest"], default="rmvpe", help="F0 method (RVC only)") infer_parser.add_argument("--index-rate", type=float, default=0.75, help="Index rate 0-1 (RVC only)") infer_parser.add_argument("--speaker", type=int, default=0, help="Target speaker index (Beatrice only)") infer_parser.add_argument("--formant-shift", type=float, default=0.0, help="Formant shift -2 to 2 (Beatrice only)") # Training subcommand (RVC) train_parser = subparsers.add_parser("train", help="Train RVC model", epilog=""" Examples: python app.py train -a voice.mp3 -o ./my_model --epochs 5 python app.py train -a dataset.wav -o ./trained --epochs 10 --batch 4 """) train_parser.add_argument("-a", "--audio", required=True, help="Training audio file") train_parser.add_argument("-o", "--output", required=True, help="Output directory for model") train_parser.add_argument("--epochs", type=int, default=5, help="Number of epochs (default: 5)") train_parser.add_argument("--batch", type=int, default=2, help="Batch size (default: 2)") train_parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate (default: 1e-5)") train_parser.add_argument("--sr", type=int, default=40000, help="Sample rate (default: 40000)") train_parser.add_argument("--f0", choices=["rmvpe", "pm", "harvest"], default="rmvpe", help="F0 method (default: rmvpe)") # Beatrice v2 Training subcommand beatrice_parser = subparsers.add_parser("train-beatrice", help="Train Beatrice v2 model (downloads assets from HF)", epilog=""" Examples: python app.py train-beatrice -a voice.mp3 -o ./beatrice_model --epochs 20 python app.py train-beatrice -a dataset.wav -o ./trained --epochs 50 --batch 8 --resume Pretrained assets are downloaded automatically from HuggingFace. """) beatrice_parser.add_argument("-a", "--audio", required=True, help="Training audio file") beatrice_parser.add_argument("-o", "--output", required=True, help="Output directory for model") beatrice_parser.add_argument("--epochs", type=int, default=20, help="Number of epochs (default: 20)") beatrice_parser.add_argument("--batch", type=int, default=24, help="Batch size (default: 24, reduce for less VRAM)") beatrice_parser.add_argument("--resume", action="store_true", help="Resume from checkpoint") args = parser.parse_args() if args.command == "infer": cli_convert(args) elif args.command == "train": cli_train(args) elif args.command == "train-beatrice": cli_train_beatrice(args) else: parser.print_help() else: # ============================================================ # GRADIO QUEUE — Stability config for 2-core CPU HF Spaces # # Why these settings matter: # # max_size=3 # Hard cap on pending jobs. On a 2-core CPU, RVC can take # 3–10 minutes per job. Without a cap, users pile up and # the Space exhausts RAM while holding dozens of audio # blobs in the queue. 3 is a safe upper bound: one running # + two waiting. Any new request beyond that gets a clear # "queue full" HTTP 503 instead of a silent timeout. # # default_concurrency_limit=1 # Only one inference job runs at a time. Two simultaneous # RVC jobs on a 2-core CPU would each get 1 thread, which # is slower than one job using both cores sequentially. # Serial execution is strictly faster here. # # The convert_btn.click already has concurrency_limit=1 set # at the event level; this queue-level setting enforces it # globally (including any API calls) so nothing bypasses it. # # status_update_rate="auto" # Gradio sends SSE heartbeats to the browser on this cadence. # "auto" (≈ 1 Hz) is enough to keep the WebSocket / SSE # connection alive across long jobs without flooding the # 2-core CPU with keep-alive overhead. # ============================================================ demo.queue( max_size=3, default_concurrency_limit=1, status_update_rate="auto", ).launch( mcp_server=True, show_error=True, ssr_mode=False, )