diff --git "a/app.py" "b/app.py" new file mode 100644--- /dev/null +++ "b/app.py" @@ -0,0 +1,6583 @@ +""" +RVC + Beatrice v2 Voice Conversion - Single-file app for HuggingFace Spaces +RVC-Project + Beatrice v2 (fierce-cats/beatrice-trainer), consolidated into single file + +- Inference: RVC v2 (.pth) + Beatrice v2 (.pt.gz), CPU or GPU +- Training: RVC v2 + Beatrice v2, GPU recommended + +Usage: + CLI: python app.py infer -i input.wav -m model.pth -o output.wav + python app.py infer -i input.wav -m beatrice.pt.gz -o output.wav + Gradio: python app.py +""" +import os +import sys + +# MPS fallback for macOS +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +import argparse +import gc +import gzip +import json as json_module +import logging +import math +import re +import shutil +import tempfile +import warnings + +# Suppress known harmless warnings from HF Spaces / torch internals +warnings.filterwarnings("ignore", message=".*torch.distributed.reduce_op.*", category=FutureWarning) +warnings.filterwarnings("ignore", message=".*torch.nn.utils.weight_norm.*", category=FutureWarning) +from collections import defaultdict +from fractions import Fraction +from functools import partial +from pathlib import Path +from random import Random +from typing import Optional, List, Tuple, Union, BinaryIO, Literal, Sequence, Iterable, Callable + +import gradio as gr +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm +import librosa +import pyworld +import soundfile as sf +import torchaudio +from scipy import signal +from huggingface_hub import hf_hub_download +from tqdm.auto import tqdm + +# 48 Hz high-pass filter to remove low-frequency artifacts (same as Applio) +FILTER_ORDER = 5 +CUTOFF_FREQUENCY = 48 # Hz +SAMPLE_RATE = 16000 # Hz +bh, ah = signal.butter(N=FILTER_ORDER, Wn=CUTOFF_FREQUENCY, btype="high", fs=SAMPLE_RATE) + +def sanitize_model_name(name: str) -> str: + """Sanitize model name for safe use in file paths""" + name = os.path.basename(name.strip()) + name = re.sub(r'[^\w\-.]', '_', name) + return name or "unnamed_model" + +# Default example model +DEFAULT_MODEL_REPO = "audo/Benee-RVC" +DEFAULT_MODEL_FILE = "BENEE8000.pth" +DEFAULT_INDEX_FILE = "added_IVF1054_Flat_nprobe_8.index" + +# RVC v2 pretrained weights from official repo +RVC_PRETRAINED_REPO = "lj1995/VoiceConversionWebUI" +RVC_PRETRAINED_V2 = { + # Generator with f0 (pitch) + "f0G48k": "pretrained_v2/f0G48k.pth", + "f0G40k": "pretrained_v2/f0G40k.pth", + "f0G32k": "pretrained_v2/f0G32k.pth", + # Discriminator with f0 + "f0D48k": "pretrained_v2/f0D48k.pth", + "f0D40k": "pretrained_v2/f0D40k.pth", + "f0D32k": "pretrained_v2/f0D32k.pth", + # Generator without f0 + "G48k": "pretrained_v2/G48k.pth", + "G40k": "pretrained_v2/G40k.pth", + "G32k": "pretrained_v2/G32k.pth", + # Discriminator without f0 + "D48k": "pretrained_v2/D48k.pth", + "D40k": "pretrained_v2/D40k.pth", + "D32k": "pretrained_v2/D32k.pth", +} + +def download_pretrained_rvc(name: str) -> str: + """Download RVC v2 pretrained weights from HuggingFace""" + if name not in RVC_PRETRAINED_V2: + raise ValueError(f"Unknown pretrained: {name}. Available: {list(RVC_PRETRAINED_V2.keys())}") + filepath = RVC_PRETRAINED_V2[name] + logger.info(f"Downloading pretrained {name} from {RVC_PRETRAINED_REPO}...") + return hf_hub_download(repo_id=RVC_PRETRAINED_REPO, filename=filepath) + +# Beatrice v2 pretrained assets +BEATRICE_REPO = "fierce-cats/beatrice-trainer" +BEATRICE_PRETRAINED = { + "phone_extractor": "assets/pretrained/122_checkpoint_03000000.pt", + "pitch_estimator": "assets/pretrained/104_3_checkpoint_00300000.pt", + "pretrained_model": "assets/pretrained/151_checkpoint_libritts_r_200_02750000.pt.gz", +} + +def download_beatrice_asset(name: str) -> str: + """Download Beatrice v2 pretrained asset from HuggingFace""" + if name not in BEATRICE_PRETRAINED: + raise ValueError(f"Unknown asset: {name}. Available: {list(BEATRICE_PRETRAINED.keys())}") + filepath = BEATRICE_PRETRAINED[name] + logger.info(f"Downloading Beatrice asset {name} from {BEATRICE_REPO}...") + return hf_hub_download(repo_id=BEATRICE_REPO, filename=filepath) + +def download_beatrice_augmentation(): + """Download Beatrice augmentation assets (noise + IR) - optional for training""" + try: + from huggingface_hub import snapshot_download + cache_dir = snapshot_download(repo_id=BEATRICE_REPO, allow_patterns=["assets/noise/*", "assets/ir/*"]) + noise_dir = os.path.join(cache_dir, "assets", "noise") + ir_dir = os.path.join(cache_dir, "assets", "ir") + if os.path.isdir(noise_dir) and os.path.isdir(ir_dir): + return noise_dir, ir_dir + return None, None + except Exception as e: + logger.warning(f"Could not download augmentation assets: {e}") + return None, None + +def load_pretrained_weights(model: nn.Module, pretrained_path: str) -> None: + """Load pretrained weights into model, handling speaker embedding mismatch""" + logger.info(f"Loading pretrained weights: {pretrained_path}") + state_dict = torch.load(pretrained_path, map_location="cpu", weights_only=True) + # Handle different checkpoint formats + if "model" in state_dict: + state_dict = state_dict["model"] + + # Filter out mismatched keys, but handle emb_g specially + model_state = model.state_dict() + filtered_state = {} + skipped = [] + for k, v in state_dict.items(): + if k in model_state: + if v.shape == model_state[k].shape: + filtered_state[k] = v + elif k == "emb_g.weight": + # Initialize our speaker embedding with mean of pretrained embeddings + # This gives a much better starting point than random initialization + mean_emb = v.mean(dim=0, keepdim=True) # [1, 256] + num_speakers = model_state[k].shape[0] + filtered_state[k] = mean_emb.expand(num_speakers, -1).clone() + logger.info(f"Initialized emb_g from pretrained mean ({v.shape[0]} -> {num_speakers} speakers)") + else: + skipped.append(f"{k}: {v.shape} vs {model_state[k].shape}") + else: + skipped.append(f"{k}: not in model") + + if skipped: + logger.info(f"Skipped {len(skipped)} mismatched keys") + + model.load_state_dict(filtered_state, strict=False) + logger.info(f"Loaded {len(filtered_state)}/{len(state_dict)} pretrained weights") + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Device selection: +# - Inference: Always CPU (HF Spaces free tier, also works everywhere) +# - Training: GPU if available for speed, CPU fallback +device = torch.device("cpu") # For inference +train_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # For training + +logger.info(f"Inference device: {device}") +logger.info(f"Training device: {train_device}") + +# ============================================================ +# COMMONS - Helper functions from infer/lib/infer_pack/commons.py +# ============================================================ + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + +def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + return t_act * s_act + +def slice_segments(x, ids_str, segment_size=4): + """Slice segments from tensor""" + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + +def slice_segments2(x, ids_str, segment_size=4): + """Slice segments from 2D tensor""" + ret = torch.zeros_like(x[:, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, idx_str:idx_end] + return ret + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + """Random slice segments""" + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=1) + ids_str = (torch.rand([b], device=x.device) * ids_str_max.float()).long() + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + +# ============================================================ +# MODULES - From infer/lib/infer_pack/modules.py +# ============================================================ + +LRELU_SLOPE = 0.1 + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + +class WN(nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): + super().__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size,) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = float(p_dropout) + + self.in_layers = nn.ModuleList() + self.res_skip_layers = nn.ModuleList() + self.drop = nn.Dropout(float(p_dropout)) + + if gin_channels != 0: + cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) + self.cond_layer = weight_norm(cond_layer, name="weight") + + for i in range(n_layers): + dilation = dilation_rate ** i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding) + in_layer = weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i, (in_layer, res_skip_layer) in enumerate(zip(self.in_layers, self.res_skip_layers)): + x_in = in_layer(x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = self.drop(acts) + res_skip_acts = res_skip_layer(acts) + + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, :self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels:, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + remove_weight_norm(self.cond_layer) + for l in self.in_layers: + remove_weight_norm(l) + for l in self.res_skip_layers: + remove_weight_norm(l) + +class ResBlock1(nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]))), + ]) + self.convs1.apply(init_weights) + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), + ]) + self.convs2.apply(init_weights) + self.lrelu_slope = LRELU_SLOPE + + def forward(self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, self.lrelu_slope) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, self.lrelu_slope) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + +class ResBlock2(nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), + ]) + self.convs.apply(init_weights) + self.lrelu_slope = LRELU_SLOPE + + def forward(self, x, x_mask: Optional[torch.Tensor] = None): + for c in self.convs: + xt = F.leaky_relu(x, self.lrelu_slope) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + +class Flip(nn.Module): + def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x, torch.zeros([1], device=x.device) + +class ResidualCouplingLayer(nn.Module): + def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False): + assert channels % 2 == 0 + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=float(p_dropout), gin_channels=gin_channels) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x, torch.zeros([1]) + + def remove_weight_norm(self): + self.enc.remove_weight_norm() + +# ============================================================ +# ATTENTIONS - From infer/lib/infer_pack/attentions.py +# ============================================================ + +class MultiHeadAttention(nn.Module): + def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, proximal_bias=False, proximal_init=False): + super().__init__() + assert channels % n_heads == 0 + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels ** -0.5 + self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + x, _ = self.attention(q, k, v, mask=attn_mask) + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s = key.size() + t_t = query.size(2) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + return torch.matmul(x, y.unsqueeze(0)) + + def _matmul_with_relative_keys(self, x, y): + return torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + + def _get_relative_embeddings(self, relative_embeddings, length): + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) + else: + padded_relative_embeddings = relative_embeddings + return padded_relative_embeddings[:, slice_start_position:slice_end_position] + + def _relative_position_to_absolute_position(self, x): + batch, heads, length, _ = x.size() + x = F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0]) + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, [0, int(length) - 1, 0, 0, 0, 0]) + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:] + return x_final + + def _absolute_position_to_relative_position(self, x): + batch, heads, length, _ = x.size() + x = F.pad(x, [0, int(length) - 1, 0, 0, 0, 0, 0, 0]) + x_flat = x.view([batch, heads, int(length ** 2) + int(length * (length - 1))]) + x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.causal = causal + self.is_activation = activation == "gelu" + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor): + x = self.conv_1(self._padding(x, x_mask)) + if self.is_activation: + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self._padding(x, x_mask)) + return x * x_mask + + def _padding(self, x, x_mask): + if self.causal: + if self.kernel_size == 1: + return x * x_mask + pad_l = self.kernel_size - 1 + return F.pad(x * x_mask, [pad_l, 0, 0, 0, 0, 0]) + else: + if self.kernel_size == 1: + return x * x_mask + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + return F.pad(x * x_mask, [pad_l, pad_r, 0, 0, 0, 0]) + +class Encoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, window_size=10, **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.n_layers = int(n_layers) + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + + for i in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for attn, norm1, ffn, norm2 in zip(self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2): + y = attn(x, x, attn_mask) + y = self.drop(y) + x = norm1(x + y) + y = ffn(x, x_mask) + y = self.drop(y) + x = norm2(x + y) + return x * x_mask + +# ============================================================ +# MODELS - From infer/lib/infer_pack/models.py +# ============================================================ + +sr2sr = {"32k": 32000, "40k": 40000, "48k": 48000} + +class TextEncoder256(nn.Module): + def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.emb_phone = nn.Linear(256, hidden_channels) + self.lrelu = nn.LeakyReLU(0.1, inplace=True) + if f0: + self.emb_pitch = nn.Embedding(256, hidden_channels) + self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor): + if pitch is None: + x = self.emb_phone(phone) + else: + x = self.emb_phone(phone) + self.emb_pitch(pitch) + x = x * math.sqrt(self.hidden_channels) + x = self.lrelu(x) + x = torch.transpose(x, 1, -1) + x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype) + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + return m, logs, x_mask + +class TextEncoder768(nn.Module): + def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.emb_phone = nn.Linear(768, hidden_channels) + self.lrelu = nn.LeakyReLU(0.1, inplace=True) + if f0: + self.emb_pitch = nn.Embedding(256, hidden_channels) + self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor): + if pitch is None: + x = self.emb_phone(phone) + else: + x = self.emb_phone(phone) + self.emb_pitch(pitch) + x = x * math.sqrt(self.hidden_channels) + x = self.lrelu(x) + x = torch.transpose(x, 1, -1) + x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype) + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + return m, logs, x_mask + +class ResidualCouplingBlock(nn.Module): + def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): + super().__init__() + self.n_flows = n_flows + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) + self.flows.append(Flip()) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in self.flows[::-1]: + x, _ = flow.forward(x, x_mask, g=g, reverse=reverse) + return x + + def remove_weight_norm(self): + for i in range(self.n_flows): + self.flows[i * 2].remove_weight_norm() + +class PosteriorEncoder(nn.Module): + def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0): + super().__init__() + self.out_channels = out_channels + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None): + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + def remove_weight_norm(self): + self.enc.remove_weight_norm() + +class Generator(nn.Module): + def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): + super().__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock_class(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + +class SineGen(nn.Module): + def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0): + super().__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + + def _f02uv(self, f0): + uv = torch.ones_like(f0) + uv = uv * (f0 > self.voiced_threshold) + return uv.float() + + def forward(self, f0: torch.Tensor, upp: int): + with torch.no_grad(): + f0 = f0[:, None].transpose(1, 2) + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + f0_buf[:, :, 0] = f0[:, :, 0] + for idx in range(self.harmonic_num): + f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2) + rad_values = (f0_buf / self.sampling_rate) % 1 + rand_ini = torch.rand(f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + tmp_over_one = torch.cumsum(rad_values, 1) + tmp_over_one *= upp + tmp_over_one = F.interpolate(tmp_over_one.transpose(2, 1), scale_factor=float(upp), mode="linear", align_corners=True).transpose(2, 1) + rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1) + tmp_over_one %= 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi) + sine_waves = sine_waves * self.sine_amp + uv = self._f02uv(f0) + uv = F.interpolate(uv.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1) + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + +class SourceModuleHnNSF(nn.Module): + def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0, is_half=False): + super().__init__() + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod) + self.l_linear = nn.Linear(harmonic_num + 1, 1) + self.l_tanh = nn.Tanh() + + def forward(self, x: torch.Tensor, upp: int = 1): + sine_wavs, uv, _ = self.l_sin_gen(x, upp) + sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + return sine_merge, None, None + +class GeneratorNSF(nn.Module): + def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, sr, is_half=False): + super().__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.f0_upsamp = nn.Upsample(scale_factor=math.prod(upsample_rates)) + self.m_source = SourceModuleHnNSF(sampling_rate=sr, harmonic_num=0, is_half=is_half) + self.noise_convs = nn.ModuleList() + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + c_cur = upsample_initial_channel // (2 ** (i + 1)) + self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2))) + if i + 1 < len(upsample_rates): + stride_f0 = math.prod(upsample_rates[i + 1:]) + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock_class(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + self.upp = math.prod(upsample_rates) + self.lrelu_slope = LRELU_SLOPE + + def forward(self, x, f0, g: Optional[torch.Tensor] = None): + har_source, _, _ = self.m_source(f0, self.upp) + har_source = har_source.transpose(1, 2) + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)): + if i < self.num_upsamples: + x = F.leaky_relu(x, self.lrelu_slope) + x = ups(x) + x_source = noise_convs(har_source) + x = x + x_source + xs = None + l = [i * self.num_kernels + j for j in range(self.num_kernels)] + for j, resblock in enumerate(self.resblocks): + if j in l: + if xs is None: + xs = resblock(x) + else: + xs += resblock(x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + +# Synthesizer classes for different model versions +class SynthesizerTrnMs256NSFsid(nn.Module): + """RVC v1 model with f0""" + def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, **kwargs): + super().__init__() + if isinstance(sr, str): + sr = sr2sr[sr] + self.segment_size = segment_size + self.gin_channels = gin_channels + self.spk_embed_dim = spk_embed_dim + self.enc_p = TextEncoder256(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) + self.dec = GeneratorNSF(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, is_half=kwargs.get("is_half", False)) + self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) + self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) + + @torch.jit.export + def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: torch.Tensor, nsff0: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): + g = self.emb_g(sid).unsqueeze(-1) + m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + if rate is not None: + head = int(z_p.shape[2] * (1 - rate.item())) + z_p = z_p[:, :, head:] + x_mask = x_mask[:, :, head:] + nsff0 = nsff0[:, head:] + z = self.flow(z_p, x_mask, g=g, reverse=True) + o = self.dec(z * x_mask, nsff0, g=g) + return o, x_mask, (z, z_p, m_p, logs_p) + +class SynthesizerTrnMs768NSFsid(nn.Module): + """RVC v2 model with f0""" + def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, **kwargs): + super().__init__() + if isinstance(sr, str): + sr = sr2sr[sr] + self.segment_size = segment_size + self.gin_channels = gin_channels + self.spk_embed_dim = spk_embed_dim + self.enc_p = TextEncoder768(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout)) + self.dec = GeneratorNSF(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, is_half=kwargs.get("is_half", False)) + self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) + self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) + + def forward(self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds): + """Training forward pass""" + g = self.emb_g(ds).unsqueeze(-1) + m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size) + pitchf = slice_segments2(pitchf, ids_slice, self.segment_size) + o = self.dec(z_slice, pitchf, g=g) + return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + @torch.jit.export + def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: torch.Tensor, nsff0: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): + g = self.emb_g(sid).unsqueeze(-1) + m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + if rate is not None: + head = int(z_p.shape[2] * (1.0 - rate.item())) + z_p = z_p[:, :, head:] + x_mask = x_mask[:, :, head:] + nsff0 = nsff0[:, head:] + z = self.flow(z_p, x_mask, g=g, reverse=True) + o = self.dec(z * x_mask, nsff0, g=g) + return o, x_mask, (z, z_p, m_p, logs_p) + +class SynthesizerTrnMs256NSFsid_nono(nn.Module): + """RVC v1 model without f0""" + def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr=None, **kwargs): + super().__init__() + self.segment_size = segment_size + self.gin_channels = gin_channels + self.enc_p = TextEncoder256(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), f0=False) + self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) + self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) + self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) + + @torch.jit.export + def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): + g = self.emb_g(sid).unsqueeze(-1) + m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + if rate is not None: + head = int(z_p.shape[2] * (1.0 - rate.item())) + z_p = z_p[:, :, head:] + x_mask = x_mask[:, :, head:] + z = self.flow(z_p, x_mask, g=g, reverse=True) + o = self.dec(z * x_mask, g=g) + return o, x_mask, (z, z_p, m_p, logs_p) + +class SynthesizerTrnMs768NSFsid_nono(nn.Module): + """RVC v2 model without f0""" + def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr=None, **kwargs): + super().__init__() + self.segment_size = segment_size + self.gin_channels = gin_channels + self.enc_p = TextEncoder768(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), f0=False) + self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) + self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels) + self.emb_g = nn.Embedding(spk_embed_dim, gin_channels) + + @torch.jit.export + def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None): + g = self.emb_g(sid).unsqueeze(-1) + m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + if rate is not None: + head = int(z_p.shape[2] * (1.0 - rate.item())) + z_p = z_p[:, :, head:] + x_mask = x_mask[:, :, head:] + z = self.flow(z_p, x_mask, g=g, reverse=True) + o = self.dec(z * x_mask, g=g) + return o, x_mask, (z, z_p, m_p, logs_p) + +# ============================================================ +# DISCRIMINATOR - For training +# ============================================================ + +class DiscriminatorS(nn.Module): + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = nn.utils.spectral_norm if use_spectral_norm else weight_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, 0.1) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + return x, fmap + +class DiscriminatorP(nn.Module): + def __init__(self, period, use_spectral_norm=False): + super().__init__() + self.period = period + norm_f = nn.utils.spectral_norm if use_spectral_norm else weight_norm + self.convs = nn.ModuleList([ + norm_f(nn.Conv2d(1, 32, (5, 1), (3, 1), padding=(2, 0))), + norm_f(nn.Conv2d(32, 128, (5, 1), (3, 1), padding=(2, 0))), + norm_f(nn.Conv2d(128, 512, (5, 1), (3, 1), padding=(2, 0))), + norm_f(nn.Conv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0))), + norm_f(nn.Conv2d(1024, 1024, (5, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + b, c, t = x.shape + if t % self.period != 0: + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, 0.1) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + return x, fmap + +class MultiPeriodDiscriminator(nn.Module): + def __init__(self, use_spectral_norm=False): + super().__init__() + periods = [2, 3, 5, 7, 11, 17, 23, 37] # 8 periods for v2 pretrained (9 total discriminators) + self.discriminators = nn.ModuleList( + [DiscriminatorS(use_spectral_norm)] + + [DiscriminatorP(p, use_spectral_norm) for p in periods] + ) + + def forward(self, y, y_hat): + y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], [] + for d in self.discriminators: + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + +# ============================================================ +# TRAINING LOSSES +# ============================================================ + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl.float().detach() - gl.float())) + return loss * 2 + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + loss += torch.mean((1 - dr.float()) ** 2) + torch.mean(dg.float() ** 2) + return loss + +def generator_loss(disc_outputs): + loss = 0 + for dg in disc_outputs: + loss += torch.mean((1 - dg.float()) ** 2) + return loss + +def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): + z_p, logs_q, m_p, logs_p, z_mask = [x.float() for x in [z_p, logs_q, m_p, logs_p, z_mask]] + kl = logs_p - logs_q - 0.5 + 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + return torch.sum(kl * z_mask) / torch.sum(z_mask) + +# ============================================================ +# HUBERT EXTRACTION - Using torchaudio bundle +# ============================================================ + +# ContentVec model for v1 (256-dim) and HuBERT for v2 (768-dim) +_contentvec_model = None # For v1 models (256-dim output) +_hubert_model = None # For v2 models (768-dim output) +_hubert_bundle = None + +CONTENTVEC_REPO = "IAHispano/Applio" +CONTENTVEC_MODEL = "Resources/embedders/contentvec/pytorch_model.bin" +CONTENTVEC_CONFIG = "Resources/embedders/contentvec/config.json" + +def load_contentvec(): + """Load ContentVec model from HuggingFace for v1 models (256-dim output)""" + global _contentvec_model + if _contentvec_model is None: + try: + from transformers import HubertModel, HubertConfig + logger.info("Loading ContentVec model from HuggingFace...") + + # Download model files + model_path = hf_hub_download(repo_id=CONTENTVEC_REPO, filename=CONTENTVEC_MODEL) + config_path = hf_hub_download(repo_id=CONTENTVEC_REPO, filename=CONTENTVEC_CONFIG) + + # Create model with final_proj layer + class HubertModelWithFinalProj(HubertModel): + def __init__(self, config): + super().__init__(config) + self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) + + config = HubertConfig.from_pretrained(config_path) + _contentvec_model = HubertModelWithFinalProj(config) + state_dict = torch.load(model_path, map_location="cpu", weights_only=True) + _contentvec_model.load_state_dict(state_dict) + _contentvec_model.to(device).eval() + logger.info(f"ContentVec loaded: hidden={config.hidden_size}, proj={config.classifier_proj_size}") + except Exception as e: + logger.warning(f"Failed to load ContentVec: {e}, falling back to torchaudio HuBERT") + _contentvec_model = None + return _contentvec_model + +def load_hubert(): + """Load HuBERT model via torchaudio for v2 models (768-dim output)""" + global _hubert_model, _hubert_bundle + if _hubert_model is None: + import torchaudio + logger.info("Loading HuBERT model via torchaudio...") + _hubert_bundle = torchaudio.pipelines.HUBERT_BASE + _hubert_model = _hubert_bundle.get_model().to(device) + _hubert_model.eval() + logger.info("HuBERT model loaded") + return _hubert_model, _hubert_bundle + +def extract_hubert_features(audio: np.ndarray, sr: int = 16000, version: str = "v2") -> torch.Tensor: + """Extract ContentVec features from audio (same as Applio) + + v1 models: Use ContentVec with final_proj (256-dim) + v2 models: Use ContentVec without final_proj (768-dim) + """ + audio = audio.astype(np.float32) + if np.abs(audio).max() > 1.0: + audio = audio / np.abs(audio).max() + + inputs = torch.from_numpy(audio).unsqueeze(0).to(device) + + # Use ContentVec for ALL versions (same as Applio) + contentvec = load_contentvec() + if contentvec is not None: + with torch.no_grad(): + output = contentvec(inputs) + if version == "v1": + # v1: use final_proj for 256-dim + feats = contentvec.final_proj(output.last_hidden_state) + else: + # v2: use raw hidden state (768-dim) + feats = output.last_hidden_state + return feats + + # Fallback to torchaudio HuBERT if ContentVec not available + logger.warning("ContentVec not available, using torchaudio HuBERT (results may be degraded)") + hubert, bundle = load_hubert() + with torch.no_grad(): + features, _ = hubert.extract_features(inputs) + layer_idx = 11 if version == "v2" else 8 + feats = features[min(layer_idx, len(features)-1)] + + if version == "v1": + proj = nn.Linear(768, 256, bias=False).to(device) + with torch.no_grad(): + w = torch.zeros(256, 768) + for i in range(256): + w[i, i*3:(i+1)*3] = 1/3 + proj.weight.copy_(w) + feats = proj(feats) + + return feats + +# ============================================================ +# F0 EXTRACTION +# ============================================================ + +def extract_f0_pm(audio: np.ndarray, sr: int = 16000, f0_up_key: int = 0) -> Tuple[np.ndarray, np.ndarray]: + """Extract F0 using parselmouth (pm method)""" + import parselmouth + + p_len = audio.shape[0] // 160 + 1 + f0_min = 65 + f0_max = 1100 + + l_pad = int(np.ceil(1.5 / f0_min * 16000)) + r_pad = l_pad + 1 + + s = parselmouth.Sound(np.pad(audio, (l_pad, r_pad)), 16000).to_pitch_ac( + time_step=0.01, voicing_threshold=0.6, pitch_floor=f0_min, pitch_ceiling=f0_max, + ) + + f0 = s.selected_array["frequency"] + if len(f0) < p_len: + f0 = np.pad(f0, (0, p_len - len(f0))) + f0 = f0[:p_len] + f0 *= pow(2, f0_up_key / 12) + + return f0_to_coarse(f0) + +def extract_f0_harvest(audio: np.ndarray, sr: int = 16000, f0_up_key: int = 0) -> Tuple[np.ndarray, np.ndarray]: + """Extract F0 using pyworld harvest""" + import pyworld + from scipy import signal as scipy_signal + + f0, t = pyworld.harvest(audio.astype(np.double), fs=16000, f0_ceil=1100, f0_floor=50, frame_period=10) + f0 = scipy_signal.medfilt(f0, 3) + f0 *= pow(2, f0_up_key / 12) + + return f0_to_coarse(f0) + +def f0_to_coarse(f0: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Convert f0 to coarse representation""" + f0_min = 50 + f0_max = 1100 + f0_mel_min = 1127 * np.log(1 + f0_min / 700) + f0_mel_max = 1127 * np.log(1 + f0_max / 700) + + f0bak = f0.copy() + f0_mel = 1127 * np.log(1 + f0 / 700) + f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1 + f0_mel[f0_mel <= 1] = 1 + f0_mel[f0_mel > 255] = 255 + f0_coarse = np.rint(f0_mel).astype(np.int32) + + return f0_coarse, f0bak + +# ============================================================ +# RMVPE F0 EXTRACTION (from Applio - IAHispano/Applio) +# ============================================================ + +class RMVPE_ConvBlockRes(nn.Module): + def __init__(self, in_channels, out_channels, momentum=0.01): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False), + nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), + nn.Conv2d(out_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False), + nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), + ) + self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) if in_channels != out_channels else None + + def forward(self, x): + r = self.conv(x) + return r + self.shortcut(x) if self.shortcut else r + x + +class RMVPE_ResEncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01): + super().__init__() + self.conv = nn.ModuleList([RMVPE_ConvBlockRes(in_channels, out_channels, momentum)]) + for _ in range(n_blocks - 1): + self.conv.append(RMVPE_ConvBlockRes(out_channels, out_channels, momentum)) + self.kernel_size = kernel_size + if kernel_size is not None: + self.pool = nn.AvgPool2d(kernel_size=kernel_size) + + def forward(self, x): + for c in self.conv: + x = c(x) + return (x, self.pool(x)) if self.kernel_size is not None else x + +class RMVPE_Encoder(nn.Module): + def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01): + super().__init__() + self.n_encoders = n_encoders + self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) + self.layers = nn.ModuleList() + for _ in range(n_encoders): + self.layers.append(RMVPE_ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum)) + in_channels = out_channels + out_channels *= 2 + in_size //= 2 + self.out_size = in_size + self.out_channel = out_channels + + def forward(self, x): + concat_tensors = [] + x = self.bn(x) + for layer in self.layers: + t, x = layer(x) + concat_tensors.append(t) + return x, concat_tensors + +class RMVPE_Intermediate(nn.Module): + def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01): + super().__init__() + self.layers = nn.ModuleList([RMVPE_ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)]) + for _ in range(n_inters - 1): + self.layers.append(RMVPE_ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + +class RMVPE_ResDecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): + super().__init__() + out_padding = (0, 1) if stride == (1, 2) else (1, 1) + self.conv1 = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, (3, 3), stride, (1, 1), out_padding, bias=False), + nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), + ) + self.conv2 = nn.ModuleList([RMVPE_ConvBlockRes(out_channels * 2, out_channels, momentum)]) + for _ in range(n_blocks - 1): + self.conv2.append(RMVPE_ConvBlockRes(out_channels, out_channels, momentum)) + + def forward(self, x, concat_tensor): + x = self.conv1(x) + x = torch.cat((x, concat_tensor), dim=1) + for c in self.conv2: + x = c(x) + return x + +class RMVPE_Decoder(nn.Module): + def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01): + super().__init__() + self.layers = nn.ModuleList() + for _ in range(n_decoders): + out_channels = in_channels // 2 + self.layers.append(RMVPE_ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)) + in_channels = out_channels + self.n_decoders = n_decoders + + def forward(self, x, concat_tensors): + for i in range(self.n_decoders): + x = self.layers[i](x, concat_tensors[-1 - i]) + return x + +class RMVPE_DeepUnet(nn.Module): + def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): + super().__init__() + self.encoder = RMVPE_Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels) + self.intermediate = RMVPE_Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) + self.decoder = RMVPE_Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) + + def forward(self, x): + x, concat_tensors = self.encoder(x) + x = self.intermediate(x) + x = self.decoder(x, concat_tensors) + return x + +class RMVPE_BiGRU(nn.Module): + def __init__(self, input_features, hidden_features, num_layers): + super().__init__() + self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) + + def forward(self, x): + return self.gru(x)[0] + +RMVPE_N_MELS = 128 +RMVPE_N_CLASS = 360 + +class RMVPE_E2E(nn.Module): + def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): + super().__init__() + self.unet = RMVPE_DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) + self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) + if n_gru: + self.fc = nn.Sequential( + RMVPE_BiGRU(3 * 128, 256, n_gru), + nn.Linear(512, RMVPE_N_CLASS), nn.Dropout(0.25), nn.Sigmoid(), + ) + else: + self.fc = nn.Sequential(nn.Linear(3 * RMVPE_N_MELS, RMVPE_N_CLASS), nn.Dropout(0.25), nn.Sigmoid()) + + def forward(self, mel): + mel = mel.transpose(-1, -2).unsqueeze(1) + x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) + return self.fc(x) + +class RMVPE_MelSpectrogram(nn.Module): + def __init__(self, n_mel_channels=128, sample_rate=16000, win_length=1024, hop_length=160, n_fft=None, mel_fmin=30, mel_fmax=8000, clamp=1e-5): + super().__init__() + from librosa.filters import mel as librosa_mel + n_fft = win_length if n_fft is None else n_fft + self.hann_window = {} + mel_basis = librosa_mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True) + self.register_buffer("mel_basis", torch.from_numpy(mel_basis).float()) + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.clamp = clamp + + def forward(self, audio, keyshift=0, speed=1, center=True): + factor = 2 ** (keyshift / 12) + n_fft_new = int(np.round(self.n_fft * factor)) + win_length_new = int(np.round(self.win_length * factor)) + hop_length_new = int(np.round(self.hop_length * speed)) + key = f"{keyshift}_{audio.device}" + if key not in self.hann_window: + self.hann_window[key] = torch.hann_window(win_length_new).to(audio.device) + fft = torch.stft(audio, n_fft=n_fft_new, hop_length=hop_length_new, win_length=win_length_new, + window=self.hann_window[key], center=center, return_complex=True) + magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) + if keyshift != 0: + size = self.n_fft // 2 + 1 + resize = magnitude.size(1) + if resize < size: + magnitude = F.pad(magnitude, (0, 0, 0, size - resize)) + magnitude = magnitude[:, :size, :] * self.win_length / win_length_new + mel_output = torch.matmul(self.mel_basis, magnitude) + return torch.log(torch.clamp(mel_output, min=self.clamp)) + +_rmvpe_model = None + +def load_rmvpe(): + """Download and load RMVPE model for f0 extraction""" + global _rmvpe_model + if _rmvpe_model is None: + logger.info("Downloading RMVPE model...") + rmvpe_path = hf_hub_download(repo_id="IAHispano/Applio", filename="Resources/predictors/rmvpe.pt") + model = RMVPE_E2E(4, 1, (2, 2)) + ckpt = torch.load(rmvpe_path, map_location="cpu", weights_only=True) + model.load_state_dict(ckpt) + model.eval().to(device) + mel_extractor = RMVPE_MelSpectrogram().to(device) + cents_mapping = 20 * np.arange(RMVPE_N_CLASS) + 1997.3794084376191 + _rmvpe_model = (model, mel_extractor, np.pad(cents_mapping, (4, 4))) + logger.info("RMVPE model loaded") + return _rmvpe_model + +def extract_f0_rmvpe(audio: np.ndarray, sr: int = 16000, f0_up_key: int = 0, thred: float = 0.03) -> Tuple[np.ndarray, np.ndarray]: + """Extract F0 using RMVPE (best quality, neural network based)""" + model, mel_extractor, cents_mapping = load_rmvpe() + + audio_t = torch.from_numpy(audio).float().to(device).unsqueeze(0) + mel = mel_extractor(audio_t, center=True) + del audio_t + + # mel2hidden with chunking + with torch.no_grad(): + n_frames = mel.shape[-1] + mel_padded = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect") + chunks = [] + for start in range(0, mel_padded.shape[-1], 32000): + end = min(start + 32000, mel_padded.shape[-1]) + chunks.append(model(mel_padded[..., start:end])) + hidden = torch.cat(chunks, dim=1)[:, :n_frames].squeeze(0).cpu().numpy() + + # Decode hidden to f0 + center = np.argmax(hidden, axis=1) + salience = np.pad(hidden, ((0, 0), (4, 4))) + center += 4 + todo_salience = [] + todo_cents = [] + for idx in range(salience.shape[0]): + s, e = center[idx] - 4, center[idx] + 5 + todo_salience.append(salience[idx, s:e]) + todo_cents.append(cents_mapping[s:e]) + todo_salience = np.array(todo_salience) + todo_cents = np.array(todo_cents) + cents_pred = np.sum(todo_salience * todo_cents, 1) / np.sum(todo_salience, 1) + cents_pred[np.max(salience, axis=1) <= thred] = 0 + + f0 = 10 * (2 ** (cents_pred / 1200)) + f0[f0 == 10] = 0 + f0 *= pow(2, f0_up_key / 12) + + return f0_to_coarse(f0) + +# ============================================================ +# MODEL LOADING +# ============================================================ + +_model_cache = {} + +def load_rvc_model(model_path: str): + """Load RVC model and auto-detect version""" + if model_path in _model_cache: + return _model_cache[model_path] + + logger.info(f"Loading RVC model: {model_path}") + try: + cpt = torch.load(model_path, map_location="cpu", weights_only=True) + except Exception: + logger.warning("Model requires unsafe loading - may be an older format") + cpt = torch.load(model_path, map_location="cpu", weights_only=False) + + weight_key = None + for key in ["weight", "model", "state_dict", "net_g"]: + if key in cpt: + weight_key = key + break + if weight_key is None: + raise ValueError(f"Cannot find model weights. Keys: {list(cpt.keys())}") + + config = cpt.get("config", None) + if config is None: + config = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4], 109, 256, 40000] + logger.warning("No config found, using v2 defaults") + + version = cpt.get("version", "v1") + if_f0 = cpt.get("f0", 1) + + if weight_key in cpt: + emb_weight = cpt[weight_key].get("emb_g.weight") + if emb_weight is not None: + config[-3] = emb_weight.shape[0] + + sr = config[-1] if isinstance(config[-1], int) else 40000 + + if version == "v1": + model_class = SynthesizerTrnMs256NSFsid if if_f0 == 1 else SynthesizerTrnMs256NSFsid_nono + else: + model_class = SynthesizerTrnMs768NSFsid if if_f0 == 1 else SynthesizerTrnMs768NSFsid_nono + + model = model_class( + spec_channels=config[0], segment_size=config[1], inter_channels=config[2], + hidden_channels=config[3], filter_channels=config[4], n_heads=config[5], + n_layers=config[6], kernel_size=config[7], p_dropout=config[8], + resblock=config[9], resblock_kernel_sizes=config[10], + resblock_dilation_sizes=config[11], upsample_rates=config[12], + upsample_initial_channel=config[13], upsample_kernel_sizes=config[14], + spk_embed_dim=config[15], gin_channels=config[16], sr=sr, is_half=False + ) + + model.load_state_dict(cpt[weight_key], strict=False) + model.eval().to(device) + + _model_cache[model_path] = (model, sr, version, if_f0) + logger.info(f"Model loaded: version={version}, f0={if_f0}, sr={sr}") + + return model, sr, version, if_f0 + +# ============================================================ +# TRAINING - Simplified for CPU testing +# ============================================================ + +def spectrogram_torch(y, n_fft, hop_size, win_size, center=False): + """Compute spectrogram""" + hann_window = torch.hann_window(win_size).to(y.device) + y = F.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect').squeeze(1) + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) + spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6) + return spec + +# Mel spectrogram for training loss +_mel_basis_cache = {} + +def spec_to_mel_torch(spec, n_fft=2048, num_mels=125, sampling_rate=40000, fmin=0, fmax=None): + """Convert spectrogram to mel spectrogram""" + from librosa.filters import mel as librosa_mel_fn + global _mel_basis_cache + + if fmax is None: + fmax = sampling_rate // 2 + + key = f"{n_fft}_{num_mels}_{sampling_rate}_{fmin}_{fmax}_{spec.dtype}_{spec.device}" + if key not in _mel_basis_cache: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + _mel_basis_cache[key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + + melspec = torch.matmul(_mel_basis_cache[key], spec) + melspec = torch.log(torch.clamp(melspec, min=1e-5)) # Log-amplitude + return melspec + +def preprocess_audio_for_training(audio_path: str, output_dir: str, target_sr: int = 40000, f0_method: str = "rmvpe"): + """Preprocess audio file for training - slice and extract features""" + import scipy.signal as signal + + os.makedirs(output_dir, exist_ok=True) + os.makedirs(f"{output_dir}/wavs", exist_ok=True) + os.makedirs(f"{output_dir}/hubert", exist_ok=True) + os.makedirs(f"{output_dir}/f0", exist_ok=True) + + logger.info(f"Preprocessing: {audio_path}") + + # Load and resample audio + audio, sr = librosa.load(audio_path, sr=target_sr, mono=True) + + # High-pass filter + bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=target_sr) + audio = signal.lfilter(bh, ah, audio) + + # Slice into chunks (3.7 seconds with 0.3 overlap) + chunk_size = int(3.7 * target_sr) + hop = int(3.4 * target_sr) + + chunks = [] + for i, start in enumerate(range(0, len(audio) - chunk_size, hop)): + chunk = audio[start:start + chunk_size] + # Normalize + max_val = np.abs(chunk).max() + if max_val > 0.01: # Skip silence + chunk = chunk / max_val * 0.9 + chunks.append((i, chunk)) + + if not chunks: + logger.warning("No valid audio chunks found") + return None + + logger.info(f"Created {len(chunks)} chunks") + + # Save chunks and extract features + manifest = [] + for idx, chunk in chunks: + # Save wav + wav_path = f"{output_dir}/wavs/{idx:04d}.wav" + sf.write(wav_path, chunk, target_sr) + + # Resample to 16k for HuBERT + chunk_16k = librosa.resample(chunk, orig_sr=target_sr, target_sr=16000) + + # Extract HuBERT features + feats = extract_hubert_features(chunk_16k, sr=16000, version="v2") + hubert_path = f"{output_dir}/hubert/{idx:04d}.npy" + np.save(hubert_path, feats.squeeze(0).cpu().numpy()) + + # Extract F0 + if f0_method == "rmvpe": + f0_coarse, f0 = extract_f0_rmvpe(chunk_16k, 16000, 0) + elif f0_method == "harvest": + f0_coarse, f0 = extract_f0_harvest(chunk_16k, 16000, 0) + else: + f0_coarse, f0 = extract_f0_pm(chunk_16k, 16000, 0) + f0_path = f"{output_dir}/f0/{idx:04d}.npy" + np.save(f0_path, np.stack([f0_coarse, f0], axis=0)) + + manifest.append(f"{idx:04d}") + + # Save manifest + with open(f"{output_dir}/manifest.txt", "w") as f: + f.write("\n".join(manifest)) + + logger.info(f"Preprocessing complete: {len(manifest)} samples") + return output_dir + +def train_rvc_generator( + data_dir: str, + output_dir: str, + epochs: int = 10, + batch_size: int = 2, + lr: float = 1e-5, # Lower LR prevents overfitting on small data + target_sr: int = 40000, + progress_callback=None +): + """Generator version of train_rvc - yields (epoch_msg, ckpt_path) tuples""" + logger.info(f"Starting training: {data_dir} -> {output_dir}") + os.makedirs(output_dir, exist_ok=True) + + # Load manifest + with open(f"{data_dir}/manifest.txt") as f: + samples = [l.strip() for l in f if l.strip()] + + if len(samples) < 1: + logger.error("No training samples found") + return None + + logger.info(f"Training with {len(samples)} samples") + + # Model config (v2 40k defaults) + config = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4], 1, 256, target_sr] + + # Create models (v2 only - 768-dim HuBERT features) + net_g = SynthesizerTrnMs768NSFsid( + spec_channels=config[0], segment_size=config[1], inter_channels=config[2], + hidden_channels=config[3], filter_channels=config[4], n_heads=config[5], + n_layers=config[6], kernel_size=config[7], p_dropout=config[8], + resblock=config[9], resblock_kernel_sizes=config[10], + resblock_dilation_sizes=config[11], upsample_rates=config[12], + upsample_initial_channel=config[13], upsample_kernel_sizes=config[14], + spk_embed_dim=config[15], gin_channels=config[16], sr=target_sr + ).to(train_device) + + net_d = MultiPeriodDiscriminator().to(train_device) + logger.info(f"Training on device: {train_device}") + + # Download and load pretrained weights (essential for good results) + sr_key = f"{target_sr // 1000}k" # e.g., "40k" + try: + pretrain_g_path = download_pretrained_rvc(f"f0G{sr_key}") + pretrain_d_path = download_pretrained_rvc(f"f0D{sr_key}") + load_pretrained_weights(net_g, pretrain_g_path) + load_pretrained_weights(net_d, pretrain_d_path) + except Exception as e: + logger.warning(f"Failed to load pretrained weights: {e}") + logger.warning("Training from scratch (results may be poor)") + + # Optimizers (after loading pretrained weights) + optim_g = torch.optim.AdamW(net_g.parameters(), lr=lr, betas=(0.8, 0.99)) + optim_d = torch.optim.AdamW(net_d.parameters(), lr=lr, betas=(0.8, 0.99)) + + # LR scheduler (matches Applio - exponential decay) + lr_decay = 0.999875 + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=lr_decay) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=lr_decay) + + net_g.train() + net_d.train() + + # Training loop + for epoch in range(epochs): + total_loss_g, total_loss_d = 0, 0 + np.random.shuffle(samples) + + for i in range(0, len(samples), batch_size): + batch_samples = samples[i:i+batch_size] + + # Load batch data + wavs, huberts, f0s = [], [], [] + for s in batch_samples: + wav, _ = librosa.load(f"{data_dir}/wavs/{s}.wav", sr=target_sr, mono=True) + hubert = np.load(f"{data_dir}/hubert/{s}.npy") + # Upsample 50Hz -> 100Hz using interpolation (same as inference) + hubert_t = torch.from_numpy(hubert).unsqueeze(0).permute(0, 2, 1) # (1, 768, seq) + hubert_t = F.interpolate(hubert_t, scale_factor=2, mode='linear', align_corners=False) + hubert = hubert_t.permute(0, 2, 1).squeeze(0).numpy() # (seq*2, 768) + f0_data = np.load(f"{data_dir}/f0/{s}.npy") + wavs.append(wav) + huberts.append(hubert) + f0s.append(f0_data) + + # Compute spectrogram first to get target length + max_wav_len = max(len(w) for w in wavs) + wav_batch = np.zeros((len(wavs), max_wav_len)) + for j, w in enumerate(wavs): + wav_batch[j, :len(w)] = w + wav_t = torch.FloatTensor(wav_batch).unsqueeze(1).to(train_device) + spec = spectrogram_torch(wav_t.squeeze(1), 2048, 400, 2048) + spec_len = spec.shape[2] # Target length for all features + + # Pad/truncate features to match spec length exactly + hubert_batch = np.zeros((len(huberts), spec_len, huberts[0].shape[1])) + f0_batch = np.zeros((len(f0s), spec_len)) + f0f_batch = np.zeros((len(f0s), spec_len)) + + for j, (h, f) in enumerate(zip(huberts, f0s)): + # Truncate or pad HuBERT to spec_len + h_len = min(h.shape[0], spec_len) + hubert_batch[j, :h_len] = h[:h_len] + # Truncate or pad F0 to spec_len + f0_len = min(f.shape[1], spec_len) + f0_batch[j, :f0_len] = f[0, :f0_len] + f0f_batch[j, :f0_len] = f[1, :f0_len] + + # To tensors - all features now have spec_len + hubert_t = torch.FloatTensor(hubert_batch).to(train_device) + f0_t = torch.LongTensor(f0_batch.astype(np.int64)).to(train_device) + f0f_t = torch.FloatTensor(f0f_batch).to(train_device) + lengths_t = torch.LongTensor([spec_len] * len(batch_samples)).to(train_device) + sid_t = torch.LongTensor([0] * len(batch_samples)).to(train_device) + spec_lengths = torch.LongTensor([spec_len] * len(batch_samples)).to(train_device) + + # Forward pass generator + # Args: phone, phone_lengths, pitch, pitchf, y (spec), y_lengths, ds + try: + y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = net_g( + hubert_t, lengths_t, f0_t, f0f_t, spec, spec_lengths, sid_t + ) + except Exception as e: + logger.warning(f"Generator forward failed: {e}") + continue + + # Slice wav at same position model generated (CRITICAL for proper loss) + # ids_slice is in latent space, multiply by hop_length to get waveform position + hop_length = 400 + segment_size_wav = 32 * hop_length # segment_size in latent * hop_length + y = slice_segments(wav_t, ids_slice * hop_length, segment_size_wav) + + # Discriminator forward + y_d_rs, y_d_gs, fmap_rs, fmap_gs = net_d(y, y_hat.detach()) + + # Discriminator loss + loss_d = discriminator_loss(y_d_rs, y_d_gs) + + optim_d.zero_grad() + loss_d.backward() + optim_d.step() + + # Generator loss + y_d_rs, y_d_gs, fmap_rs, fmap_gs = net_d(y, y_hat) + loss_gen = generator_loss(y_d_gs) + loss_fm = feature_loss(fmap_rs, fmap_gs) + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) + + # Mel spectrogram loss (crucial for quality) + # Config: n_fft=2048, hop=400, win=2048, n_mels=125, fmin=0, fmax=None + y_mel = spec_to_mel_torch(spectrogram_torch(y.squeeze(1), 2048, 400, 2048), + n_fft=2048, num_mels=125, sampling_rate=target_sr, fmin=0, fmax=None) + y_hat_mel = spec_to_mel_torch(spectrogram_torch(y_hat.squeeze(1), 2048, 400, 2048), + n_fft=2048, num_mels=125, sampling_rate=target_sr, fmin=0, fmax=None) + # Align lengths if needed + min_len = min(y_mel.shape[2], y_hat_mel.shape[2]) + loss_mel = F.l1_loss(y_mel[:, :, :min_len], y_hat_mel[:, :, :min_len]) * 45 # c_mel = 45 + + loss_g = loss_gen + loss_fm + loss_mel + loss_kl + + optim_g.zero_grad() + loss_g.backward() + optim_g.step() + + total_loss_g += loss_g.item() + total_loss_d += loss_d.item() + + avg_loss_g = total_loss_g / max(1, len(samples) // batch_size) + avg_loss_d = total_loss_d / max(1, len(samples) // batch_size) + epoch_msg = f"Epoch {epoch+1}/{epochs} - G: {avg_loss_g:.2f}, D: {avg_loss_d:.2f}" + logger.info(epoch_msg) + + # Update progress callback if provided + if progress_callback: + progress_pct = 0.30 + (0.65 * (epoch + 1) / epochs) + progress_callback(progress_pct, epoch_msg) + + # Yield epoch message for live UI updates + yield epoch_msg, None, None + + # Step LR schedulers + scheduler_g.step() + scheduler_d.step() + + # Save checkpoint + ckpt_path = f"{output_dir}/model.pth" + torch.save({ + "weight": net_g.state_dict(), + "config": config, + "version": "v2", # v2 only + "f0": 1, + }, ckpt_path) + logger.info(f"Saved checkpoint: {ckpt_path}") + + # Generate index file for better speaker similarity + index_path = None + try: + import faiss + hubert_dir = f"{data_dir}/hubert" + npys = [] + for name in sorted(os.listdir(hubert_dir)): + if name.endswith('.npy'): + phone = np.load(os.path.join(hubert_dir, name)) + npys.append(phone) + if npys: + big_npy = np.concatenate(npys, axis=0) + n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39) + n_ivf = max(1, n_ivf) # Ensure at least 1 + index = faiss.index_factory(big_npy.shape[1], f"IVF{n_ivf},Flat") + index.train(big_npy) + index.add(big_npy) + index_path = f"{output_dir}/model.index" + faiss.write_index(index, index_path) + logger.info(f"Saved index: {index_path}") + except Exception as e: + logger.warning(f"Failed to generate index: {e}") + + # Cleanup training models to free memory + del net_g, net_d, optim_g, optim_d, scheduler_g, scheduler_d + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + yield "Training complete!", ckpt_path, index_path + + +def train_rvc( + data_dir: str, + output_dir: str, + epochs: int = 10, + batch_size: int = 2, + lr: float = 1e-5, # Lower LR prevents overfitting on small data + target_sr: int = 40000, + progress_callback=None +): + """Non-generator wrapper for CLI use - returns (checkpoint_path, index_path)""" + ckpt = None + idx = None + for msg, path, index in train_rvc_generator(data_dir, output_dir, epochs, batch_size, lr, target_sr, progress_callback): + if path: + ckpt = path + if index: + idx = index + return ckpt, idx + +# ============================================================ +# INFERENCE +# ============================================================ + +def convert_voice( + source_audio: str, + model_file, + index_file=None, + pitch_shift: int = 0, + f0_method: str = "pm", + index_rate: float = 0.5, + protect: float = 0.33, + volume_envelope: float = 1.0, + progress=gr.Progress() +) -> Tuple[str, str]: + """Convert voice using RVC model (Applio-compatible pipeline).""" + try: + if source_audio is None: + return None, "Please upload source audio" + if model_file is None: + return None, "Please upload RVC model (.pth)" + + model_path = model_file.name if hasattr(model_file, 'name') else model_file + + progress(0.1, "Loading model...") + model, tgt_sr, version, if_f0 = load_rvc_model(model_path) + + progress(0.2, "Loading audio...") + audio, sr = librosa.load(source_audio, sr=16000, mono=True) + + # Apply 48Hz high-pass filter (critical - removes low-frequency artifacts) + audio = signal.filtfilt(bh, ah, audio) + + # Normalize audio + audio_max = np.abs(audio).max() / 0.95 + if audio_max > 1: + audio /= audio_max + + # Pipeline constants (same as Applio) + window = 160 # Critical for feature/pitch alignment + x_pad = 1 # Padding in seconds + t_pad = 16000 * x_pad # Padding in samples + + # Pad audio + audio_pad = np.pad(audio, (t_pad, t_pad), mode="reflect") + p_len = audio_pad.shape[0] // window + + progress(0.3, "Extracting features...") + feats = extract_hubert_features(audio_pad, sr=16000, version=version) + + # Save original features for protect mechanism + feats0 = feats.clone() if if_f0 == 1 and protect < 0.5 else None + + # Index retrieval (speaker similarity) + if index_file is not None and index_rate > 0: + try: + import faiss + index_path = index_file.name if hasattr(index_file, 'name') else index_file + progress(0.4, "Loading index...") + index = faiss.read_index(index_path) + big_npy = index.reconstruct_n(0, index.ntotal) + + npy = feats[0].cpu().numpy().astype("float32") + score, ix = index.search(npy, k=8) + weight = np.square(1 / score) + weight /= weight.sum(axis=1, keepdims=True) + npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) + feats = torch.from_numpy(npy).unsqueeze(0).to(device) * index_rate + (1 - index_rate) * feats + except Exception as e: + logger.warning(f"Index retrieval failed: {e}") + + # Feature upsampling by 2x + feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) + + # Adjust length based on audio + p_len = min(audio_pad.shape[0] // window, feats.shape[1]) + + pitch, pitchf = None, None + if if_f0 == 1: + progress(0.5, f"Extracting F0 ({f0_method})...") + if f0_method == "rmvpe": + pitch, pitchf = extract_f0_rmvpe(audio_pad, 16000, pitch_shift) + elif f0_method == "harvest": + pitch, pitchf = extract_f0_harvest(audio_pad, 16000, pitch_shift) + else: + pitch, pitchf = extract_f0_pm(audio_pad, 16000, pitch_shift) + + pitch = pitch[:p_len] + pitchf = pitchf[:p_len] + + # Upsample feats0 for protect + if feats0 is not None: + feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) + + # Apply protect mechanism (preserve original features for unvoiced segments) + if protect < 0.5 and feats0 is not None: + pitchf_tensor = torch.from_numpy(pitchf).float().to(device) + pitchff = pitchf_tensor.clone() + pitchff[pitchf_tensor > 0] = 1 + pitchff[pitchf_tensor < 1] = protect + pitchff = pitchff.unsqueeze(0).unsqueeze(-1) + feats = feats[:, :p_len, :] * pitchff + feats0[:, :p_len, :] * (1 - pitchff) + + if len(pitch) < p_len: + pitch = np.pad(pitch, (0, p_len - len(pitch))) + pitchf = np.pad(pitchf, (0, p_len - len(pitchf))) + + pitch = torch.LongTensor(pitch).unsqueeze(0).to(device) + pitchf = torch.FloatTensor(pitchf).unsqueeze(0).to(device) + + p_len_tensor = torch.LongTensor([p_len]).to(device) + sid = torch.LongTensor([0]).to(device) + + progress(0.7, "Running inference...") + with torch.no_grad(): + if if_f0 == 1: + audio_out = model.infer(feats[:, :p_len, :], p_len_tensor, pitch, pitchf, sid)[0][0, 0].data.cpu().float().numpy() + else: + audio_out = model.infer(feats[:, :p_len, :], p_len_tensor, sid)[0][0, 0].data.cpu().float().numpy() + + # Remove padding from output + t_pad_tgt = int(t_pad * tgt_sr / 16000) + if len(audio_out) > 2 * t_pad_tgt: + audio_out = audio_out[t_pad_tgt:-t_pad_tgt] + + # RMS mixing - match volume dynamics of source audio + if volume_envelope != 1.0: + try: + source_at_tgt_sr = librosa.resample(audio, orig_sr=16000, target_sr=tgt_sr) + frame_len = tgt_sr // 2 * 2 + hop_len = tgt_sr // 2 + + rms_source = librosa.feature.rms(y=source_at_tgt_sr, frame_length=frame_len, hop_length=hop_len) + rms_output = librosa.feature.rms(y=audio_out, frame_length=frame_len, hop_length=hop_len) + + rms_source = F.interpolate( + torch.from_numpy(rms_source).float().unsqueeze(0), + size=audio_out.shape[0], mode="linear" + ).squeeze() + rms_output = F.interpolate( + torch.from_numpy(rms_output).float().unsqueeze(0), + size=audio_out.shape[0], mode="linear" + ).squeeze() + rms_output = torch.maximum(rms_output, torch.zeros_like(rms_output) + 1e-6) + + # Applio formula: target * (source^(1-rate) * output^(rate-1)) + audio_out = audio_out * (torch.pow(rms_source, 1 - volume_envelope) * torch.pow(rms_output, volume_envelope - 1)).numpy() + except Exception as e: + logger.warning(f"RMS mixing failed: {e}") + + # Final normalization + audio_max = np.abs(audio_out).max() / 0.99 + if audio_max > 1: + audio_out /= audio_max + + progress(0.9, "Saving output...") + fd, output_path = tempfile.mkstemp(suffix=".wav") + os.close(fd) + sf.write(output_path, audio_out, tgt_sr) + + # Cleanup tensors and model cache to prevent OOM on HF Spaces + del model, feats, audio_out, audio, audio_pad + if feats0 is not None: + del feats0 + _model_cache.clear() + gc.collect() + + return output_path, f"Converted: {version}, sr={tgt_sr}, pitch={pitch_shift:+d}" + + except Exception as e: + logger.exception("Conversion failed") + _model_cache.clear() + gc.collect() + return None, f"Error: {str(e)}" + +# ============================================================ +# DEFAULT MODEL DOWNLOAD +# ============================================================ + +def load_example_model(): + """Download and load the default example model from HuggingFace""" + import shutil + try: + logger.info(f"Downloading example model from {DEFAULT_MODEL_REPO}...") + model_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_MODEL_FILE) + index_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_INDEX_FILE) + + # Gradio 6 requires files to be in allowed directories (cwd or /tmp) + # Copy from HF cache to temp directory + temp_dir = tempfile.mkdtemp() + temp_model = os.path.join(temp_dir, DEFAULT_MODEL_FILE) + temp_index = os.path.join(temp_dir, DEFAULT_INDEX_FILE) + shutil.copy2(model_path, temp_model) + shutil.copy2(index_path, temp_index) + + return temp_model, temp_index, f"Loaded: {DEFAULT_MODEL_REPO}" + except Exception as e: + logger.exception("Failed to download example model") + return None, None, f"Error: {str(e)}" + +# ============================================================ +# BEATRICE V2 MODEL +# ============================================================ + +def beatrice_load_audio(file, **kwargs): + """Load audio using soundfile directly (for Beatrice dataset)""" + data, sr = sf.read(file, dtype='float32') + # soundfile returns (samples, channels), convert to torch (channels, samples) + wav = torch.from_numpy(data) + if wav.ndim == 1: + wav = wav.unsqueeze(0) # mono -> (1, samples) + else: + wav = wav.T # (samples, channels) -> (channels, samples) + return wav, sr + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + + +def dump_params(params: torch.Tensor, f: BinaryIO): + if params is None: + return + if params.dtype == torch.bfloat16: + f.write( + params.detach() + .clone() + .float() + .view(torch.short) + .numpy() + .ravel()[1::2] + .tobytes() + ) + else: + f.write(params.detach().numpy().ravel().tobytes()) + f.flush() + + +def dump_layer(layer: nn.Module, f: BinaryIO): + dump = partial(dump_params, f=f) + if hasattr(layer, "dump"): + layer.dump(f) + elif isinstance(layer, (nn.Linear, nn.Conv1d, nn.LayerNorm)): + dump(layer.weight) + dump(layer.bias) + elif isinstance(layer, nn.MultiheadAttention): + embed_dim = layer.embed_dim + num_heads = layer.num_heads + # [3 * embed_dim, embed_dim] + in_proj_weight = layer.in_proj_weight.data.clone() + in_proj_weight[: 2 * embed_dim] *= 1.0 / math.sqrt( + math.sqrt(embed_dim // num_heads) + ) + in_proj_weight = in_proj_weight.view( + 3, num_heads, embed_dim // num_heads, embed_dim + ) + # [num_heads, 3, embed_dim / num_heads, embed_dim] + in_proj_weight = in_proj_weight.transpose(0, 1) + # [3 * embed_dim] + in_proj_bias = layer.in_proj_bias.data.clone() + in_proj_bias[: 2 * embed_dim] *= 1.0 / math.sqrt( + math.sqrt(embed_dim // num_heads) + ) + in_proj_bias = in_proj_bias.view(3, num_heads, embed_dim // num_heads) + # [num_heads, 3, embed_dim / num_heads] + in_proj_bias = in_proj_bias.transpose(0, 1) + dump(in_proj_weight) + dump(in_proj_bias) + dump(layer.out_proj.weight) + dump(layer.out_proj.bias) + elif isinstance(layer, nn.Embedding): + dump(layer.weight) + elif isinstance(layer, nn.Parameter): + dump(layer) + elif isinstance(layer, nn.ModuleList): + for layer_i in layer: + dump_layer(layer_i, f) + else: + assert False, layer + + +class CausalConv1d(nn.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + delay: int = 0, + ): + padding = (kernel_size - 1) * dilation - delay + self.trim = (kernel_size - 1) * dilation - 2 * delay + if self.trim < 0: + raise ValueError + super().__init__( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + result = super().forward(input) + if self.trim == 0: + return result + else: + return result[:, :, : -self.trim] + + +class WSConv1d(CausalConv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + delay: int = 0, + ): + super().__init__( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + delay=delay, + ) + self.weight.data.normal_( + 0.0, math.sqrt(1.0 / (in_channels * kernel_size // groups)) + ) + if bias: + self.bias.data.zero_() + self.gain = nn.Parameter(torch.ones((out_channels, 1, 1))) + + def standardized_weight(self) -> torch.Tensor: + var, mean = torch.var_mean(self.weight, [1, 2], keepdim=True) + scale = ( + self.gain + * ( + self.in_channels * self.kernel_size[0] // self.groups * var + 1e-8 + ).rsqrt() + ) + return scale * (self.weight - mean) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + result = F.conv1d( + input, + self.standardized_weight(), + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + if self.trim == 0: + return result + else: + return result[:, :, : -self.trim] + + def merge_weights(self): + self.weight.data[:] = self.standardized_weight().detach() + self.gain.data.fill_(1.0) + + +class WSLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__(in_features, out_features, bias) + self.weight.data.normal_(0.0, math.sqrt(1.0 / in_features)) + self.bias.data.zero_() + self.gain = nn.Parameter(torch.ones((out_features, 1))) + + def standardized_weight(self) -> torch.Tensor: + var, mean = torch.var_mean(self.weight, 1, keepdim=True) + scale = self.gain * (self.in_features * var + 1e-8).rsqrt() + return scale * (self.weight - mean) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.standardized_weight(), self.bias) + + def merge_weights(self): + self.weight.data[:] = self.standardized_weight().detach() + self.gain.data.fill_(1.0) + + +class CrossAttention(nn.Module): + def __init__( + self, + qk_channels: int, + vo_channels: int, + num_heads: int, + in_q_channels: int, + in_kv_channels: int, + out_channels: int, + dropout: float = 0.0, + ): + super().__init__() + assert qk_channels % num_heads == 0 + self.qk_channels = qk_channels + self.vo_channels = vo_channels + self.num_heads = num_heads + self.in_q_channels = in_q_channels + self.in_kv_channels = in_kv_channels + self.out_channels = out_channels + self.dropout = dropout + self.head_qk_channels = qk_channels // num_heads + self.head_vo_channels = vo_channels // num_heads + self.q_projection = nn.Linear(in_q_channels, qk_channels) + self.q_projection.weight.data.normal_(0.0, math.sqrt(1.0 / in_q_channels)) + self.q_projection.bias.data.zero_() + self.kv_projection = nn.Linear(in_kv_channels, qk_channels + vo_channels) + self.kv_projection.weight.data.normal_(0.0, math.sqrt(1.0 / in_kv_channels)) + self.kv_projection.bias.data.zero_() + self.out_projection = nn.Linear(vo_channels, out_channels) + self.out_projection.weight.data.normal_(0.0, math.sqrt(1.0 / vo_channels)) + self.out_projection.bias.data.zero_() + + def forward( + self, + q: torch.Tensor, + kv: torch.Tensor, + ) -> torch.Tensor: + # q: [batch_size, q_length, in_q_channels] + # kv: [batch_size, kv_length, in_kv_channels] + batch_size, q_length, _ = q.size() + _, kv_length, _ = kv.size() + # [batch_size, q_length, qk_channels] + q = self.q_projection(q) + # [batch_size, kv_length, qk_channels + vo_channels] + kv = self.kv_projection(kv) + # [batch_size, kv_length, qk_channels], [batch_size, kv_length, vo_channels] + k, v = kv.split([self.qk_channels, self.vo_channels], dim=2) + q = q.view( + batch_size, q_length, self.num_heads, self.head_qk_channels + ).transpose(1, 2) + k = k.view( + batch_size, kv_length, self.num_heads, self.head_qk_channels + ).transpose(1, 2) + v = v.view( + batch_size, kv_length, self.num_heads, self.head_vo_channels + ).transpose(1, 2) + # [batch_size, num_heads, q_length, head_vo_channels] + attn_out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout) + # [batch_size, q_length, vo_channels] + attn_out = ( + attn_out.transpose(1, 2) + .contiguous() + .view(batch_size, q_length, self.vo_channels) + ) + # [batch_size, q_length, out_channels] + attn_out = self.out_projection(attn_out) + return attn_out + + def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): + if isinstance(f, (str, bytes, os.PathLike)): + with open(f, "wb") as f: + self.dump(f) + return + if not hasattr(f, "write"): + raise TypeError + + q_projection_weight = self.q_projection.weight.data.clone() + q_projection_bias = self.q_projection.bias.data.clone() + q_projection_weight *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) + q_projection_bias *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) + dump_params(q_projection_weight, f) + dump_params(q_projection_bias, f) + dump_layer(self.out_projection, f) + + def dump_kv(self, f: Union[BinaryIO, str, bytes, os.PathLike]): + if isinstance(f, (str, bytes, os.PathLike)): + with open(f, "wb") as f: + self.dump_kv(f) + return + if not hasattr(f, "write"): + raise TypeError + + kv_projection_weight = self.kv_projection.weight.data.clone() + kv_projection_bias = self.kv_projection.bias.data.clone() + k_projection_weight, v_projection_weight = kv_projection_weight.split( + [self.qk_channels, self.vo_channels] + ) + k_projection_bias, v_projection_bias = kv_projection_bias.split( + [self.qk_channels, self.vo_channels] + ) + k_projection_weight *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) + k_projection_bias *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels)) + # [qk_channels, in_kv_channels] -> [num_heads, head_qk_channels, in_kv_channels] + k_projection_weight = k_projection_weight.view( + self.num_heads, self.head_qk_channels, self.in_kv_channels + ) + # [qk_channels] -> [num_heads, head_qk_channels] + k_projection_bias = k_projection_bias.view( + self.num_heads, self.head_qk_channels + ) + # [vo_channels, in_kv_channels] -> [num_heads, head_vo_channels, in_kv_channels] + v_projection_weight = v_projection_weight.view( + self.num_heads, self.head_vo_channels, self.in_kv_channels + ) + # [vo_channels] -> [num_heads, head_vo_channels] + v_projection_bias = v_projection_bias.view( + self.num_heads, self.head_vo_channels + ) + for i in range(self.num_heads): + # [head_qk_channels, in_kv_channels] + dump_params(k_projection_weight[i], f) + # [head_vo_channels, in_kv_channels] + dump_params(v_projection_weight[i], f) + for i in range(self.num_heads): + # [head_qk_channels] + dump_params(k_projection_bias[i], f) + # [head_vo_channels] + dump_params(v_projection_bias[i], f) + + +class ConvNeXtBlock(nn.Module): + def __init__( + self, + channels: int, + intermediate_channels: int, + layer_scale_init_value: float, + kernel_size: int = 7, + use_weight_standardization: bool = False, + enable_scaling: bool = False, + pre_scale: float = 1.0, + post_scale: float = 1.0, + use_mha: bool = False, + cross_attention: bool = False, + num_heads: int = 4, + attention_dropout: float = 0.1, + attention_channels: Optional[int] = None, + kv_channels: Optional[int] = None, + ): + super().__init__() + self.use_weight_standardization = use_weight_standardization + self.enable_scaling = enable_scaling + self.use_mha = use_mha + self.cross_attention = cross_attention + if use_mha: + self.attn_norm = nn.LayerNorm(channels) + if cross_attention: + self.mha = CrossAttention( + qk_channels=attention_channels, + vo_channels=attention_channels, + num_heads=num_heads, + in_q_channels=channels, + in_kv_channels=kv_channels, + out_channels=channels, + dropout=attention_dropout, + ) + else: # self-attention + assert attention_channels is None + assert kv_channels is None + self.mha = nn.MultiheadAttention( + embed_dim=channels, + num_heads=num_heads, + dropout=attention_dropout, + batch_first=True, + ) + self.dwconv = CausalConv1d( + channels, channels, kernel_size=kernel_size, groups=channels + ) + self.norm = nn.LayerNorm(channels) + self.pwconv1 = nn.Linear(channels, intermediate_channels) + self.pwconv2 = nn.Linear(intermediate_channels, channels) + self.gamma = nn.Parameter(torch.full((channels,), layer_scale_init_value)) + self.dwconv.weight.data.normal_(0.0, math.sqrt(1.0 / kernel_size)) + self.dwconv.bias.data.zero_() + self.pwconv1.weight.data.normal_(0.0, math.sqrt(2.0 / channels)) + self.pwconv1.bias.data.zero_() + self.pwconv2.weight.data.normal_(0.0, math.sqrt(1.0 / intermediate_channels)) + self.pwconv2.bias.data.zero_() + if use_weight_standardization: + self.norm = nn.Identity() + self.dwconv = WSConv1d(channels, channels, kernel_size, groups=channels) + self.pwconv1 = WSLinear(channels, intermediate_channels) + self.pwconv2 = WSLinear(intermediate_channels, channels) + del self.gamma + if enable_scaling: + self.register_buffer("pre_scale", torch.tensor(pre_scale)) + self.register_buffer("post_scale", torch.tensor(post_scale)) + self.post_scale_weight = nn.Parameter(torch.ones(())) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + kv: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.use_mha: + batch_size, channels, length = x.size() + if self.cross_attention: + assert kv is not None + else: + assert kv is None + assert length % 4 == 0 + identity = x + if self.cross_attention: + # kv: [batch_size, kv_length, kv_channels] + x = x.transpose(1, 2) + x = self.attn_norm(x) + x = self.mha(x, kv) + x = x.transpose(1, 2) + else: + x = x.view(batch_size, channels, length // 4, 4) + x = x.permute(0, 3, 2, 1) + x = x.reshape(batch_size * 4, length // 4, channels) + x = self.attn_norm(x) + x, _ = self.mha( + x, x, x, attn_mask=attn_mask, is_causal=True, need_weights=False + ) + x = x.view(batch_size, 4, length // 4, channels) + x = x.permute(0, 3, 2, 1) + x = x.reshape(batch_size, channels, length) + x += identity + + identity = x + if self.enable_scaling: + x = x * self.pre_scale + x = self.dwconv(x) + x = x.transpose(1, 2) + x = self.norm(x) + x = self.pwconv1(x) + x = F.gelu(x, approximate="tanh") + x = self.pwconv2(x) + if not self.use_weight_standardization: + x *= self.gamma + if self.enable_scaling: + x *= self.post_scale * self.post_scale_weight + x = x.transpose(1, 2) + x += identity + return x + + def merge_weights(self): + if self.use_mha: + if self.cross_attention: + assert isinstance(self.mha, CrossAttention) + self.mha.q_projection.bias.data += torch.mv( + self.mha.q_projection.weight.data, self.attn_norm.bias.data + ) + self.mha.q_projection.weight.data *= self.attn_norm.weight.data[None, :] + self.attn_norm.bias.data[:] = 0.0 + self.attn_norm.weight.data[:] = 1.0 + else: # self-attention + assert isinstance(self.mha, nn.MultiheadAttention) + self.mha.in_proj_bias.data += torch.mv( + self.mha.in_proj_weight.data, self.attn_norm.bias.data + ) + self.mha.in_proj_weight.data *= self.attn_norm.weight.data[None, :] + self.attn_norm.bias.data[:] = 0.0 + self.attn_norm.weight.data[:] = 1.0 + if self.use_weight_standardization: + self.dwconv.merge_weights() + self.pwconv1.merge_weights() + self.pwconv2.merge_weights() + else: + self.pwconv1.bias.data += torch.mv( + self.pwconv1.weight.data, self.norm.bias.data + ) + self.pwconv1.weight.data *= self.norm.weight.data[None, :] + self.norm.bias.data[:] = 0.0 + self.norm.weight.data[:] = 1.0 + self.pwconv2.weight.data *= self.gamma.data[:, None] + self.pwconv2.bias.data *= self.gamma.data + self.gamma.data[:] = 1.0 + if self.enable_scaling: + self.dwconv.weight.data *= self.pre_scale.data + self.pre_scale.data.fill_(1.0) + self.pwconv2.weight.data *= ( + self.post_scale.data * self.post_scale_weight.data + ) + self.pwconv2.bias.data *= self.post_scale.data * self.post_scale_weight.data + self.post_scale.data.fill_(1.0) + self.post_scale_weight.data.fill_(1.0) + + def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): + if isinstance(f, (str, bytes, os.PathLike)): + with open(f, "wb") as f: + self.dump(f) + return + if not hasattr(f, "write"): + raise TypeError + + if self.use_mha: + dump_layer(self.mha, f) + dump_layer(self.dwconv, f) + dump_layer(self.pwconv1, f) + dump_layer(self.pwconv2, f) + + +class ConvNeXtStack(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + intermediate_channels: int, + n_blocks: int, + delay: int, + embed_kernel_size: int, + kernel_size: int, + use_weight_standardization: bool = False, + enable_scaling: bool = False, + use_mha: bool = False, + cross_attention: bool = False, + kv_channels: Optional[int] = None, + ): + super().__init__() + assert delay * 2 + 1 <= embed_kernel_size + assert not (use_weight_standardization and use_mha) # 未対応 + self.use_weight_standardization = use_weight_standardization + self.use_mha = use_mha + self.cross_attention = cross_attention + self.embed = CausalConv1d(in_channels, channels, embed_kernel_size, delay=delay) + self.norm = nn.LayerNorm(channels) + self.convnext = nn.ModuleList() + for i in range(n_blocks): + pre_scale = 1.0 / math.sqrt(1.0 + i / n_blocks) if enable_scaling else 1.0 + post_scale = 1.0 / math.sqrt(n_blocks) if enable_scaling else 1.0 + block = ConvNeXtBlock( + channels=channels, + intermediate_channels=intermediate_channels, + layer_scale_init_value=1.0 / n_blocks, + kernel_size=kernel_size, + use_weight_standardization=use_weight_standardization, + enable_scaling=enable_scaling, + pre_scale=pre_scale, + post_scale=post_scale, + use_mha=use_mha, + cross_attention=cross_attention, + num_heads=4, + attention_dropout=0.1, + attention_channels=kv_channels, + kv_channels=kv_channels, + ) + self.convnext.append(block) + self.final_layer_norm = nn.LayerNorm(channels) + self.embed.weight.data.normal_( + 0.0, math.sqrt(0.5 / (embed_kernel_size * in_channels)) + ) + self.embed.bias.data.zero_() + if use_weight_standardization: + self.embed = WSConv1d(in_channels, channels, embed_kernel_size, delay=delay) + self.norm = nn.Identity() + self.final_layer_norm = nn.Identity() + + def forward( + self, x: torch.Tensor, kv: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x = self.embed(x) + x = self.norm(x.transpose(1, 2)).transpose(1, 2) + if self.use_mha and not self.cross_attention: + pad_length = -x.size(2) % 4 + if pad_length: + x = F.pad(x, (0, pad_length)) + t40 = x.size(2) // 4 + attn_mask = torch.ones((t40, t40), dtype=torch.bool, device=x.device).triu( + 1 + ) + else: + attn_mask = None + for conv_block in self.convnext: + x = conv_block(x, attn_mask=attn_mask, kv=kv) + if self.use_mha and not self.cross_attention and pad_length: + x = x[:, :, :-pad_length] + x = self.final_layer_norm(x.transpose(1, 2)).transpose(1, 2) + return x + + def merge_weights(self): + if self.use_weight_standardization: + self.embed.merge_weights() + for conv_block in self.convnext: + conv_block.merge_weights() + + def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): + if isinstance(f, (str, bytes, os.PathLike)): + with open(f, "wb") as f: + self.dump(f) + return + if not hasattr(f, "write"): + raise TypeError + + dump_layer(self.embed, f) + if not self.use_weight_standardization: + dump_layer(self.norm, f) + dump_layer(self.convnext, f) + if not self.use_weight_standardization: + dump_layer(self.final_layer_norm, f) + + def dump_kv(self, f: Union[BinaryIO, str, bytes, os.PathLike]): + if isinstance(f, (str, bytes, os.PathLike)): + with open(f, "wb") as f: + self.dump_kv(f) + return + if not hasattr(f, "write"): + raise TypeError + + assert self.use_mha and self.cross_attention + for conv_block in self.convnext: + if not conv_block.use_mha or not conv_block.cross_attention: + continue + assert isinstance(conv_block, ConvNeXtBlock) + assert hasattr(conv_block, "mha") + assert isinstance(conv_block.mha, CrossAttention) + conv_block.mha.dump_kv(f) + + +class FeatureExtractor(nn.Module): + def __init__(self, hidden_channels: int): + super().__init__() + # fmt: off + self.conv0 = weight_norm(nn.Conv1d(1, hidden_channels // 8, 10, 5, bias=False)) + self.conv1 = weight_norm(nn.Conv1d(hidden_channels // 8, hidden_channels // 4, 3, 2, bias=False)) + self.conv2 = weight_norm(nn.Conv1d(hidden_channels // 4, hidden_channels // 2, 3, 2, bias=False)) + self.conv3 = weight_norm(nn.Conv1d(hidden_channels // 2, hidden_channels, 3, 2, bias=False)) + self.conv4 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 3, 2, bias=False)) + self.conv5 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 2, 2, bias=False)) + # fmt: on + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [batch_size, 1, wav_length] + wav_length = x.size(2) + if wav_length % 160 != 0: + warnings.warn("wav_length % 160 != 0") + x = F.pad(x, (40, 40)) + x = F.gelu(self.conv0(x), approximate="tanh") + x = F.gelu(self.conv1(x), approximate="tanh") + x = F.gelu(self.conv2(x), approximate="tanh") + x = F.gelu(self.conv3(x), approximate="tanh") + x = F.gelu(self.conv4(x), approximate="tanh") + x = F.gelu(self.conv5(x), approximate="tanh") + # [batch_size, hidden_channels, wav_length / 160] + return x + + def remove_weight_norm(self): + remove_weight_norm(self.conv0) + remove_weight_norm(self.conv1) + remove_weight_norm(self.conv2) + remove_weight_norm(self.conv3) + remove_weight_norm(self.conv4) + remove_weight_norm(self.conv5) + + def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): + if isinstance(f, (str, bytes, os.PathLike)): + with open(f, "wb") as f: + self.dump(f) + return + if not hasattr(f, "write"): + raise TypeError + + dump_layer(self.conv0, f) + dump_layer(self.conv1, f) + dump_layer(self.conv2, f) + dump_layer(self.conv3, f) + dump_layer(self.conv4, f) + dump_layer(self.conv5, f) + + +class FeatureProjection(nn.Module): + def __init__(self, channels: int): + super().__init__() + self.norm = nn.LayerNorm(channels) + self.dropout = nn.Dropout(0.1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # [batch_size, channels, length] + x = self.norm(x.transpose(1, 2)).transpose(1, 2) + x = self.dropout(x) + return x + + +class PhoneExtractor(nn.Module): + def __init__( + self, + phone_channels: int = 128, + hidden_channels: int = 128, + backbone_embed_kernel_size: int = 9, + kernel_size: int = 17, + n_blocks: int = 20, + ): + super().__init__() + self.feature_extractor = FeatureExtractor(hidden_channels) + self.feature_projection = FeatureProjection(hidden_channels) + self.backbone = ConvNeXtStack( + in_channels=hidden_channels, + channels=hidden_channels, + intermediate_channels=hidden_channels * 3, + n_blocks=n_blocks, + delay=0, + embed_kernel_size=backbone_embed_kernel_size, + kernel_size=kernel_size, + use_mha=True, + ) + self.head = weight_norm(nn.Conv1d(hidden_channels, phone_channels, 1)) + + def forward( + self, x: torch.Tensor, return_stats: bool = True + ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]: + # x: [batch_size, 1, wav_length] + + stats = {} + + # [batch_size, 1, wav_length] -> [batch_size, feature_extractor_hidden_channels, length] + x = self.feature_extractor(x) + if return_stats: + stats["feature_norm"] = x.detach().norm(dim=1).mean() + # [batch_size, feature_extractor_hidden_channels, length] -> [batch_size, hidden_channels, length] + x = self.feature_projection(x) + # [batch_size, hidden_channels, length] + x = self.backbone(x) + # [batch_size, hidden_channels, length] -> [batch_size, phone_channels, length] + phone = self.head(F.gelu(x, approximate="tanh")) + + results = [phone] + if return_stats: + stats["code_norm"] = phone.detach().norm(dim=1).mean() + results.append(stats) + + if len(results) == 1: + return results[0] + return tuple(results) + + @torch.inference_mode() + def units(self, x: torch.Tensor) -> torch.Tensor: + # x: [batch_size, 1, wav_length] + + # [batch_size, 1, wav_length] -> [batch_size, phone_channels, length] + phone = self.forward(x, return_stats=False) + # [batch_size, phone_channels, length] -> [batch_size, length, phone_channels] + phone = phone.transpose(1, 2) + # [batch_size, length, phone_channels] + return phone + + def remove_weight_norm(self): + self.feature_extractor.remove_weight_norm() + remove_weight_norm(self.head) + + def merge_weights(self): + self.backbone.merge_weights() + + self.backbone.embed.bias.data += ( + ( + self.feature_projection.norm.bias.data[None, :, None] + * self.backbone.embed.weight.data # [o, i, k] + ) + .sum(1) + .sum(1) + ) + self.backbone.embed.weight.data *= self.feature_projection.norm.weight.data[ + None, :, None + ] + self.feature_projection.norm.bias.data[:] = 0.0 + self.feature_projection.norm.weight.data[:] = 1.0 + + def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): + if isinstance(f, (str, bytes, os.PathLike)): + with open(f, "wb") as f: + self.dump(f) + return + if not hasattr(f, "write"): + raise TypeError + + dump_layer(self.feature_extractor, f) + dump_layer(self.backbone, f) + dump_layer(self.head, f) + + +class VectorQuantizer(nn.Module): + def __init__( + self, + n_speakers: int, + codebook_size: int, + channels: int, + topk: int = 4, + training_time_vq: Literal["none", "self", "random"] = "none", + ): + super().__init__() + assert 1 <= topk <= codebook_size + self.n_speakers = n_speakers + self.codebook_size = codebook_size + self.channels = channels + self.topk = topk + self.training_time_vq = training_time_vq + + self.register_buffer( + "codebooks", + torch.empty(n_speakers, codebook_size, channels, dtype=torch.half), + ) + self.codebooks: torch.Tensor + + # VQ の適用箇所を変更しやすいように hook にしている + self._hook_handle: Optional[torch.utils.hooks.RemovableHandle] = None + self.target_speaker_ids: Optional[torch.Tensor] = None + + def _hook(_, __, output): + return self(output, self.target_speaker_ids) + + self._hook_fn = _hook + + @torch.no_grad() + def build_codebooks( + self, + collector_func: Callable, + target_layer: nn.Module, + inputs: Sequence[Iterable[torch.Tensor]], + kmeans_n_iters: int = 50, + ): + assert len(inputs) == self.n_speakers + assert self._hook_handle is None, "hook already installed" + device = next(self.buffers()).device + + for spk_id, inps in enumerate(tqdm(inputs, desc="Building codebooks")): + activations: list[torch.Tensor] = [] + + # TODO: データ多すぎる場合に間引く処理をする + + def _collect(_, __, output): + # output: [batch_size, channels, length] + activations.append(output.detach()) + + handle = target_layer.register_forward_hook(_collect) + for x in inps: + collector_func(x.to(device)) + handle.remove() + + if not activations: + raise RuntimeError(f"No activation collected for speaker {spk_id}") + + # [n_data, channels] + activations: torch.Tensor = torch.cat( + [ + a.transpose(1, 2).reshape(a.size(0) * a.size(2), self.channels) + for a in activations + ] + ) + activations = activations.float() + activations = F.normalize(activations, dim=1, eps=1e-6) + # [codebook_size, channels] + centers = ( + self._kmeans_plus_plus(activations, self.codebook_size, kmeans_n_iters) + if activations.size(0) >= self.codebook_size + else self._pad_replicate(activations, self.codebook_size) + ) + self.codebooks[spk_id] = centers.to(self.codebooks.dtype) + + def forward( + self, x: torch.Tensor, speaker_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: + batch_size, channels, length = x.size() + assert channels == self.channels + device = x.device + dtype = x.dtype + + if self.training: + if self.training_time_vq == "none": + return x + elif self.training_time_vq == "self": + if self.target_speaker_ids is None: + raise ValueError("target_speaker_ids is not set") + elif self.training_time_vq == "random": + speaker_ids = torch.randint( + 0, self.n_speakers, (batch_size,), device=device + ) + else: + raise ValueError(f"Unknown training_time_vq: {self.training_time_vq}") + else: + if speaker_ids is None: + return x + speaker_ids = speaker_ids.to(device) + + # [batch_size, channels, length] → [batch_size, length, channels] + q = F.normalize(x, dim=1, eps=1e-6) + codes = self.codebooks[speaker_ids].to(q.dtype) + # [batch_size, length, codebook_size] + sim = torch.einsum("bcl,bkc->blk", q, codes) + + # [batch_size, length, topk] + _, topk_idx = sim.topk(self.topk, dim=-1) + # [batch_size, length, codebook_size, channels] + expanded_codes = codes[:, None, :, :].expand(-1, length, -1, -1) + # [batch_size, length, topk, channels] + expanded_topk_idx = topk_idx[:, :, :, None].expand(-1, -1, -1, channels) + # [batch_size, length, topk, channels] + gathered = expanded_codes.gather(2, expanded_topk_idx) + # [batch_size, length, channels] + gathered = gathered.mean(2) + # [batch_size, channels, length] + return gathered.transpose(1, 2).to(dtype) + + def enable_hook(self, target_layer: nn.Module): + if self._hook_handle is not None: + raise RuntimeError("hook already installed") + self._hook_handle = target_layer.register_forward_hook(self._hook_fn) + + def disable_hook(self): + if self._hook_handle is None: + raise RuntimeError("hook not installed") + self._hook_handle.remove() + self._hook_handle = None + + def set_target_speaker_ids(self, speaker_ids: Optional[torch.Tensor]): + # この話者が使われる条件は forward() を参照 + self.target_speaker_ids = speaker_ids + + @staticmethod + def _pad_replicate(x: torch.Tensor, n: int) -> torch.Tensor: + # データ数が n に満たないとき適当に複製して埋める + idx = torch.arange(n, device=x.device) % x.size(0) + return x[idx] + + @staticmethod + def _kmeans_plus_plus( + x: torch.Tensor, n_clusters: int, n_iters: int = 50 + ) -> torch.Tensor: + n_data, _ = x.size() + center_indices = [torch.randint(0, n_data, ()).item()] + min_distances = torch.full((n_data,), math.inf, device=x.device) + for _ in range(1, n_clusters): + last_center_index = center_indices[-1] + min_distances = min_distances.minimum( + torch.cdist(x, x[last_center_index : last_center_index + 1]) + .float() + .square_() + .squeeze_(1) + ) + probs = min_distances / (min_distances.sum() + 1e-12) + center_indices.append(torch.multinomial(probs, 1).item()) + centers = x[center_indices] + del min_distances, probs + for _ in range(n_iters): + distances = torch.cdist(x, centers) # [n_data, n_clusters] + labels = distances.argmin(1) # [n_data] + # [n_clusters, dim] + new_centers = torch.zeros_like(centers).index_add_(0, labels, x) + # [n_clusters] + counts = labels.bincount(minlength=n_clusters) + if (counts == 0).sum().item() != 0: + # TODO: 割り当てがないクラスタの処理 + warnings.warn("Some clusters have no assigned data points.") + new_centers /= counts[:, None].clamp_(min=1).float() + centers = new_centers + return centers + + + +def extract_pitch_features( + y: torch.Tensor, # [..., wav_length] + hop_length: int = 160, # 10ms + win_length: int = 560, # 35ms + max_corr_period: int = 256, # 16ms, 62.5Hz (16000 / 256) + corr_win_length: int = 304, # 19ms + instfreq_features_cutoff_bin: int = 64, # 1828Hz (16000 * 64 / 560) +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert max_corr_period + corr_win_length == win_length + + # パディングする + padding_length = (win_length - hop_length) // 2 + y = F.pad(y, (padding_length, padding_length)) + + # フレームにする + # [..., win_length, n_frames] + y_frames = y.unfold(-1, win_length, hop_length).transpose_(-2, -1) + + # 複素スペクトログラム + # Complex[..., (win_length // 2 + 1), n_frames] + spec: torch.Tensor = torch.fft.rfft(y_frames, n=win_length, dim=-2) + + # Complex[..., instfreq_features_cutoff_bin, n_frames] + spec = spec[..., :instfreq_features_cutoff_bin, :] + + # 対数パワースペクトログラム + log_power_spec = spec.abs().add_(1e-5).log10_() + + # 瞬時位相の時間差分 + # 時刻 0 の値は 0 + delta_spec = spec[..., :, 1:] * spec[..., :, :-1].conj() + delta_spec /= delta_spec.abs().add_(1e-5) + delta_spec = torch.cat( + [torch.zeros_like(delta_spec[..., :, :1]), delta_spec], dim=-1 + ) + + # [..., instfreq_features_cutoff_bin * 3, n_frames] + instfreq_features = torch.cat( + [log_power_spec, delta_spec.real, delta_spec.imag], dim=-2 + ) + + # 自己相関 + # 元々これに 2.0 / corr_win_length を掛けて使おうと思っていたが、 + # この値は振幅の 2 乗に比例していて、NN に入力するために良い感じに分散を + # 標準化する方法が思いつかなかったのでやめた + flipped_y_frames = y_frames.flip((-2,)) + a = torch.fft.rfft(flipped_y_frames, n=win_length, dim=-2) + b = torch.fft.rfft(y_frames[..., -corr_win_length:, :], n=win_length, dim=-2) + # [..., max_corr_period, n_frames] + corr = torch.fft.irfft(a * b, n=win_length, dim=-2)[..., corr_win_length:, :] + + # エネルギー項 + energy = flipped_y_frames.square_().cumsum_(-2) + energy0 = energy[..., corr_win_length - 1 : corr_win_length, :] + energy = energy[..., corr_win_length:, :] - energy[..., :-corr_win_length, :] + + # Difference function + corr_diff = (energy0 + energy).sub_(corr.mul_(2.0)) + assert corr_diff.min() >= -1e-3, corr_diff.min() + corr_diff.clamp_(min=0.0) # 計算誤差対策 + + # 標準化 + corr_diff *= 2.0 / corr_win_length + corr_diff.sqrt_() + + # 変換モデルへの入力用のエネルギー + energy = ( + (y_frames * torch.signal.windows.cosine(win_length, device=y.device)[..., None]) + .square_() + .sum(-2, keepdim=True) + ) + + energy.clamp_(min=1e-3).log10_() # >= -3, 振幅 1 の正弦波なら大体 2.15 + energy *= 0.5 # >= -1.5, 振幅 1 の正弦波なら大体 1.07, 1 の差は振幅で 20dB の差 + + return ( + instfreq_features, # [..., instfreq_features_cutoff_bin * 3, n_frames] + corr_diff, # [..., max_corr_period, n_frames] + energy, # [..., 1, n_frames] + ) + + +class PitchEstimator(nn.Module): + def __init__( + self, + input_instfreq_channels: int = 192, + input_corr_channels: int = 256, + pitch_bins: int = 448, + channels: int = 192, + intermediate_channels: int = 192 * 2, + n_blocks: int = 9, + delay: int = 1, # 10ms, 特徴抽出と合わせると 22.5ms + embed_kernel_size: int = 3, + kernel_size: int = 33, + pitch_bins_per_octave: int = 96, + ): + super().__init__() + self.pitch_bins_per_octave = pitch_bins_per_octave + + self.instfreq_embed_0 = nn.Conv1d(input_instfreq_channels, channels, 1) + self.instfreq_embed_1 = nn.Conv1d(channels, channels, 1) + self.corr_embed_0 = nn.Conv1d(input_corr_channels, channels, 1) + self.corr_embed_1 = nn.Conv1d(channels, channels, 1) + self.backbone = ConvNeXtStack( + channels, + channels, + intermediate_channels, + n_blocks, + delay, + embed_kernel_size, + kernel_size, + enable_scaling=True, + ) + self.head = nn.Conv1d(channels, pitch_bins, 1) + + def forward(self, wav: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # wav: [batch_size, 1, wav_length] + + # [batch_size, input_instfreq_channels, length], + # [batch_size, input_corr_channels, length] + with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): + instfreq_features, corr_diff, energy = extract_pitch_features( + wav.squeeze(1), + hop_length=160, + win_length=560, + max_corr_period=256, + corr_win_length=304, + instfreq_features_cutoff_bin=64, + ) + instfreq_features = F.gelu( + self.instfreq_embed_0(instfreq_features), approximate="tanh" + ) + instfreq_features = self.instfreq_embed_1(instfreq_features) + corr_diff = F.gelu(self.corr_embed_0(corr_diff), approximate="tanh") + corr_diff = self.corr_embed_1(corr_diff) + # [batch_size, channels, length] + x = F.gelu(instfreq_features + corr_diff, approximate="tanh") + x = self.backbone(x) + # [batch_size, pitch_bins, length] + x = self.head(x) + return x, energy + + def sample_pitch( + self, pitch: torch.Tensor, band_width: int = 4, return_features: bool = False + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + # pitch: [batch_size, pitch_bins, length] + # 返されるピッチの値には 0 は含まれない + batch_size, pitch_bins, length = pitch.size() + pitch = pitch.softmax(1) + if return_features: + unvoiced_proba = pitch[:, :1, :].clone() + pitch[:, 0, :] = -100.0 + pitch = ( + pitch.transpose(1, 2).contiguous().view(batch_size * length, 1, pitch_bins) + ) + band_pitch = F.conv1d( + pitch, + torch.ones((1, 1, 1), device=pitch.device).expand(1, 1, band_width), + ) + # [batch_size * length, 1, pitch_bins - band_width + 1] -> Long[batch_size * length, 1] + quantized_band_pitch = band_pitch.argmax(2) + if return_features: + # [batch_size * length, 1] + band_proba = band_pitch.gather(2, quantized_band_pitch[:, :, None]) + # [batch_size * length, 1] + half_pitch_band_proba = band_pitch.gather( + 2, + (quantized_band_pitch - self.pitch_bins_per_octave).clamp_(min=1)[ + :, :, None + ], + ) + half_pitch_band_proba[ + quantized_band_pitch <= self.pitch_bins_per_octave + ] = 0.0 + half_pitch_proba = (half_pitch_band_proba / (band_proba + 1e-6)).view( + batch_size, 1, length + ) + # [batch_size * length, 1] + double_pitch_band_proba = band_pitch.gather( + 2, + (quantized_band_pitch + self.pitch_bins_per_octave).clamp_( + max=pitch_bins - band_width + )[:, :, None], + ) + double_pitch_band_proba[ + quantized_band_pitch + > pitch_bins - band_width - self.pitch_bins_per_octave + ] = 0.0 + double_pitch_proba = (double_pitch_band_proba / (band_proba + 1e-6)).view( + batch_size, 1, length + ) + # Long[1, pitch_bins] + mask = torch.arange(pitch_bins, device=pitch.device)[None, :] + # bool[batch_size * length, pitch_bins] + mask = (quantized_band_pitch <= mask) & ( + mask < quantized_band_pitch + band_width + ) + # Long[batch_size, length] + quantized_pitch = (pitch.squeeze(1) * mask).argmax(1).view(batch_size, length) + + if return_features: + features = torch.cat( + [unvoiced_proba, half_pitch_proba, double_pitch_proba], dim=1 + ) + # Long[batch_size, length], [batch_size, 3, length] + return quantized_pitch, features + else: + return quantized_pitch + + def merge_weights(self): + self.backbone.merge_weights() + + def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): + if isinstance(f, (str, bytes, os.PathLike)): + with open(f, "wb") as f: + self.dump(f) + return + if not hasattr(f, "write"): + raise TypeError + + dump_layer(self.instfreq_embed_0, f) + dump_layer(self.instfreq_embed_1, f) + dump_layer(self.corr_embed_0, f) + dump_layer(self.corr_embed_1, f) + dump_layer(self.backbone, f) + dump_layer(self.head, f) + + + +def overlap_add( + ir_amp: torch.Tensor, + ir_phase: torch.Tensor, + window: torch.Tensor, + pitch: torch.Tensor, + hop_length: int = 240, + delay: int = 0, + sr: float = 24000.0, +) -> torch.Tensor: + batch_size, ir_length, length = ir_amp.size() + ir_length = (ir_length - 1) * 2 + assert ir_phase.size() == ir_amp.size() + assert window.size() == (ir_length,), (window.size(), ir_amp.size()) + assert pitch.size() == (batch_size, length * hop_length) + assert 0 <= delay < ir_length, (delay, ir_length) + # 正規化角周波数 [2π rad] + normalized_freq = pitch / sr + # 初期位相 [2π rad] をランダムに設定 + normalized_freq[:, 0] = torch.rand(batch_size, device=pitch.device) + with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): + phase = (normalized_freq.double().cumsum_(1) % 1.0).float() + # 重ねる箇所を求める + # [n_pitchmarks], [n_pitchmarks] + indices0, indices1 = torch.nonzero(phase[:, :-1] > phase[:, 1:], as_tuple=True) + # 重ねる箇所の小数部分 (位相の遅れ) を求める + numer = 1.0 - phase[indices0, indices1] + # [n_pitchmarks] + fractional_part = numer / (numer + phase[indices0, indices1 + 1]) + # 重ねる値を求める + # Complex[n_pitchmarks, ir_length / 2 + 1] + ir_amp = ir_amp[indices0, :, indices1 // hop_length] + ir_phase = ir_phase[indices0, :, indices1 // hop_length] + # 位相遅れの量 [rad] + # [n_pitchmarks, ir_length / 2 + 1] + delay_phase = ( + torch.arange(ir_length // 2 + 1, device=pitch.device, dtype=torch.float32)[ + None, : + ] + * (-math.tau / ir_length) + * fractional_part[:, None] + ) + # Complex[n_pitchmarks, ir_length / 2 + 1] + spec = torch.polar(ir_amp, ir_phase + delay_phase) + # [n_pitchmarks, ir_length] + ir = torch.fft.irfft(spec, n=ir_length, dim=1) + ir *= window + + # 加算する値をサンプル単位にばらす + # [n_pitchmarks * ir_length] + ir = ir.ravel() + # Long[n_pitchmarks * ir_length] + indices0 = indices0[:, None].expand(-1, ir_length).ravel() + # Long[n_pitchmarks * ir_length] + indices1 = ( + indices1[:, None] + torch.arange(ir_length, device=pitch.device) + ).ravel() + + # overlap-add する + overlap_added_signal = torch.zeros( + (batch_size, length * hop_length + ir_length), device=pitch.device + ) + overlap_added_signal.index_put_((indices0, indices1), ir, accumulate=True) + overlap_added_signal = overlap_added_signal[:, delay : -ir_length + delay] + + return overlap_added_signal + + +def generate_noise( + aperiodicity: torch.Tensor, delay: int = 0 +) -> tuple[torch.Tensor, torch.Tensor]: + # aperiodicity: [batch_size, hop_length, length] + batch_size, hop_length, length = aperiodicity.size() + excitation = torch.rand( + batch_size, (length + 1) * hop_length, device=aperiodicity.device + ) + excitation -= 0.5 + n_fft = 2 * hop_length + # 矩形窓で分析 + # Complex[batch_size, hop_length + 1, length] + noise = torch.stft( + excitation, + n_fft=n_fft, + hop_length=hop_length, + window=torch.ones(n_fft, device=excitation.device), + center=False, + return_complex=True, + ) + assert noise.size(2) == aperiodicity.size(2) + noise[:, 0, :] = 0.0 + noise[:, 1:, :] *= aperiodicity + # ハン窓で合成 + # torch.istft は最適合成窓が使われるので使えないことに注意 + # [batch_size, 2 * hop_length, length] + noise = torch.fft.irfft(noise, n=2 * hop_length, dim=1) + noise *= torch.hann_window(2 * hop_length, device=noise.device)[None, :, None] + # [batch_size, (length + 1) * hop_length] + noise = F.fold( + noise, + (1, (length + 1) * hop_length), + (1, 2 * hop_length), + stride=(1, hop_length), + ).squeeze_((1, 2)) + + assert delay < hop_length + noise = noise[:, delay : -hop_length + delay] + excitation = excitation[:, delay : -hop_length + delay] + return noise, excitation # [batch_size, length * hop_length] + + +D4C_PREVENT_ZERO_DIVISION = True # False にすると本家の処理 + + +def interp(x: torch.Tensor, y: torch.Tensor, xi: torch.Tensor) -> torch.Tensor: + # x が単調増加で等間隔と仮定 + # 外挿は起こらないと仮定 + x = torch.as_tensor(x) + y = torch.as_tensor(y) + xi = torch.as_tensor(xi) + if xi.ndim < y.ndim: + diff_ndim = y.ndim - xi.ndim + xi = xi.view(tuple([1] * diff_ndim) + xi.size()) + if xi.size()[:-1] != y.size()[:-1]: + xi = xi.expand(y.size()[:-1] + (xi.size(-1),)) + assert (x.min(-1).values == x[..., 0]).all() + assert (x.max(-1).values == x[..., -1]).all() + assert (xi.min(-1).values >= x[..., 0]).all() + assert (xi.max(-1).values <= x[..., -1]).all() + delta_x = (x[..., -1].double() - x[..., 0].double()) / (x.size(-1) - 1.0) + delta_x = delta_x.to(x.dtype) + xi = (xi - x[..., :1]).div_(delta_x[..., None]) + xi_base = xi.floor() + xi_fraction = xi.sub_(xi_base) + xi_base = xi_base.long() + delta_y = y.diff(dim=-1, append=y[..., -1:]) + yi = y.gather(-1, xi_base) + delta_y.gather(-1, xi_base) * xi_fraction + return yi + + +def linear_smoothing( + group_delay: torch.Tensor, sr: int, n_fft: int, width: torch.Tensor +) -> torch.Tensor: + group_delay = torch.as_tensor(group_delay) + assert group_delay.size(-1) == n_fft // 2 + 1 + width = torch.as_tensor(width) + boundary = (width.max() * n_fft / sr).long() + 1 + + dtype = group_delay.dtype + device = group_delay.device + fft_resolution = sr / n_fft + mirroring_freq_axis = ( + torch.arange(-boundary, n_fft // 2 + 1 + boundary, dtype=dtype, device=device) + .add(0.5) + .mul(fft_resolution) + ) + if group_delay.ndim == 1: + mirroring_spec = F.pad( + group_delay[None], (boundary, boundary), mode="reflect" + ).squeeze_(0) + elif group_delay.ndim >= 4: + shape = group_delay.size() + mirroring_spec = F.pad( + group_delay.view(math.prod(shape[:-1]), group_delay.size(-1)), + (boundary, boundary), + mode="reflect", + ).view(shape[:-1] + (shape[-1] + 2 * boundary,)) + else: + mirroring_spec = F.pad(group_delay, (boundary, boundary), mode="reflect") + mirroring_segment = mirroring_spec.mul(fft_resolution).cumsum_(-1) + center_freq = torch.arange(n_fft // 2 + 1, dtype=dtype, device=device).mul_( + fft_resolution + ) + low_freq = center_freq - width[..., None] * 0.5 + high_freq = center_freq + width[..., None] * 0.5 + levels = interp( + mirroring_freq_axis, mirroring_segment, torch.cat([low_freq, high_freq], dim=-1) + ) + low_levels, high_levels = levels.split([n_fft // 2 + 1] * 2, dim=-1) + smoothed = (high_levels - low_levels).div_(width[..., None]) + return smoothed + + +def dc_correction( + spec: torch.Tensor, sr: int, n_fft: int, f0: torch.Tensor +) -> torch.Tensor: + spec = torch.as_tensor(spec) + f0 = torch.as_tensor(f0) + dtype = spec.dtype + device = spec.device + + upper_limit = 2 + (f0 * (n_fft / sr)).long() + max_upper_limit = upper_limit.max() + upper_limit_mask = ( + torch.arange(max_upper_limit - 1, device=device) < (upper_limit - 1)[..., None] + ) + low_freq_axis = torch.arange(max_upper_limit + 1, dtype=dtype, device=device) * ( + sr / n_fft + ) + low_freq_replica = interp( + f0[..., None] - low_freq_axis.flip(-1), + spec[..., : max_upper_limit + 1].flip(-1), + low_freq_axis[..., : max_upper_limit - 1] * upper_limit_mask, + ) + output = spec.clone() + output[..., : max_upper_limit - 1] += low_freq_replica * upper_limit_mask + return output + + +def nuttall(n: int, device: torch.types.Device) -> torch.Tensor: + t = torch.linspace(0, math.tau, n, device=device) + coefs = torch.tensor([0.355768, -0.487396, 0.144232, -0.012604], device=device) + terms = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device) + cos_matrix = (terms[:, None] * t).cos_() # [4, n] + window = coefs.matmul(cos_matrix) + return window + + +def get_windowed_waveform( + x: torch.Tensor, + sr: int, + f0: torch.Tensor, + position: torch.Tensor, + half_window_length_ratio: float, + window_type: Literal["hann", "blackman"], + n_fft: int, +) -> tuple[torch.Tensor, torch.Tensor]: + x = torch.as_tensor(x) + f0 = torch.as_tensor(f0) + position = torch.as_tensor(position) + + current_sample = position * sr + # [...] + half_window_length = (half_window_length_ratio * sr / f0).add_(0.5).long() + # [..., fft_size] + base_index = -half_window_length[..., None] + torch.arange(n_fft, device=x.device) + base_index_mask = base_index <= half_window_length[..., None] + # [..., fft_size] + safe_index = ((current_sample + 0.501).long()[..., None] + base_index).clamp_( + 0, x.size(-1) - 1 + ) + # [..., fft_size] + time_axis = base_index.to(x.dtype).div_(half_window_length_ratio) + # [...] + normalized_f0 = math.pi / sr * f0 + # [..., fft_size] + phase = time_axis.mul_(normalized_f0[..., None]) + + if window_type == "hann": + window = phase.cos_().mul_(0.5).add_(0.5) + elif window_type == "blackman": + window = phase.mul(2.0).cos_().mul_(0.08).add_(phase.cos().mul_(0.5)).add_(0.42) + else: + assert False + window *= base_index_mask + + prefix_shape = tuple( + max(x_size, i_size) for x_size, i_size in zip(x.size(), safe_index.size()) + )[:-1] + waveform = ( + x.expand(prefix_shape + (-1,)) + .gather(-1, safe_index.expand(prefix_shape + (-1,))) + .mul_(window) + ) + if not D4C_PREVENT_ZERO_DIVISION: + waveform += torch.randn_like(window).mul_(1e-12) + waveform *= base_index_mask + waveform -= window * waveform.sum(-1, keepdim=True).div_( + window.sum(-1, keepdim=True) + ) + return waveform, window + + +def get_centroid(x: torch.Tensor, n_fft: int) -> torch.Tensor: + x = torch.as_tensor(x) + if D4C_PREVENT_ZERO_DIVISION: + x = x / x.norm(dim=-1, keepdim=True).clamp(min=6e-8) + else: + x = x / x.norm(dim=-1, keepdim=True) + spec0 = torch.fft.rfft(x, n_fft) + spec1 = torch.fft.rfft( + x * torch.arange(1, x.size(-1) + 1, dtype=x.dtype, device=x.device).div_(n_fft), + n_fft, + ) + centroid = spec0.real * spec1.real + spec0.imag * spec1.imag + return centroid + + +def get_static_centroid( + x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int +) -> torch.Tensor: + """First step: calculation of temporally static parameters on basis of group delay""" + x1, _ = get_windowed_waveform( + x, sr, f0, position + 0.25 / f0, 2.0, "blackman", n_fft + ) + x2, _ = get_windowed_waveform( + x, sr, f0, position - 0.25 / f0, 2.0, "blackman", n_fft + ) + centroid1 = get_centroid(x1, n_fft) + centroid2 = get_centroid(x2, n_fft) + return dc_correction(centroid1 + centroid2, sr, n_fft, f0) + + +def get_smoothed_power_spec( + x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int +) -> tuple[torch.Tensor, torch.Tensor]: + x = torch.as_tensor(x) + f0 = torch.as_tensor(f0) + x, window = get_windowed_waveform(x, sr, f0, position, 2.0, "hann", n_fft) + window_weight = window.square().sum(-1, keepdim=True) + rms = x.square().sum(-1, keepdim=True).div_(window_weight).sqrt_() + if D4C_PREVENT_ZERO_DIVISION: + x = x / (rms * math.sqrt(n_fft)).clamp_(min=6e-8) + smoothed_power_spec = torch.fft.rfft(x, n_fft).abs().square_() + smoothed_power_spec = dc_correction(smoothed_power_spec, sr, n_fft, f0) + smoothed_power_spec = linear_smoothing(smoothed_power_spec, sr, n_fft, f0) + return smoothed_power_spec, rms.detach().squeeze(-1) + + +def get_static_group_delay( + static_centroid: torch.Tensor, + smoothed_power_spec: torch.Tensor, + sr: int, + f0: torch.Tensor, + n_fft: int, +) -> torch.Tensor: + """Second step: calculation of parameter shaping""" + if D4C_PREVENT_ZERO_DIVISION: + smoothed_power_spec = smoothed_power_spec.clamp(min=6e-8) + static_group_delay = static_centroid / smoothed_power_spec # t_g + static_group_delay = linear_smoothing( + static_group_delay, sr, n_fft, f0 * 0.5 + ) # t_gs + smoothed_group_delay = linear_smoothing(static_group_delay, sr, n_fft, f0) # t_gb + static_group_delay = static_group_delay - smoothed_group_delay # t_D + return static_group_delay + + +def get_coarse_aperiodicity( + group_delay: torch.Tensor, + sr: int, + n_fft: int, + freq_interval: int, + n_aperiodicities: int, + window: torch.Tensor, +) -> torch.Tensor: + """Third step: estimation of band-aperiodicity""" + group_delay = torch.as_tensor(group_delay) + window = torch.as_tensor(window) + boundary = int(round(n_fft * 8 / window.size(-1))) + half_window_length = window.size(-1) // 2 + coarse_aperiodicity = torch.empty( + group_delay.size()[:-1] + (n_aperiodicities,), + dtype=group_delay.dtype, + device=group_delay.device, + ) + for i in range(n_aperiodicities): + center = freq_interval * (i + 1) * n_fft // sr + segment = ( + group_delay[ + ..., center - half_window_length : center + half_window_length + 1 + ] + * window + ) + power_spec: torch.Tensor = torch.fft.rfft(segment, n_fft).abs().square_() + cumulative_power_spec = power_spec.sort(-1).values.cumsum_(-1) + if D4C_PREVENT_ZERO_DIVISION: + cumulative_power_spec.clamp_(min=6e-8) + coarse_aperiodicity[..., i] = ( + cumulative_power_spec[..., n_fft // 2 - boundary - 1] + / cumulative_power_spec[..., -1] + ) + coarse_aperiodicity.log10_().mul_(10.0) + return coarse_aperiodicity + + +def d4c_love_train( + x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, threshold: float +) -> int: + x = torch.as_tensor(x) + position = torch.as_tensor(position) + f0: torch.Tensor = torch.as_tensor(f0) + vuv = f0 != 0 + lowest_f0 = 40 + f0 = f0.clamp(min=lowest_f0) + n_fft = 1 << (3 * sr // lowest_f0).bit_length() + boundary0 = (100 * n_fft - 1) // sr + 1 + boundary1 = (4000 * n_fft - 1) // sr + 1 + boundary2 = (7900 * n_fft - 1) // sr + 1 + waveform, _ = get_windowed_waveform(x, sr, f0, position, 1.5, "blackman", n_fft) + power_spec = torch.fft.rfft(waveform, n_fft).abs().square_() + power_spec[..., : boundary0 + 1] = 0.0 + cumulative_spec = power_spec.cumsum_(-1) + vuv = vuv & ( + cumulative_spec[..., boundary1] > threshold * cumulative_spec[..., boundary2] + ) + return vuv + + +def d4c_general_body( + x: torch.Tensor, + sr: int, + f0: torch.Tensor, + freq_interval: int, + position: torch.Tensor, + n_fft: int, + n_aperiodicities: int, + window: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + static_centroid = get_static_centroid(x, sr, f0, position, n_fft) + smoothed_power_spec, rms = get_smoothed_power_spec(x, sr, f0, position, n_fft) + static_group_delay = get_static_group_delay( + static_centroid, smoothed_power_spec, sr, f0, n_fft + ) + coarse_aperiodicity = get_coarse_aperiodicity( + static_group_delay, sr, n_fft, freq_interval, n_aperiodicities, window + ) + coarse_aperiodicity.add_((f0[..., None] - 100.0).div_(50.0)).clamp_(max=0.0) + return coarse_aperiodicity, rms + + +def d4c( + x: torch.Tensor, + f0: torch.Tensor, + t: torch.Tensor, + sr: int, + threshold: float = 0.85, + n_fft_spec: Optional[int] = None, + coarse_only: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """Adapted from https://github.com/tuanad121/Python-WORLD/blob/master/world/d4c.py""" + FLOOR_F0 = 71 + FLOOR_F0_D4C = 47 + UPPER_LIMIT = 15000 + FREQ_INTERVAL = 3000 + + assert sr == int(sr) + sr = int(sr) + assert sr % 2 == 0 + x = torch.as_tensor(x) + f0 = torch.as_tensor(f0) + temporal_positions = torch.as_tensor(t) + + n_fft_d4c = 1 << (4 * sr // FLOOR_F0_D4C).bit_length() + if n_fft_spec is None: + n_fft_spec = 1 << (3 * sr // FLOOR_F0).bit_length() + n_aperiodicities = min(UPPER_LIMIT, sr // 2 - FREQ_INTERVAL) // FREQ_INTERVAL + assert n_aperiodicities >= 1 + window_length = FREQ_INTERVAL * n_fft_d4c // sr * 2 + 1 + window = nuttall(window_length, device=x.device) + freq_axis = torch.arange(n_fft_spec // 2 + 1, device=x.device) * (sr / n_fft_spec) + + coarse_aperiodicity, rms = d4c_general_body( + x[..., None, :], + sr, + f0.clamp(min=FLOOR_F0_D4C), + FREQ_INTERVAL, + temporal_positions, + n_fft_d4c, + n_aperiodicities, + window, + ) + if coarse_only: + return coarse_aperiodicity, rms + + even_coarse_axis = ( + torch.arange(n_aperiodicities + 3, device=x.device) * FREQ_INTERVAL + ) + assert even_coarse_axis[-2] <= sr // 2 < even_coarse_axis[-1], sr + coarse_axis_low = ( + torch.arange(n_aperiodicities + 1, dtype=torch.float, device=x.device) + * FREQ_INTERVAL + ) + aperiodicity_low = interp( + coarse_axis_low, + F.pad(coarse_aperiodicity, (1, 0), value=-60.0), + freq_axis[freq_axis < n_aperiodicities * FREQ_INTERVAL], + ) + coarse_axis_high = torch.tensor( + [n_aperiodicities * FREQ_INTERVAL, sr * 0.5], device=x.device + ) + aperiodicity_high = interp( + coarse_axis_high, + F.pad(coarse_aperiodicity[..., -1:], (0, 1), value=-1e-12), + freq_axis[freq_axis >= n_aperiodicities * FREQ_INTERVAL], + ) + aperiodicity = torch.cat([aperiodicity_low, aperiodicity_high], -1) + aperiodicity = 10.0 ** (aperiodicity / 20.0) + vuv = d4c_love_train(x[..., None, :], sr, f0, temporal_positions, threshold) + aperiodicity = torch.where(vuv[..., None], aperiodicity, 1 - 1e-12) + + return aperiodicity, coarse_aperiodicity + + +class Vocoder(nn.Module): + def __init__( + self, + channels: int, + speaker_embedding_channels: int = 128, + hop_length: int = 240, + n_pre_blocks: int = 4, + out_sample_rate: float = 24000.0, + ): + super().__init__() + self.hop_length = hop_length + self.out_sample_rate = out_sample_rate + + self.prenet = ConvNeXtStack( + in_channels=channels, + channels=channels, + intermediate_channels=channels * 2, + n_blocks=n_pre_blocks, + delay=2, # 20ms 遅延 + embed_kernel_size=7, + kernel_size=33, + enable_scaling=True, + use_mha=True, + cross_attention=True, + kv_channels=speaker_embedding_channels, + ) + self.ir_generator = ConvNeXtStack( + in_channels=channels, + channels=channels, + intermediate_channels=channels * 2, + n_blocks=2, + delay=0, + embed_kernel_size=3, + kernel_size=33, + use_weight_standardization=True, + enable_scaling=True, + ) + self.ir_generator_post = WSConv1d(channels, 512, 1) + self.register_buffer("ir_scale", torch.tensor(1.0)) + self.ir_window = nn.Parameter(torch.ones(512)) + self.aperiodicity_generator = ConvNeXtStack( + in_channels=channels, + channels=channels, + intermediate_channels=channels * 2, + n_blocks=1, + delay=0, + embed_kernel_size=3, + kernel_size=33, + use_weight_standardization=True, + enable_scaling=True, + ) + self.aperiodicity_generator_post = WSConv1d(channels, hop_length, 1, bias=False) + self.register_buffer("aperiodicity_scale", torch.tensor(0.005)) + self.post_filter_generator = ConvNeXtStack( + in_channels=channels, + channels=channels, + intermediate_channels=channels * 2, + n_blocks=1, + delay=0, + embed_kernel_size=3, + kernel_size=33, + use_weight_standardization=True, + enable_scaling=True, + ) + self.post_filter_generator_post = WSConv1d(channels, 512, 1, bias=False) + self.register_buffer("post_filter_scale", torch.tensor(0.01)) + + def forward( + self, x: torch.Tensor, pitch: torch.Tensor, speaker_embedding: torch.Tensor + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + # x: [batch_size, channels, length] + # pitch: [batch_size, length] + # speaker_embedding: [batch_size, speaker_embedding_length, speaker_embedding_channels] + batch_size, _, length = x.size() + + x = self.prenet(x, speaker_embedding) + ir = self.ir_generator(x) + ir = F.silu(ir, inplace=True) + # [batch_size, 512, length] + ir = self.ir_generator_post(ir) + ir *= self.ir_scale + ir_amp = ir[:, : ir.size(1) // 2 + 1, :].exp() + ir_phase = F.pad(ir[:, ir.size(1) // 2 + 1 :, :], (0, 0, 1, 1)) + ir_phase[:, 1::2, :] += math.pi + # TODO: 直流成分が正の値しか取れないのを修正する + + # 最近傍補間 + # [batch_size, length * hop_length] + pitch = torch.repeat_interleave(pitch, self.hop_length, dim=1) + + # [batch_size, length * hop_length] + periodic_signal = overlap_add( + ir_amp, + ir_phase, + self.ir_window, + pitch, + self.hop_length, + delay=0, + sr=self.out_sample_rate, + ) + + aperiodicity = self.aperiodicity_generator(x) + aperiodicity = F.silu(aperiodicity, inplace=True) + # [batch_size, hop_length, length] + aperiodicity = self.aperiodicity_generator_post(aperiodicity) + aperiodicity *= self.aperiodicity_scale + # [batch_size, length * hop_length], [batch_size, length * hop_length] + aperiodic_signal, noise_excitation = generate_noise(aperiodicity, delay=0) + + post_filter = self.post_filter_generator(x) + post_filter = F.silu(post_filter, inplace=True) + # [batch_size, 512, length] + post_filter = self.post_filter_generator_post(post_filter) + post_filter *= self.post_filter_scale + post_filter[:, 0, :] += 1.0 + # [batch_size, length, 512] + post_filter = post_filter.transpose(1, 2) + with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): + periodic_signal = periodic_signal.float() + aperiodic_signal = aperiodic_signal.float() + post_filter = post_filter.float() + post_filter = torch.fft.rfft(post_filter, n=768) + + # [batch_size, length, 768] + periodic_signal = torch.fft.irfft( + torch.fft.rfft( + periodic_signal.view(batch_size, length, self.hop_length), n=768 + ) + * post_filter, + n=768, + ) + aperiodic_signal = torch.fft.irfft( + torch.fft.rfft( + aperiodic_signal.view(batch_size, length, self.hop_length), n=768 + ) + * post_filter, + n=768, + ) + periodic_signal = F.fold( + periodic_signal.transpose(1, 2), + (1, (length - 1) * self.hop_length + 768), + (1, 768), + stride=(1, self.hop_length), + ).squeeze_((1, 2)) + aperiodic_signal = F.fold( + aperiodic_signal.transpose(1, 2), + (1, (length - 1) * self.hop_length + 768), + (1, 768), + stride=(1, self.hop_length), + ).squeeze_((1, 2)) + periodic_signal = periodic_signal[:, 120 : 120 + length * self.hop_length] + aperiodic_signal = aperiodic_signal[:, 120 : 120 + length * self.hop_length] + noise_excitation = noise_excitation[:, 120:] + + # TODO: compensation の正確さが怪しくなってくる。今も本当に必要なのか? + + # [batch_size, 1, length * hop_length] + y_g_hat = (periodic_signal + aperiodic_signal)[:, None, :] + + return y_g_hat, { + "periodic_signal": periodic_signal.detach(), + "aperiodic_signal": aperiodic_signal.detach(), + "noise_excitation": noise_excitation.detach(), + } + + def merge_weights(self): + self.prenet.merge_weights() + self.ir_generator.merge_weights() + self.ir_generator_post.merge_weights() + self.aperiodicity_generator.merge_weights() + self.aperiodicity_generator_post.merge_weights() + self.ir_generator_post.weight.data *= self.ir_scale + self.ir_generator_post.bias.data *= self.ir_scale + self.ir_scale.fill_(1.0) + self.aperiodicity_generator_post.weight.data *= self.aperiodicity_scale + self.aperiodicity_scale.fill_(1.0) + self.post_filter_generator.merge_weights() + self.post_filter_generator_post.merge_weights() + self.post_filter_generator_post.weight.data *= self.post_filter_scale + self.post_filter_scale.fill_(1.0) + + def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): + if isinstance(f, (str, bytes, os.PathLike)): + with open(f, "wb") as f: + self.dump(f) + return + if not hasattr(f, "write"): + raise TypeError + + dump_layer(self.prenet, f) + dump_layer(self.ir_generator, f) + dump_layer(self.ir_generator_post, f) + dump_layer(self.ir_window, f) + dump_layer(self.aperiodicity_generator, f) + dump_layer(self.aperiodicity_generator_post, f) + dump_layer(self.post_filter_generator, f) + dump_layer(self.post_filter_generator_post, f) + + +def compute_loudness( + x: torch.Tensor, sr: int, win_lengths: list[int] +) -> list[torch.Tensor]: + # x: [batch_size, wav_length] + assert x.ndim == 2 + n_fft = 2048 + chunk_length = n_fft // 2 + n_taps = chunk_length + 1 + + results = [] + with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): + if not hasattr(compute_loudness, "filter"): + compute_loudness.filter = {} + if sr not in compute_loudness.filter: + ir = torch.zeros(n_taps, device=x.device, dtype=torch.double) + ir[0] = 0.5 + ir = torchaudio.functional.treble_biquad( + ir, sr, 4.0, 1500.0, 1.0 / math.sqrt(2) + ) + ir = torchaudio.functional.highpass_biquad(ir, sr, 38.0, 0.5) + ir *= 2.0 + compute_loudness.filter[sr] = torch.fft.rfft(ir, n=n_fft).to( + torch.complex64 + ) + + x = x.float() + wav_length = x.size(-1) + if wav_length % chunk_length != 0: + x = F.pad(x, (0, chunk_length - wav_length % chunk_length)) + padded_wav_length = x.size(-1) + x = x.view(x.size()[:-1] + (padded_wav_length // chunk_length, chunk_length)) + x = torch.fft.irfft( + torch.fft.rfft(x, n=n_fft) * compute_loudness.filter[sr], + n=n_fft, + ) + x = F.fold( + x.transpose(-2, -1), + (1, padded_wav_length + chunk_length), + (1, n_fft), + stride=(1, chunk_length), + ).squeeze_((-3, -2))[..., :wav_length] + + x.square_() + for win_length in win_lengths: + hop_length = win_length // 4 + # [..., n_frames] + energy = ( + x.unfold(-1, win_length, hop_length) + .matmul(torch.hann_window(win_length, device=x.device)) + .add_(win_length / 4.0 * 1e-5) + .log10_() + ) + # フィルタリング後の波形が振幅 1 の正弦波なら大体 log10(win_length/4), 1 の差は 10dB の差 + results.append(energy) + return results + + +def beatrice_slice_segments( + x: torch.Tensor, start_indices: torch.Tensor, segment_length: int +) -> torch.Tensor: + batch_size, channels, _ = x.size() + # [batch_size, 1, segment_size] + indices = start_indices[:, None, None] + torch.arange( + segment_length, device=start_indices.device + ) + # [batch_size, channels, segment_size] + indices = indices.expand(batch_size, channels, segment_length) + return x.gather(2, indices) + + +class ConverterNetwork(nn.Module): + def __init__( + self, + phone_extractor: PhoneExtractor, + pitch_estimator: PitchEstimator, + n_speakers: int, + pitch_bins: int, + hidden_channels: int, + vq_topk: int = 4, + training_time_vq: Literal["none", "self", "random"] = "none", + phone_noise_ratio: int = 0.5, + floor_noise_level: float = 1e-3, + ): + super().__init__() + self.frozen_modules = { + "phone_extractor": phone_extractor.eval().requires_grad_(False), + "pitch_estimator": pitch_estimator.eval().requires_grad_(False), + } + self.pitch_bins = pitch_bins + self.phone_noise_ratio = phone_noise_ratio + self.floor_noise_level = floor_noise_level + self.out_sample_rate = out_sample_rate = 24000 + phone_channels = 128 + self.vq = VectorQuantizer( + n_speakers=n_speakers, + codebook_size=512, + channels=phone_channels, + topk=vq_topk, + training_time_vq=training_time_vq, + ) + self.embed_phone = nn.Conv1d(phone_channels, hidden_channels, 1) + self.embed_phone.weight.data.normal_(0.0, math.sqrt(2.0 / (256 * 5))) + self.embed_phone.bias.data.zero_() + self.embed_quantized_pitch = nn.Embedding(pitch_bins, hidden_channels) + phase = ( + torch.arange(pitch_bins, dtype=torch.float)[:, None] + * ( + torch.arange(0, hidden_channels, 2, dtype=torch.float) + * (-math.log(10000.0) / hidden_channels) + ).exp_() + ) + self.embed_quantized_pitch.weight.data[:, 0::2] = phase.sin() + self.embed_quantized_pitch.weight.data[:, 1::2] = phase.cos_() + self.embed_quantized_pitch.weight.data *= math.sqrt(4.0 / 5.0) + self.embed_quantized_pitch.weight.requires_grad_(False) + self.embed_pitch_features = nn.Conv1d(4, hidden_channels, 1) + self.embed_pitch_features.weight.data.normal_(0.0, math.sqrt(2.0 / (4 * 5))) + self.embed_pitch_features.bias.data.zero_() + self.embed_speaker = nn.Embedding(n_speakers, hidden_channels) + self.embed_speaker.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0)) + self.embed_formant_shift = nn.Embedding(9, hidden_channels) + self.embed_formant_shift.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0)) + + self.key_value_speaker_embedding_length = 384 + self.key_value_speaker_embedding_channels = 128 + self.key_value_speaker_embedding = nn.Embedding( + n_speakers, + self.key_value_speaker_embedding_length + * self.key_value_speaker_embedding_channels, + ) + self.key_value_speaker_embedding.weight.data[0].normal_() + self.key_value_speaker_embedding.weight.data[1:] = ( + self.key_value_speaker_embedding.weight.data[0] + ) + + self.vocoder = Vocoder( + channels=hidden_channels, + speaker_embedding_channels=self.key_value_speaker_embedding_channels, + hop_length=out_sample_rate // 100, + n_pre_blocks=4, + out_sample_rate=out_sample_rate, + ) + self.melspectrograms = nn.ModuleList() + for win_length, n_mels in [ + (32, 5), + (64, 10), + (128, 20), + (256, 40), + (512, 80), + (1024, 160), + (2048, 320), + ]: + self.melspectrograms.append( + torchaudio.transforms.MelSpectrogram( + sample_rate=out_sample_rate, + n_fft=win_length, + win_length=win_length, + hop_length=win_length // 4, + n_mels=n_mels, + power=2, + norm="slaney", + mel_scale="slaney", + ) + ) + + def initialize_vq(self, inputs: Sequence[Iterable[torch.Tensor]]): + collector_func = self.frozen_modules["phone_extractor"].units + target_layer = self.frozen_modules["phone_extractor"].head + + self.vq.build_codebooks( + collector_func, + target_layer, + inputs, + ) + self.vq.enable_hook(target_layer) + + def enable_hook(self): + target_layer = self.frozen_modules["phone_extractor"].head + self.vq.enable_hook(target_layer) + + def _get_resampler( + self, orig_freq, new_freq, device, cache={} + ) -> torchaudio.transforms.Resample: + key = orig_freq, new_freq + if key in cache: + return cache[key] + resampler = torchaudio.transforms.Resample(orig_freq, new_freq).to( + device, non_blocking=True + ) + cache[key] = resampler + return resampler + + def forward( + self, + x: torch.Tensor, + target_speaker_id: torch.Tensor, + formant_shift_semitone: torch.Tensor, + pitch_shift_semitone: Optional[torch.Tensor] = None, + slice_start_indices: Optional[torch.Tensor] = None, + slice_segment_length: Optional[int] = None, + return_stats: bool = False, + ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]: + # x: [batch_size, 1, wav_length] + # target_speaker_id: Long[batch_size] + # formant_shift_semitone: [batch_size] + # pitch_shift_semitone: [batch_size] + # slice_start_indices: [batch_size] + + batch_size, _, _ = x.size() + self.vq.set_target_speaker_ids(target_speaker_id) + + with torch.inference_mode(): + phone_extractor: PhoneExtractor = self.frozen_modules["phone_extractor"] + pitch_estimator: PitchEstimator = self.frozen_modules["pitch_estimator"] + # [batch_size, 1, wav_length] -> [batch_size, phone_channels, length] + phone = phone_extractor.units(x).transpose(1, 2) + + if self.training and self.phone_noise_ratio != 0.0: + phone *= (1.0 - self.phone_noise_ratio) / phone.square().mean( + 1, keepdim=True + ).sqrt_() + noise = torch.randn_like(phone) + noise *= ( + self.phone_noise_ratio + / noise.square().mean(1, keepdim=True).sqrt_() + ) + phone += noise + # F.rms_norm は PyTorch >= 2.4 が必要 + phone *= ( + 1.0 + / phone.square() + .mean(1, keepdim=True) + .add_(torch.finfo(torch.float).eps) + .sqrt_() + ) + + # [batch_size, 1, wav_length] -> [batch_size, pitch_bins, length], [batch_size, 1, length] + pitch, energy = pitch_estimator(x) + # augmentation + if self.training: + # [batch_size, pitch_bins - 1] + weights = pitch.softmax(1)[:, 1:, :].mean(2) + # [batch_size] + mean_pitch = ( + weights + * torch.arange( + 1, + self.embed_quantized_pitch.num_embeddings, + device=weights.device, + ) + ).sum(1) / weights.sum(1) + mean_pitch = mean_pitch.round_().long() + target_pitch = torch.randint_like(mean_pitch, 64, 257) + shift = target_pitch - mean_pitch + shift_ratio = ( + 2.0 ** (shift.float() / pitch_estimator.pitch_bins_per_octave) + ).tolist() + shift = [] + interval_length = 100 # 1s + interval_zeros = torch.zeros( + (1, 1, interval_length * 160), device=x.device + ) + concatenated_shifted_x = [] + offsets = [0] + torch.backends.cudnn.benchmark = False + for i in range(batch_size): + shift_ratio_i = shift_ratio[i] + shift_ratio_fraction_i = Fraction.from_float( + shift_ratio_i + ).limit_denominator(30) + shift_numer_i = shift_ratio_fraction_i.numerator + shift_denom_i = shift_ratio_fraction_i.denominator + shift_ratio_i = shift_numer_i / shift_denom_i + shift_i = int( + round( + math.log2(shift_ratio_i) + * pitch_estimator.pitch_bins_per_octave + ) + ) + shift.append(shift_i) + shift_ratio[i] = shift_ratio_i + # [1, 1, wav_length / shift_ratio] + with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): + shifted_x_i = self._get_resampler( + shift_numer_i, shift_denom_i, x.device + )(x[i])[None] + if shifted_x_i.size(2) % 160 != 0: + shifted_x_i = F.pad( + shifted_x_i, + (0, 160 - shifted_x_i.size(2) % 160), + mode="reflect", + ) + assert shifted_x_i.size(2) % 160 == 0 + offsets.append( + offsets[-1] + interval_length + shifted_x_i.size(2) // 160 + ) + concatenated_shifted_x.extend([interval_zeros, shifted_x_i]) + if offsets[-1] % 256 != 0: + # 長さが同じ方が何かのキャッシュが効いて早くなるようなので + # 適当に 256 の倍数になるようにパディングして長さのパターン数を減らす + concatenated_shifted_x.append( + torch.zeros( + (1, 1, (256 - offsets[-1] % 256) * 160), device=x.device + ) + ) + # [batch_size, 1, sum(wav_length) + batch_size * 16000] + concatenated_shifted_x = torch.cat(concatenated_shifted_x, dim=2) + assert concatenated_shifted_x.size(2) % (256 * 160) == 0 + # [1, pitch_bins, length / shift_ratio], [1, 1, length / shift_ratio] + concatenated_pitch, concatenated_energy = pitch_estimator( + concatenated_shifted_x + ) + for i in range(batch_size): + shift_i = shift[i] + shift_ratio_i = shift_ratio[i] + left = offsets[i] + interval_length + right = offsets[i + 1] + pitch_i = concatenated_pitch[:, :, left:right] + energy_i = concatenated_energy[:, :, left:right] + pitch_i = F.interpolate( + pitch_i, + scale_factor=shift_ratio_i, + mode="linear", + align_corners=False, + ) + energy_i = F.interpolate( + energy_i, + scale_factor=shift_ratio_i, + mode="linear", + align_corners=False, + ) + assert pitch_i.size(2) == energy_i.size(2) + assert abs(pitch_i.size(2) - pitch.size(2)) <= 10 + length = min(pitch_i.size(2), pitch.size(2)) + + if shift_i > 0: + pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length] + pitch[i : i + 1, 1:-shift_i, :length] = pitch_i[ + :, 1 + shift_i :, :length + ] + pitch[i : i + 1, -shift_i:, :length] = -10.0 + elif shift_i < 0: + pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length] + pitch[i : i + 1, 1 : 1 - shift_i, :length] = -10.0 + pitch[i : i + 1, 1 - shift_i :, :length] = pitch_i[ + :, 1:shift_i, :length + ] + energy[i : i + 1, :, :length] = energy_i[:, :, :length] + torch.backends.cudnn.benchmark = True + + # [batch_size, pitch_bins, length] -> Long[batch_size, length], [batch_size, 3, length] + quantized_pitch, pitch_features = pitch_estimator.sample_pitch( + pitch, return_features=True + ) + if pitch_shift_semitone is not None: + quantized_pitch = torch.where( + quantized_pitch == 0, + quantized_pitch, + ( + quantized_pitch + + ( + pitch_shift_semitone[:, None] + * (pitch_estimator.pitch_bins_per_octave / 12.0) + ) + .round_() + .long() + ).clamp_(1, self.pitch_bins - 1), + ) + pitch = 55.0 * 2.0 ** ( + quantized_pitch.float() / pitch_estimator.pitch_bins_per_octave + ) + # phone が 2.5ms 先読みしているのに対して、 + # energy は 12.5ms, pitch_features は 22.5ms 先読みしているので、 + # ずらして phone に合わせる + energy = F.pad(energy[:, :, :-1], (1, 0), mode="reflect") + quantized_pitch = F.pad(quantized_pitch[:, :-2], (2, 0), mode="reflect") + pitch_features = F.pad(pitch_features[:, :, :-2], (2, 0), mode="reflect") + # [batch_size, 1, length], [batch_size, 3, length] -> [batch_size, 4, length] + pitch_features = torch.cat([energy, pitch_features], dim=1) + formant_shift_indices = ( + ((formant_shift_semitone + 2.0) * 2.0).round_().long() + ) + + phone = phone.clone() + quantized_pitch = quantized_pitch.clone() + pitch_features = pitch_features.clone() + formant_shift_indices = formant_shift_indices.clone() + pitch = pitch.clone() + + # [batch_sise, hidden_channels, length] + x = ( + self.embed_phone(phone) + + self.embed_quantized_pitch(quantized_pitch).transpose(1, 2) + + self.embed_pitch_features(pitch_features) + + ( + self.embed_speaker(target_speaker_id)[:, :, None] + + self.embed_formant_shift(formant_shift_indices)[:, :, None] + ) + ) + if slice_start_indices is not None: + assert slice_segment_length is not None + # [batch_size, hidden_channels, length] -> [batch_size, hidden_channels, segment_length] + x = beatrice_slice_segments(x, slice_start_indices, slice_segment_length) + x = F.silu(x, inplace=True) + + speaker_embedding = self.key_value_speaker_embedding(target_speaker_id).view( + batch_size, + self.key_value_speaker_embedding_length, + self.key_value_speaker_embedding_channels, + ) + + # [batch_size, hidden_channels, segment_length] -> [batch_size, 1, segment_length * 240] + y_g_hat, stats = self.vocoder(x, pitch, speaker_embedding) + stats["pitch"] = pitch + if return_stats: + return y_g_hat, stats + else: + return y_g_hat + + def _normalize_melsp(self, x): + return x.clamp(min=1e-10).log_() + + def forward_and_compute_loss( + self, + noisy_wavs_16k: torch.Tensor, + target_speaker_id: torch.Tensor, + formant_shift_semitone: torch.Tensor, + slice_start_indices: torch.Tensor, + slice_segment_length: int, + y_all: torch.Tensor, + enable_loss_ap: bool = False, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + dict[str, float], + ]: + # noisy_wavs_16k: [batch_size, 1, wav_length] + # target_speaker_id: Long[batch_size] + # formant_shift_semitone: [batch_size] + # slice_start_indices: [batch_size] + # slice_segment_length: int + # y_all: [batch_size, 1, wav_length] + + stats = {} + loss_mel = 0.0 + loss_loudness = 0.0 + loudness_win_lengths = [512, 1024, 2048, 4096] + + # [batch_size, 1, wav_length] -> [batch_size, 1, wav_length * 240] + y_hat_all, intermediates = self( + noisy_wavs_16k, + target_speaker_id, + formant_shift_semitone, + return_stats=True, + ) + y_hat_all = y_hat_all.detach().where(y_all == 0.0, y_hat_all) + + with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): + periodic_signal = intermediates["periodic_signal"].float() + aperiodic_signal = intermediates["aperiodic_signal"].float() + noise_excitation = intermediates["noise_excitation"].float() + periodic_signal = periodic_signal[:, : noise_excitation.size(1)] + aperiodic_signal = aperiodic_signal[:, : noise_excitation.size(1)] + y_hat_all = y_hat_all.float() + floor_noise = torch.randn_like(y_all) * self.floor_noise_level + y_all = y_all + floor_noise + y_hat_all += floor_noise + y_hat_all_truncated = y_hat_all.squeeze(1)[:, : periodic_signal.size(1)] + y_all_truncated = y_all.squeeze(1)[:, : periodic_signal.size(1)] + + y_loudness = compute_loudness( + y_all_truncated, self.out_sample_rate, loudness_win_lengths + ) + y_hat_loudness = compute_loudness( + y_hat_all_truncated, self.out_sample_rate, loudness_win_lengths + ) + for win_length, y_loudness_i, y_hat_loudness_i in zip( + loudness_win_lengths, y_loudness, y_hat_loudness + ): + loss_loudness_i = F.mse_loss(y_hat_loudness_i, y_loudness_i) + loss_loudness += loss_loudness_i * math.sqrt(win_length) + stats[f"loss_loudness_{win_length}"] = loss_loudness_i.item() + + for melspectrogram in self.melspectrograms: + melsp_periodic_signal = melspectrogram(periodic_signal) + melsp_aperiodic_signal = melspectrogram(aperiodic_signal) + melsp_noise_excitation = melspectrogram(noise_excitation) + # [1, n_mels, 1] + # 1/6 ... [-0.5, 0.5] の一様乱数の平均パワー + # 3/8 ... ハン窓をかけた時のパワー減衰 + # 0.5 ... 謎 + reference_melsp = melspectrogram.mel_scale( + torch.full( + (1, melspectrogram.n_fft // 2 + 1, 1), + (1 / 6) * (3 / 8) * 0.5 * melspectrogram.win_length, + device=noisy_wavs_16k.device, + ) + ) + aperiodic_ratio = melsp_aperiodic_signal / ( + melsp_periodic_signal + melsp_aperiodic_signal + 1e-5 + ) + compensation_ratio = reference_melsp / (melsp_noise_excitation + 1e-5) + + melsp_y_hat = melspectrogram(y_hat_all_truncated) + melsp_y_hat = melsp_y_hat * ( + (1.0 - aperiodic_ratio) + aperiodic_ratio * compensation_ratio + ) + y_hat_mel = self._normalize_melsp(melsp_y_hat) + + y_mel = self._normalize_melsp(melspectrogram(y_all_truncated)) + loss_mel_i = F.l1_loss(y_hat_mel, y_mel) + loss_mel += loss_mel_i + stats[ + f"loss_mel_{melspectrogram.win_length}_{melspectrogram.n_mels}" + ] = loss_mel_i.item() + + loss_mel /= len(self.melspectrograms) + + if enable_loss_ap: + t = ( + torch.arange(intermediates["pitch"].size(1), device=y_all.device) + * 0.01 + + 0.005 + ) + y_coarse_aperiodicity, y_rms = d4c( + y_all.squeeze(1), + intermediates["pitch"], + t, + self.vocoder.out_sample_rate, + coarse_only=True, + ) + y_coarse_aperiodicity = 10.0 ** (y_coarse_aperiodicity / 10.0) + y_hat_coarse_aperiodicity, y_hat_rms = d4c( + y_hat_all.squeeze(1), + intermediates["pitch"], + t, + self.vocoder.out_sample_rate, + coarse_only=True, + ) + y_hat_coarse_aperiodicity = 10.0 ** (y_hat_coarse_aperiodicity / 10.0) + rms = torch.maximum(y_rms, y_hat_rms) + loss_ap = F.mse_loss( + y_hat_coarse_aperiodicity, y_coarse_aperiodicity, reduction="none" + ) + loss_ap *= (rms / (rms + 1e-3) * (rms > 1e-5))[:, :, None] + loss_ap = loss_ap.mean() + else: + loss_ap = torch.tensor(0.0) + + # [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240] + y_hat = beatrice_slice_segments( + y_hat_all, slice_start_indices * 240, slice_segment_length * 240 + ) + # [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240] + y = beatrice_slice_segments(y_all, slice_start_indices * 240, slice_segment_length * 240) + return y, y_hat, y_hat_all, loss_loudness, loss_mel, loss_ap, stats + + def merge_weights(self): + self.vocoder.merge_weights() + + def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): + if isinstance(f, (str, bytes, os.PathLike)): + with open(f, "wb") as f: + self.dump(f) + return + if not hasattr(f, "write"): + raise TypeError + + dump_layer(self.embed_phone, f) + dump_layer(self.embed_quantized_pitch, f) + dump_layer(self.embed_pitch_features, f) + dump_layer(self.vocoder, f) + + def dump_speaker_embeddings(self, f: Union[BinaryIO, str, bytes, os.PathLike]): + if isinstance(f, (str, bytes, os.PathLike)): + with open(f, "wb") as f: + self.dump_speaker_embeddings(f) + return + if not hasattr(f, "write"): + raise TypeError + + dump_params(self.vq.codebooks, f) + dump_layer(self.embed_speaker, f) + dump_layer(self.embed_formant_shift, f) + dump_layer(self.key_value_speaker_embedding, f) + + def dump_embedding_setter(self, f: Union[BinaryIO, str, bytes, os.PathLike]): + if isinstance(f, (str, bytes, os.PathLike)): + with open(f, "wb") as f: + self.dump_embedding_setter(f) + return + if not hasattr(f, "write"): + raise TypeError + + self.vocoder.prenet.dump_kv(f) + + +# Discriminator + + +def _normalize(tensor: torch.Tensor, dim: int) -> torch.Tensor: + denom = tensor.norm(p=2.0, dim=dim, keepdim=True).clamp_min(1e-6) + return tensor / denom + + +class SANConv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + bias: bool = True, + padding_mode="zeros", + device=None, + dtype=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding=padding, + dilation=dilation, + groups=1, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + scale = self.weight.norm(p=2.0, dim=[1, 2, 3], keepdim=True).clamp_min(1e-6) + self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight)) + self.scale = nn.parameter.Parameter(scale.view(out_channels)) + if bias: + self.bias = nn.parameter.Parameter( + torch.zeros(in_channels, device=device, dtype=dtype) + ) + else: + self.register_parameter("bias", None) + + def forward( + self, input: torch.Tensor, flg_san_train: bool = False + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if self.bias is not None: + input = input + self.bias.view(self.in_channels, 1, 1) + normalized_weight = self._get_normalized_weight() + scale = self.scale.view(self.out_channels, 1, 1) + if flg_san_train: + out_fun = F.conv2d( + input, + normalized_weight.detach(), + None, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + out_dir = F.conv2d( + input.detach(), + normalized_weight, + None, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + out = out_fun * scale, out_dir * scale.detach() + else: + out = F.conv2d( + input, + normalized_weight, + None, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + out = out * scale + return out + + @torch.no_grad() + def normalize_weight(self): + self.weight.data = self._get_normalized_weight() + + def _get_normalized_weight(self) -> torch.Tensor: + return _normalize(self.weight, dim=[1, 2, 3]) + + +def get_padding(kernel_size: int, dilation: int = 1) -> int: + return (kernel_size * dilation - dilation) // 2 + + +class BeatriceDiscriminatorP(nn.Module): + def __init__( + self, period: int, kernel_size: int = 5, stride: int = 3, san: bool = False + ): + super().__init__() + self.period = period + self.san = san + # fmt: off + self.convs = nn.ModuleList([ + weight_norm(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), + weight_norm(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), + weight_norm(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), + weight_norm(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), + weight_norm(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, (get_padding(kernel_size, 1), 0))), + ]) + # fmt: on + if san: + self.conv_post = SANConv2d(1024, 1, (3, 1), 1, (1, 0)) + else: + self.conv_post = weight_norm(nn.Conv2d(1024, 1, (3, 1), 1, (1, 0))) + + def forward( + self, x: torch.Tensor, flg_san_train: bool = False + ) -> tuple[ + Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor] + ]: + fmap = [] + + b, c, t = x.shape + if t % self.period != 0: + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for conv in self.convs: + x = conv(x) + x = F.silu(x, inplace=True) + fmap.append(x) + if self.san: + x = self.conv_post(x, flg_san_train=flg_san_train) + else: + x = self.conv_post(x) + if flg_san_train: + x_fun, x_dir = x + fmap.append(x_fun) + x_fun = torch.flatten(x_fun, 1, -1) + x_dir = torch.flatten(x_dir, 1, -1) + x = x_fun, x_dir + else: + fmap.append(x) + x = torch.flatten(x, 1, -1) + return x, fmap + + +class BeatriceDiscriminatorR(nn.Module): + def __init__(self, resolution: int, san: bool = False): + super().__init__() + self.resolution = resolution + self.san = san + assert len(self.resolution) == 3 + self.convs = nn.ModuleList( + [ + weight_norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), + weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), + weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), + weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), + weight_norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), + ] + ) + if san: + self.conv_post = SANConv2d(32, 1, (3, 3), padding=(1, 1)) + else: + self.conv_post = weight_norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) + + def forward( + self, x: torch.Tensor, flg_san_train: bool = False + ) -> tuple[ + Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor] + ]: + fmap = [] + + x = self._spectrogram(x).unsqueeze(1) + for conv in self.convs: + x = conv(x) + x = F.silu(x, inplace=True) + fmap.append(x) + if self.san: + x = self.conv_post(x, flg_san_train=flg_san_train) + else: + x = self.conv_post(x) + if flg_san_train: + x_fun, x_dir = x + fmap.append(x_fun) + x_fun = torch.flatten(x_fun, 1, -1) + x_dir = torch.flatten(x_dir, 1, -1) + x = x_fun, x_dir + else: + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + def _spectrogram(self, x: torch.Tensor) -> torch.Tensor: + n_fft, hop_length, win_length = self.resolution + x = F.pad( + x, ((n_fft - hop_length) // 2, (n_fft - hop_length) // 2), mode="reflect" + ).squeeze(1) + with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): + mag = torch.stft( + x.float(), + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=torch.ones(win_length, device=x.device), + center=False, + return_complex=True, + ).abs() + + return mag + + +class BeatriceMultiPeriodDiscriminator(nn.Module): + def __init__(self, san: bool = False): + super().__init__() + resolutions = [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]] + periods = [2, 3, 5, 7, 11] + self.discriminators = nn.ModuleList( + [BeatriceDiscriminatorR(r, san=san) for r in resolutions] + + [BeatriceDiscriminatorP(p, san=san) for p in periods] + ) + self.discriminator_names = [f"R_{n}_{h}_{w}" for n, h, w in resolutions] + [ + f"P_{p}" for p in periods + ] + self.san = san + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor, flg_san_train: bool = False + ) -> tuple[ + list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]], + list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]], + list[list[torch.Tensor]], + list[list[torch.Tensor]], + ]: + batch_size = y.size(0) + concatenated_y_y_hat = torch.cat([y, y_hat]) + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for d in self.discriminators: + if flg_san_train: + (y_d_fun, y_d_dir), fmap = d( + concatenated_y_y_hat, flg_san_train=flg_san_train + ) + y_d_r_fun, y_d_g_fun = torch.split(y_d_fun, batch_size) + y_d_r_dir, y_d_g_dir = torch.split(y_d_dir, batch_size) + y_d_r = y_d_r_fun, y_d_r_dir + y_d_g = y_d_g_fun, y_d_g_dir + else: + y_d, fmap = d(concatenated_y_y_hat, flg_san_train=flg_san_train) + y_d_r, y_d_g = torch.split(y_d, batch_size) + fmap_r = [] + fmap_g = [] + for fm in fmap: + fm_r, fm_g = torch.split(fm, batch_size) + fmap_r.append(fm_r) + fmap_g.append(fm_g) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + def forward_and_compute_loss( + self, y: torch.Tensor, y_hat: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float]]: + y_d_rs, y_d_gs, fmap_rs, fmap_gs = self(y, y_hat, flg_san_train=self.san) + stats = {} + assert len(y_d_gs) == len(y_d_rs) == len(self.discriminators) + with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False): + # discriminator loss + d_loss = 0.0 + for dr, dg, name in zip(y_d_rs, y_d_gs, self.discriminator_names): + if self.san: + dr_fun, dr_dir = map(lambda x: x.float(), dr) + dg_fun, dg_dir = map(lambda x: x.float(), dg) + r_loss_fun = F.softplus(1.0 - dr_fun).square().mean() + g_loss_fun = F.softplus(dg_fun).square().mean() + r_loss_dir = F.softplus(1.0 - dr_dir).square().mean() + g_loss_dir = -F.softplus(1.0 - dg_dir).square().mean() + r_loss = r_loss_fun + r_loss_dir + g_loss = g_loss_fun + g_loss_dir + else: + dr = dr.float() + dg = dg.float() + r_loss = (1.0 - dr).square().mean() + g_loss = dg.square().mean() + stats[f"{name}_dr_loss"] = r_loss.item() + stats[f"{name}_dg_loss"] = g_loss.item() + d_loss += r_loss + g_loss + # adversarial loss + adv_loss = 0.0 + for dg, name in zip(y_d_gs, self.discriminator_names): + if self.san: + dg_fun = dg[0].float() + g_loss = F.softplus(1.0 - dg_fun).square().mean() + else: + dg = dg.float() + g_loss = (1.0 - dg).square().mean() + stats[f"{name}_gg_loss"] = g_loss.item() + adv_loss += g_loss + # feature mathcing loss + fm_loss = 0.0 + for fr, fg, name in zip(fmap_rs, fmap_gs, self.discriminator_names): + fm_loss_i = 0.0 + for j, (r, g) in enumerate(zip(fr, fg)): + fm_loss_ij = (r.detach().float() - g.float()).abs().mean() + stats[f"~{name}_fm_loss_{j}"] = fm_loss_ij.item() + fm_loss_i += fm_loss_ij + stats[f"{name}_fm_loss"] = fm_loss_i.item() + fm_loss += fm_loss_i + return d_loss, adv_loss, fm_loss, stats + + + +class GradBalancer: + """Adapted from https://github.com/facebookresearch/encodec/blob/main/encodec/balancer.py""" + + def __init__( + self, + weights: dict[str, float], + rescale_grads: bool = True, + total_norm: float = 1.0, + ema_decay: float = 0.999, + per_batch_item: bool = True, + ): + self.weights = weights + self.per_batch_item = per_batch_item + self.total_norm = total_norm + self.ema_decay = ema_decay + self.rescale_grads = rescale_grads + + self.ema_total: dict[str, float] = defaultdict(float) + self.ema_fix: dict[str, float] = defaultdict(float) + + def backward( + self, + losses: dict[str, torch.Tensor], + input: torch.Tensor, + scaler: Optional[torch.amp.GradScaler] = None, + skip_update_ema: bool = False, + ) -> dict[str, float]: + stats = {} + if skip_update_ema: + assert len(losses) == len(self.ema_total) + ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()} + else: + # 各 loss に対して d loss / d input とそのノルムを計算する + norms = {} + grads = {} + for name, loss in losses.items(): + if scaler is not None: + loss = scaler.scale(loss) + (grad,) = torch.autograd.grad(loss, [input], retain_graph=True) + if not grad.isfinite().all(): + input.backward(grad) + return {} + grad = grad.detach() / (1.0 if scaler is None else scaler.get_scale()) + if self.per_batch_item: + dims = tuple(range(1, grad.dim())) + ema_norm = grad.norm(dim=dims).mean() + else: + ema_norm = grad.norm() + norms[name] = float(ema_norm) + grads[name] = grad + + # ノルムの移動平均を計算する + for key, value in norms.items(): + self.ema_total[key] = self.ema_total[key] * self.ema_decay + value + self.ema_fix[key] = self.ema_fix[key] * self.ema_decay + 1.0 + ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()} + + # ログを取る + total_ema_norm = sum(ema_norms.values()) + for k, ema_norm in ema_norms.items(): + stats[f"grad_norm_value_{k}"] = ema_norm + stats[f"grad_norm_ratio_{k}"] = ema_norm / (total_ema_norm + 1e-12) + + # loss の係数の比率を計算する + if self.rescale_grads: + total_weights = sum([self.weights[k] for k in ema_norms]) + ratios = {k: w / total_weights for k, w in self.weights.items()} + + # 勾配を修正する + loss = 0.0 + for name, ema_norm in ema_norms.items(): + if self.rescale_grads: + scale = ratios[name] * self.total_norm / (ema_norm + 1e-12) + else: + scale = self.weights[name] + loss += (losses if skip_update_ema else grads)[name] * scale + if scaler is not None: + loss = scaler.scale(loss) + if skip_update_ema: + (loss,) = torch.autograd.grad(loss, [input]) + input.backward(loss) + return stats + + def state_dict(self) -> dict[str, dict[str, float]]: + return { + "ema_total": dict(self.ema_total), + "ema_fix": dict(self.ema_fix), + } + + def load_state_dict(self, state_dict): + self.ema_total = defaultdict(float, state_dict["ema_total"]) + self.ema_fix = defaultdict(float, state_dict["ema_fix"]) + + +class QualityTester(nn.Module): + def __init__(self): + super().__init__() + self.utmos = torch.hub.load( + "tarepan/SpeechMOS:v1.0.0", "utmos22_strong", trust_repo=True + ).eval() + + @torch.inference_mode() + def compute_mos(self, wav: torch.Tensor) -> dict[str, list[float]]: + res = {"utmos": self.utmos(wav, sr=16000).tolist()} + return res + + def test( + self, converted_wav: torch.Tensor, source_wav: torch.Tensor + ) -> dict[str, list[float]]: + # [batch_size, wav_length] + res = {} + res.update(self.compute_mos(converted_wav)) + return res + + def test_many( + self, converted_wavs: list[torch.Tensor], source_wavs: list[torch.Tensor] + ) -> tuple[dict[str, float], dict[str, list[float]]]: + # list[batch_size, wav_length] + results = defaultdict(list) + assert len(converted_wavs) == len(source_wavs) + for converted_wav, source_wav in zip(converted_wavs, source_wavs): + res = self.test(converted_wav, source_wav) + for metric_name, value in res.items(): + results[metric_name].extend(value) + return { + metric_name: sum(values) / len(values) + for metric_name, values in results.items() + }, results + + +def compute_grad_norm( + model: nn.Module, return_stats: bool = False +) -> Union[float, dict[str, float]]: + total_norm = 0.0 + stats = {} + for name, p in model.named_parameters(): + if p.grad is None: + continue + param_norm = p.grad.data.norm().item() + if not math.isfinite(param_norm): + param_norm = p.grad.data.float().norm().item() + total_norm += param_norm * param_norm + if return_stats: + stats[f"grad_norm_{name}"] = param_norm + total_norm = math.sqrt(total_norm) + if return_stats: + return total_norm, stats + else: + return total_norm + + +def compute_mean_f0( + files: list[Path], method: Literal["dio", "harvest"] = "dio" +) -> float: + sum_log_f0 = 0.0 + n_frames = 0 + for file in files: + wav, sr = beatrice_load_audio(file) + if method == "dio": + f0, _ = pyworld.dio(wav.ravel().numpy().astype(np.float64), sr) + elif method == "harvest": + f0, _ = pyworld.harvest(wav.ravel().numpy().astype(np.float64), sr) + else: + raise ValueError(f"Invalid method: {method}") + f0 = f0[f0 > 0] + sum_log_f0 += float(np.log(f0).sum()) + n_frames += len(f0) + if n_frames == 0: + return math.nan + mean_log_f0 = sum_log_f0 / n_frames + return math.exp(mean_log_f0) + + + +def get_resampler( + sr_before: int, sr_after: int, device="cpu", cache={} +) -> torchaudio.transforms.Resample: + if not isinstance(device, str): + device = str(device) + if (sr_before, sr_after, device) not in cache: + cache[(sr_before, sr_after, device)] = torchaudio.transforms.Resample( + sr_before, sr_after + ).to(device) + return cache[(sr_before, sr_after, device)] + + +def convolve(signal: torch.Tensor, ir: torch.Tensor) -> torch.Tensor: + n = 1 << (signal.size(-1) + ir.size(-1) - 2).bit_length() + res = torch.fft.irfft(torch.fft.rfft(signal, n=n) * torch.fft.rfft(ir, n=n), n=n) + return res[..., : signal.size(-1)] + + +def random_formant_shift( + wav: torch.Tensor, + sample_rate: int, + formant_shift_semitone_min: float = -3.0, + formant_shift_semitone_max: float = 3.0, +) -> torch.Tensor: + assert wav.ndim == 2 + assert wav.size(0) == 1 + + device = wav.device + + hop_length = 256 + + # [wav_length] + wav_np = wav.ravel().double().cpu().numpy() + f0, t = pyworld.dio( + wav_np, + sample_rate, + f0_floor=55, + f0_ceil=1400, + frame_period=hop_length * 1000 / sample_rate, + ) + f0 = pyworld.stonemask(wav_np, f0, t, sample_rate) + world_sp = pyworld.cheaptrick(wav_np, f0, t, sample_rate) + world_sp = ( + torch.from_numpy(world_sp).float().to(device).sqrt_()[None] + ) # [1, length, n_fft // 2 + 1] + + n_fft = win_length = (world_sp.size(2) - 1) * 2 + + window = torch.hann_window(win_length, device=device) + + # [1, n_fft // 2 + 1, length] + stft_sp = torch.stft( + wav, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + return_complex=True, + ) + assert world_sp.size(1) == stft_sp.size(2), (world_sp.size(), stft_sp.size()) + assert world_sp.size(2) == stft_sp.size(1), (world_sp.size(), stft_sp.size()) + + shift_semitones = ( + torch.rand(()).item() + * (formant_shift_semitone_max - formant_shift_semitone_min) + + formant_shift_semitone_min + ) + shift_ratio = 2.0 ** (shift_semitones / 12.0) + shifted_world_sp = F.interpolate( + world_sp, scale_factor=shift_ratio, mode="linear", align_corners=True + ) + + if shifted_world_sp.size(2) > n_fft // 2 + 1: + shifted_world_sp = shifted_world_sp[:, :, : n_fft // 2 + 1] + elif shifted_world_sp.size(2) < n_fft // 2 + 1: + shifted_world_sp = F.pad( + shifted_world_sp, (0, n_fft // 2 + 1 - shifted_world_sp.size(2)) + ) + + ratio = ((shifted_world_sp + 1e-5) / (world_sp + 1e-5)).clamp(0.1, 10.0) + stft_sp *= ratio.transpose(-2, -1) # [1, n_fft // 2 + 1, length] + + out = torch.istft( + stft_sp, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + length=wav.size(-1), + ) + + return out + + +def random_filter(audio: torch.Tensor) -> torch.Tensor: + assert audio.ndim == 2 + ab = torch.rand(audio.size(0), 6) * 0.75 - 0.375 + a, b = ab[:, :3], ab[:, 3:] + a[:, 0] = 1.0 + b[:, 0] = 1.0 + audio = torchaudio.functional.lfilter(audio, a, b, clamp=False) + return audio + + +def get_noise( + n_samples: int, sample_rate: float, files: list[Union[str, bytes, os.PathLike]] +) -> torch.Tensor: + resample_augmentation_candidates = [0.9, 0.95, 1.0, 1.05, 1.1] + wavs = [] + current_length = 0 + while current_length < n_samples: + idx_files = torch.randint(0, len(files), ()) + file = files[idx_files] + wav, sr = beatrice_load_audio(file) + assert wav.size(0) == 1 + augmented_sample_rate = int( + round( + sample_rate + * resample_augmentation_candidates[ + torch.randint(0, len(resample_augmentation_candidates), ()) + ] + ) + ) + resampler = get_resampler(sr, augmented_sample_rate) + wav = resampler(wav) + wav = random_filter(wav) + wav *= 0.99 / (wav.abs().max() + 1e-5) + wavs.append(wav) + current_length += wav.size(1) + start = torch.randint(0, current_length - n_samples + 1, ()) + wav = torch.cat(wavs, dim=1)[:, start : start + n_samples] + assert wav.size() == (1, n_samples), wav.size() + return wav + + +def get_butterworth_lpf( + cutoff_freq: float, sample_rate: int, cache={} +) -> tuple[torch.Tensor, torch.Tensor]: + if (cutoff_freq, sample_rate) not in cache: + q = math.sqrt(0.5) + omega = math.tau * cutoff_freq / sample_rate + cos_omega = math.cos(omega) + alpha = math.sin(omega) / (2.0 * q) + b1 = (1.0 - cos_omega) / (1.0 + alpha) + b0 = b1 * 0.5 + a1 = -2.0 * cos_omega / (1.0 + alpha) + a2 = (1.0 - alpha) / (1.0 + alpha) + cache[(cutoff_freq, sample_rate)] = ( + torch.tensor([b0, b1, b0]), + torch.tensor([1.0, a1, a2]), + ) + return cache[(cutoff_freq, sample_rate)] + + +def augment_audio( + clean: torch.Tensor, + sample_rate: int, + noise_files: list[Union[str, bytes, os.PathLike]], + ir_files: list[Union[str, bytes, os.PathLike]], + snr_candidates: list[float] = [20.0, 25.0, 30.0, 35.0, 40.0, 45.0], + formant_shift_probability: float = 0.5, + formant_shift_semitone_min: float = -3.0, + formant_shift_semitone_max: float = 3.0, + reverb_probability: float = 0.5, + lpf_probability: float = 0.2, + lpf_cutoff_freq_candidates: list[float] = [2000.0, 3000.0, 4000.0, 6000.0], +) -> torch.Tensor: + # [1, wav_length] + assert clean.size(0) == 1 + n_samples = clean.size(1) + + original_clean_rms = clean.square().mean().sqrt_() + + # clean をフォルマントシフトする + if torch.rand(()) < formant_shift_probability: + clean = random_formant_shift( + clean, sample_rate, formant_shift_semitone_min, formant_shift_semitone_max + ) + + # noise を取得して clean と concat する + noise = get_noise(n_samples, sample_rate, noise_files) + signals = torch.cat([clean, noise]) + + # clean, noise に異なるランダムフィルタをかける + signals = random_filter(signals) + + # clean, noise にリバーブをかける + if torch.rand(()) < reverb_probability: + ir_file = ir_files[torch.randint(0, len(ir_files), ())] + ir, sr = beatrice_load_audio(ir_file) + assert ir.size() == (2, sr), ir.size() + assert sr == sample_rate, (sr, sample_rate) + signals = convolve(signals, ir) + + # clean, noise に同じ LPF をかける + if torch.rand(()) < lpf_probability: + if signals.abs().max() > 0.8: + signals /= signals.abs().max() * 1.25 + cutoff_freq = lpf_cutoff_freq_candidates[ + torch.randint(0, len(lpf_cutoff_freq_candidates), ()) + ] + b, a = get_butterworth_lpf(cutoff_freq, sample_rate) + signals = torchaudio.functional.lfilter(signals, a, b, clamp=False) + + # clean の音量を合わせる + clean, noise = signals + clean_rms = clean.square().mean().sqrt_() + clean *= original_clean_rms / clean_rms + + if len(snr_candidates) >= 1: + # clean, noise の音量をピークを重視して取る + clean_level = clean.square().square_().mean().sqrt_().sqrt_() + noise_level = noise.square().square_().mean().sqrt_().sqrt_() + # SNR + snr = snr_candidates[torch.randint(0, len(snr_candidates), ())] + # noisy を生成 + noisy = clean + noise * ( + 0.1 ** (snr / 20.0) * clean_level / (noise_level + 1e-5) + ) + + return noisy + + +class WavDataset(torch.utils.data.Dataset): + def __init__( + self, + audio_files: list[tuple[Path, int]], + in_sample_rate: int = 16000, + out_sample_rate: int = 24000, + wav_length: int = 4 * 24000, # 4s + segment_length: int = 100, # 1s + noise_files: Optional[list[Union[str, bytes, os.PathLike]]] = None, + ir_files: Optional[list[Union[str, bytes, os.PathLike]]] = None, + augmentation_snr_candidates: list[float] = [20.0, 25.0, 30.0, 35.0, 40.0, 45.0], + augmentation_formant_shift_probability: float = 0.5, + augmentation_formant_shift_semitone_min: float = -3.0, + augmentation_formant_shift_semitone_max: float = 3.0, + augmentation_reverb_probability: float = 0.5, + augmentation_lpf_probability: float = 0.2, + augmentation_lpf_cutoff_freq_candidates: list[float] = [ + 2000.0, + 3000.0, + 4000.0, + 6000.0, + ], + ): + self.audio_files = audio_files + self.in_sample_rate = in_sample_rate + self.out_sample_rate = out_sample_rate + self.wav_length = wav_length + self.segment_length = segment_length + self.noise_files = noise_files + self.ir_files = ir_files + self.augmentation_snr_candidates = augmentation_snr_candidates + self.augmentation_formant_shift_probability = ( + augmentation_formant_shift_probability + ) + self.augmentation_formant_shift_semitone_min = ( + augmentation_formant_shift_semitone_min + ) + self.augmentation_formant_shift_semitone_max = ( + augmentation_formant_shift_semitone_max + ) + self.augmentation_reverb_probability = augmentation_reverb_probability + self.augmentation_lpf_probability = augmentation_lpf_probability + self.augmentation_lpf_cutoff_freq_candidates = ( + augmentation_lpf_cutoff_freq_candidates + ) + + if (noise_files is None) is not (ir_files is None): + raise ValueError("noise_files and ir_files must be both None or not None") + + self.in_hop_length = in_sample_rate // 100 + self.out_hop_length = out_sample_rate // 100 # 10ms 刻み + + def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, int, int]: + file, speaker_id = self.audio_files[index] + clean_wav, sample_rate = beatrice_load_audio(file) + if clean_wav.size(0) != 1: + ch = torch.randint(0, clean_wav.size(0), ()) + clean_wav = clean_wav[ch : ch + 1] + + formant_shift_candidates = [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0] + formant_shift = formant_shift_candidates[ + torch.randint(0, len(formant_shift_candidates), ()).item() + ] + + resampler_fraction = Fraction( + sample_rate / self.out_sample_rate * 2.0 ** (formant_shift / 12.0) + ).limit_denominator(300) + clean_wav = get_resampler( + resampler_fraction.numerator, resampler_fraction.denominator + )(clean_wav) + + assert clean_wav.size(0) == 1 + assert clean_wav.size(1) != 0 + + clean_wav = F.pad(clean_wav, (self.wav_length, self.wav_length)) + + if self.noise_files is None: + noisy_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)( + clean_wav + ) + else: + clean_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)( + clean_wav + ) + noisy_wav_16k = augment_audio( + clean_wav_16k, + self.in_sample_rate, + self.noise_files, + self.ir_files, + self.augmentation_snr_candidates, + self.augmentation_formant_shift_probability, + self.augmentation_formant_shift_semitone_min, + self.augmentation_formant_shift_semitone_max, + self.augmentation_reverb_probability, + self.augmentation_lpf_probability, + self.augmentation_lpf_cutoff_freq_candidates, + ) + + clean_wav = clean_wav.squeeze_(0) + noisy_wav_16k = noisy_wav_16k.squeeze_(0) + + # 音量をランダマイズする + amplitude = torch.rand(()).item() * 0.899 + 0.1 + factor = amplitude / clean_wav.abs().max() + clean_wav *= factor + noisy_wav_16k *= factor + while noisy_wav_16k.abs().max() >= 1.0: + clean_wav *= 0.5 + noisy_wav_16k *= 0.5 + + return clean_wav, noisy_wav_16k, speaker_id, formant_shift + + def __len__(self) -> int: + return len(self.audio_files) + + def collate( + self, batch: list[tuple[torch.Tensor, torch.Tensor, int, int]] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + assert self.wav_length % self.out_hop_length == 0 + length = self.wav_length // self.out_hop_length + clean_wavs = [] + noisy_wavs = [] + slice_starts = [] + speaker_ids = [] + formant_shifts = [] + for clean_wav, noisy_wav, speaker_id, formant_shift in batch: + # 発声部分をランダムに 1 箇所選ぶ + (voiced,) = clean_wav.nonzero(as_tuple=True) + assert voiced.numel() != 0 + center = voiced[torch.randint(0, voiced.numel(), ()).item()].item() + # 発声部分が中央にくるように、スライス区間を選ぶ + slice_start = center - self.segment_length * self.out_hop_length // 2 + assert slice_start >= 0 + # スライス区間が含まれるように、ランダムに wav_length の長さを切り出す + r = torch.randint(0, length - self.segment_length + 1, ()).item() + offset = slice_start - r * self.out_hop_length + clean_wavs.append(clean_wav[offset : offset + self.wav_length]) + offset_in_sample_rate = int( + round(offset * self.in_sample_rate / self.out_sample_rate) + ) + noisy_wavs.append( + noisy_wav[ + offset_in_sample_rate : offset_in_sample_rate + + length * self.in_hop_length + ] + ) + slice_start = r + slice_starts.append(slice_start) + speaker_ids.append(speaker_id) + formant_shifts.append(formant_shift) + clean_wavs = torch.stack(clean_wavs) + noisy_wavs = torch.stack(noisy_wavs) + slice_starts = torch.tensor(slice_starts) + speaker_ids = torch.tensor(speaker_ids) + formant_shifts = torch.tensor(formant_shifts) + return ( + clean_wavs, # [batch_size, wav_length] + noisy_wavs, # [batch_size, wav_length] + slice_starts, # Long[batch_size] + speaker_ids, # Long[batch_size] + formant_shifts, # Long[batch_size] + ) + + + +AUDIO_FILE_SUFFIXES = { + ".wav", + ".aif", + ".aiff", + ".fla", + ".flac", + ".oga", + ".ogg", + ".opus", + ".mp3", +} + + +def get_compressed_optimizer_state_dict( + optimizer: torch.optim.Optimizer, +) -> dict: + state_dict = {} + for k0, v0 in optimizer.state_dict().items(): + if k0 != "state": + state_dict[k0] = v0 + continue + state_dict[k0] = {} + for k1, v1 in v0.items(): + state_dict[k0][k1] = {} + for k2, v2 in v1.items(): + if isinstance(v2, torch.Tensor): + state_dict[k0][k1][k2] = v2.bfloat16() + assert state_dict[k0][k1][k2].isfinite().all() + else: + state_dict[k0][k1][k2] = v2 + return state_dict + + +def get_decompressed_optimizer_state_dict(compressed_state_dict: dict) -> dict: + state_dict = {} + for k0, v0 in compressed_state_dict.items(): + if k0 != "state": + state_dict[k0] = v0 + continue + state_dict[k0] = {} + for k1, v1 in v0.items(): + state_dict[k0][k1] = {} + for k2, v2 in v1.items(): + if isinstance(v2, torch.Tensor): + state_dict[k0][k1][k2] = v2.float() + assert state_dict[k0][k1][k2].isfinite().all() + else: + state_dict[k0][k1][k2] = v2 + return state_dict + + + + +# ============================================================ +# BEATRICE V2 TRAINING - Embedded (downloads assets from HuggingFace) +# ============================================================ + +BEATRICE_AUDIO_FILE_SUFFIXES = {".wav", ".aif", ".aiff", ".fla", ".flac", ".oga", ".ogg", ".opus", ".mp3"} + + +def preprocess_audio_for_beatrice(audio_path: str, output_dir: str, speaker_name: str = "speaker"): + """Preprocess audio for Beatrice training using silence-based splitting""" + + # Create speaker directory structure required by Beatrice + speaker_dir = os.path.join(output_dir, speaker_name) + os.makedirs(speaker_dir, exist_ok=True) + + # Load audio at 16kHz (Beatrice input sample rate) + audio, sr = librosa.load(audio_path, sr=16000, mono=True) + + # Simple silence-based splitting (RMS threshold) + chunk_size = int(4.0 * sr) # 4 second chunks + hop = int(3.5 * sr) # 0.5s overlap + threshold = 0.01 # RMS threshold + + chunks_saved = 0 + for i, start in enumerate(range(0, len(audio) - chunk_size, hop)): + chunk = audio[start:start + chunk_size] + rms = np.sqrt(np.mean(chunk ** 2)) + + if rms > threshold: # Skip silence + # Normalize + max_val = np.abs(chunk).max() + if max_val > 0: + chunk = chunk / max_val * 0.9 + + chunk_path = os.path.join(speaker_dir, f"{speaker_name}_{chunks_saved:04d}.wav") + sf.write(chunk_path, chunk, sr) + chunks_saved += 1 + + logger.info(f"Beatrice preprocessing: {chunks_saved} chunks saved to {speaker_dir}") + return chunks_saved, output_dir + + +def train_beatrice_generator( + data_dir: str, + output_dir: str, + epochs: int = 30, + batch_size: int = 8, + lr_g: float = 5e-5, + lr_d: float = 5e-5, + use_augmentation: bool = False, + resume: bool = False, + progress_callback=None, +): + """Train Beatrice v2 model - generator yielding (message, model_path) tuples""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Download pretrained + yield "Downloading pretrained models...", None + phone_extractor_path = download_beatrice_asset("phone_extractor") + pitch_estimator_path = download_beatrice_asset("pitch_estimator") + pretrained_model_path = download_beatrice_asset("pretrained_model") + + # Discover speakers from directory structure + # Expected: data_dir/speaker_name/*.wav + speakers = [] + training_filelist = [] + speaker_audio_files = [] + for speaker_dir in sorted(Path(data_dir).iterdir()): + if not speaker_dir.is_dir(): + continue + candidates = [f for f in sorted(speaker_dir.rglob("*")) + if f.is_file() and f.suffix.lower() in BEATRICE_AUDIO_FILE_SUFFIXES] + if candidates: + speaker_id = len(speakers) + speakers.append(speaker_dir.name) + training_filelist.extend([(f, speaker_id) for f in candidates]) + speaker_audio_files.append(candidates) + + n_speakers = len(speakers) + if n_speakers == 0: + yield "Error: No speakers found in data directory", None + return + + yield f"Found {n_speakers} speaker(s), {len(training_filelist)} files", None + + # Augmentation assets (optional) + noise_files = None + ir_files = None + if use_augmentation: + try: + noise_dir, ir_dir = download_beatrice_augmentation() + if noise_dir and ir_dir: + noise_files = sorted(list(Path(noise_dir).rglob("*.wav")) + list(Path(noise_dir).rglob("*.flac"))) + ir_files = sorted(list(Path(ir_dir).rglob("*.wav")) + list(Path(ir_dir).rglob("*.flac"))) + if noise_files and ir_files: + yield f"Loaded augmentation: {len(noise_files)} noise, {len(ir_files)} IR files", None + else: + noise_files = None + ir_files = None + except Exception as e: + yield f"Warning: Could not load augmentation assets: {e}", None + + # Build models + yield "Building models...", None + + phone_extractor = PhoneExtractor().to(device).eval().requires_grad_(False) + pe_ckpt = torch.load(phone_extractor_path, map_location="cpu", weights_only=True) + phone_extractor.load_state_dict(pe_ckpt["phone_extractor"], strict=False) + del pe_ckpt + + pitch_estimator = PitchEstimator().to(device).eval().requires_grad_(False) + pi_ckpt = torch.load(pitch_estimator_path, map_location="cpu", weights_only=True) + pitch_estimator.load_state_dict(pi_ckpt["pitch_estimator"]) + del pi_ckpt + + hidden_channels = 256 + pitch_bins = 448 + + net_g = ConverterNetwork( + phone_extractor, pitch_estimator, + n_speakers=n_speakers, + pitch_bins=pitch_bins, + hidden_channels=hidden_channels, + vq_topk=4, + training_time_vq="none", + phone_noise_ratio=0.5, + floor_noise_level=1e-3, + ).to(device) + + net_d = BeatriceMultiPeriodDiscriminator(san=True).to(device) + + # Optimizers + optim_g = torch.optim.AdamW(net_g.parameters(), lr_g, betas=(0.8, 0.99), eps=1e-6) + optim_d = torch.optim.AdamW(net_d.parameters(), lr_d, betas=(0.8, 0.99), eps=1e-6) + + grad_scaler = torch.amp.GradScaler(device.type, enabled=device.type == "cuda") + grad_balancer = GradBalancer( + weights={ + "loss_loudness": 1.0, + "loss_mel": 45.0, + "loss_adv": 1.0, + "loss_fm": 2.0, + }, + ema_decay=0.999, + ) + + initial_iteration = 0 + os.makedirs(output_dir, exist_ok=True) + + # Load pretrained or resume + if resume: + latest_ckpt = os.path.join(output_dir, "checkpoint_latest.pt.gz") + if os.path.isfile(latest_ckpt): + yield "Resuming from checkpoint...", None + with gzip.open(latest_ckpt, "rb") as f: + ckpt = torch.load(f, map_location="cpu", weights_only=True) + net_g.load_state_dict(ckpt["net_g"], strict=False) + # Filter discriminator for shape mismatches + net_d_state = net_d.state_dict() + filtered_d = {k: v for k, v in ckpt["net_d"].items() + if k in net_d_state and v.shape == net_d_state[k].shape} + net_d.load_state_dict(filtered_d, strict=False) + optim_g.load_state_dict(get_decompressed_optimizer_state_dict(ckpt["optim_g"])) + optim_d.load_state_dict(get_decompressed_optimizer_state_dict(ckpt["optim_d"])) + if "grad_balancer" in ckpt: + grad_balancer.load_state_dict(ckpt["grad_balancer"]) + if "grad_scaler" in ckpt: + grad_scaler.load_state_dict(ckpt["grad_scaler"]) + initial_iteration = ckpt.get("iteration", 0) + del ckpt + else: + yield "No checkpoint found, starting fresh with pretrained", None + resume = False + + if not resume: + yield "Loading pretrained weights...", None + with gzip.open(pretrained_model_path, "rb") as f: + pretrained_ckpt = torch.load(f, map_location="cpu", weights_only=True) + # Adapt pretrained for our n_speakers + initial_speaker_emb = pretrained_ckpt["net_g"]["embed_speaker.weight"][:1] + pretrained_ckpt["net_g"]["embed_speaker.weight"] = initial_speaker_emb[[0] * n_speakers] + initial_kv_emb = pretrained_ckpt["net_g"]["key_value_speaker_embedding.weight"][:1] + pretrained_ckpt["net_g"]["key_value_speaker_embedding.weight"] = initial_kv_emb[[0] * n_speakers] + pretrained_ckpt["net_g"]["vq.codebooks"] = pretrained_ckpt["net_g"]["vq.codebooks"][[0] * n_speakers] + net_g.load_state_dict(pretrained_ckpt["net_g"], strict=False) + # Filter discriminator state dict for shape mismatches (pretrained may use san=False) + net_d_state = net_d.state_dict() + filtered_d = {k: v for k, v in pretrained_ckpt["net_d"].items() + if k in net_d_state and v.shape == net_d_state[k].shape} + net_d.load_state_dict(filtered_d, strict=False) + logger.info(f"Loaded {len(filtered_d)}/{len(pretrained_ckpt['net_d'])} discriminator weights") + # Don't load grad_balancer/grad_scaler from pretrained - our loss weights may differ + # These will be re-initialized fresh for fine-tuning + del pretrained_ckpt + + # Build VQ codebooks + yield "Building VQ codebooks...", None + + def wav_iterator(files): + for file in files: + wav, sr = beatrice_load_audio(file) + wav = wav.to(device) + if sr != 16000: + wav = get_resampler(sr, 16000, str(device))(wav) + yield wav[:, None, :] + + if resume: + net_g.enable_hook() + else: + net_g.initialize_vq([wav_iterator(files) for files in speaker_audio_files]) + + # Dataset + dataset = WavDataset( + training_filelist, + in_sample_rate=16000, + out_sample_rate=24000, + wav_length=96000, + segment_length=100, + noise_files=noise_files, + ir_files=ir_files, + ) + + _num_workers = min(4, os.cpu_count() or 1) + effective_batch = min(batch_size, len(training_filelist)) + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=_num_workers, + collate_fn=dataset.collate, + shuffle=True, + batch_size=effective_batch, + pin_memory=True, + drop_last=len(training_filelist) > effective_batch, + persistent_workers=_num_workers > 0, + ) + + # Calculate steps + steps_per_epoch = max(1, len(training_filelist) // batch_size) + total_steps = epochs * steps_per_epoch + warmup_steps = min(total_steps // 4, 5000) + + # LR scheduler with warmup + def lr_lambda(step): + if step < warmup_steps: + return step / max(1, warmup_steps) + return 0.999 ** (step - warmup_steps) + + scheduler_g = torch.optim.lr_scheduler.LambdaLR(optim_g, lr_lambda) + scheduler_d = torch.optim.lr_scheduler.LambdaLR(optim_d, lr_lambda) + + # Advance schedulers if resuming + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message=r"Detected call of `lr_scheduler\.step\(\)") + for _ in range(initial_iteration + 1): + scheduler_g.step() + scheduler_d.step() + + net_g.train() + net_d.train() + + yield f"Training {total_steps} steps ({epochs} epochs x {steps_per_epoch} steps/epoch)", None + + # Training loop + step = initial_iteration + data_iter = None + ckpt_path = None + + for epoch in range(epochs): + epoch_loss_g = 0.0 + epoch_loss_d = 0.0 + epoch_steps = 0 + + for batch_idx in range(steps_per_epoch): + if data_iter is None: + data_iter = iter(dataloader) + batch = next(data_iter, None) + if batch is None: + data_iter = iter(dataloader) + batch = next(data_iter, None) + if batch is None: + break + + clean_wavs, noisy_wavs_16k, slice_starts, speaker_ids, formant_shifts = \ + [x.to(device, non_blocking=True) for x in batch] + + with torch.amp.autocast(device.type, enabled=device.type == "cuda"): + # Generator forward + y, y_hat, y_hat_for_backward, loss_loudness, loss_mel, loss_ap, gen_stats = \ + net_g.forward_and_compute_loss( + noisy_wavs_16k[:, None, :], + speaker_ids, + formant_shifts, + slice_start_indices=slice_starts, + slice_segment_length=100, + y_all=clean_wavs[:, None, :], + ) + + # Discriminator forward + loss_disc, loss_adv, loss_fm, disc_stats = \ + net_d.forward_and_compute_loss(y, y_hat) + + # Discriminator backward + optim_d.zero_grad(set_to_none=True) + grad_scaler.scale(loss_disc).backward(retain_graph=True, inputs=list(net_d.parameters())) + grad_scaler.unscale_(optim_d) + + # Generator backward + optim_g.zero_grad(set_to_none=True) + grad_balancer.backward( + {"loss_loudness": loss_loudness, "loss_mel": loss_mel, + "loss_adv": loss_adv, "loss_fm": loss_fm}, + y_hat_for_backward, grad_scaler, + skip_update_ema=step > 10 and step % 5 != 0, + ) + grad_scaler.unscale_(optim_g) + + # Update + grad_scaler.step(optim_g) + grad_scaler.step(optim_d) + grad_scaler.update() + optim_g.zero_grad(set_to_none=True) + optim_d.zero_grad(set_to_none=True) + + scheduler_g.step() + scheduler_d.step() + + epoch_loss_g += loss_mel.item() + epoch_loss_d += loss_disc.item() + epoch_steps += 1 + step += 1 + + if progress_callback: + progress_callback(step / total_steps) + + avg_loss_g = epoch_loss_g / max(1, epoch_steps) + avg_loss_d = epoch_loss_d / max(1, epoch_steps) + + yield f"Epoch {epoch+1}/{epochs} | G loss: {avg_loss_g:.4f} | D loss: {avg_loss_d:.4f} | LR: {scheduler_g.get_last_lr()[0]:.2e}", None + + # Save checkpoint periodically + if (epoch + 1) % max(1, epochs // 5) == 0 or epoch == epochs - 1: + ckpt_path = os.path.join(output_dir, f"checkpoint_{step:08d}.pt.gz") + with gzip.open(ckpt_path, "wb") as f: + torch.save({ + "iteration": step, + "net_g": net_g.state_dict(), + "phone_extractor": phone_extractor.state_dict(), + "pitch_estimator": pitch_estimator.state_dict(), + "net_d": {k: v.half() for k, v in net_d.state_dict().items()}, + "optim_g": get_compressed_optimizer_state_dict(optim_g), + "optim_d": get_compressed_optimizer_state_dict(optim_d), + "grad_balancer": grad_balancer.state_dict(), + "grad_scaler": grad_scaler.state_dict(), + "h": { + "hidden_channels": hidden_channels, + "pitch_bins": pitch_bins, + "vq_topk": 4, + "training_time_vq": "none", + "phone_noise_ratio": 0.5, + "floor_noise_level": 1e-3, + "san": True, + }, + "speakers": speakers, + }, f) + shutil.copy(ckpt_path, os.path.join(output_dir, "checkpoint_latest.pt.gz")) + yield f"Saved checkpoint: {ckpt_path}", ckpt_path + + # Cleanup + del net_g, net_d, optim_g, optim_d, phone_extractor, pitch_estimator + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + yield "Training complete!", ckpt_path + + +def convert_voice_beatrice( + source_audio, + model_file, + target_speaker: int = 0, + pitch_shift: int = 0, + formant_shift: float = 0.0, + progress=None, +): + """Convert voice using Beatrice v2 model + + Args: + source_audio: Path to source audio file + model_file: Path to Beatrice checkpoint (.pt.gz) or file object with .name + target_speaker: Target speaker index + pitch_shift: Pitch shift in semitones + formant_shift: Formant shift in semitones (-2 to 2) + progress: Gradio progress callback + + Returns: + (output_path, status_message) tuple + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Get model path + if hasattr(model_file, 'name'): + model_path = model_file.name + elif isinstance(model_file, str): + model_path = model_file + else: + return None, "Invalid model file" + + if not model_path or not os.path.exists(model_path): + return None, f"Model file not found: {model_path}" + + try: + if progress: + progress(0.1, "Loading models...") + + # Download pretrained assets (phone extractor + pitch estimator) + phone_extractor_path = download_beatrice_asset("phone_extractor") + pitch_estimator_path = download_beatrice_asset("pitch_estimator") + + # Build phone extractor + phone_extractor = PhoneExtractor().to(device).eval().requires_grad_(False) + pe_ckpt = torch.load(phone_extractor_path, map_location="cpu", weights_only=True) + phone_extractor.load_state_dict(pe_ckpt["phone_extractor"], strict=False) + del pe_ckpt + + # Build pitch estimator + pitch_estimator = PitchEstimator().to(device).eval().requires_grad_(False) + pi_ckpt = torch.load(pitch_estimator_path, map_location="cpu", weights_only=True) + pitch_estimator.load_state_dict(pi_ckpt["pitch_estimator"]) + del pi_ckpt + + if progress: + progress(0.3, "Loading trained model...") + + # Load trained checkpoint + with gzip.open(model_path, "rb") as f: + checkpoint = torch.load(f, map_location="cpu", weights_only=True) + + # Determine model params from checkpoint + n_speakers = checkpoint["net_g"]["embed_speaker.weight"].shape[0] + h = checkpoint.get("h", {}) + hidden_channels = h.get("hidden_channels", 256) + pitch_bins = h.get("pitch_bins", 448) + speakers = checkpoint.get("speakers", [f"Speaker {i}" for i in range(n_speakers)]) + + if target_speaker >= n_speakers: + target_speaker = 0 + + net_g = ConverterNetwork( + phone_extractor, pitch_estimator, + n_speakers=n_speakers, + pitch_bins=pitch_bins, + hidden_channels=hidden_channels, + vq_topk=h.get("vq_topk", 4), + training_time_vq=h.get("training_time_vq", "none"), + phone_noise_ratio=h.get("phone_noise_ratio", 0.5), + floor_noise_level=h.get("floor_noise_level", 1e-3), + ).to(device).eval() + + net_g.load_state_dict(checkpoint["net_g"], strict=False) + net_g.enable_hook() + del checkpoint + + if progress: + progress(0.5, "Converting voice...") + + # Load audio at 16kHz + audio_path = source_audio if isinstance(source_audio, str) else source_audio + audio, sr = librosa.load(audio_path, sr=16000, mono=True) + audio_tensor = torch.from_numpy(audio).float().unsqueeze(0).unsqueeze(0).to(device) + + # Pad to multiple of 160 (phone extractor stride) + original_length = audio_tensor.shape[-1] + if original_length % 160 != 0: + pad_len = 160 - original_length % 160 + audio_tensor = F.pad(audio_tensor, (0, pad_len)) + + # Convert + with torch.inference_mode(): + y_hat = net_g( + audio_tensor, + torch.tensor([target_speaker], device=device), + torch.tensor([formant_shift], device=device), + torch.tensor([float(pitch_shift)], device=device), + ) + + # Output is 24kHz, trim to match input duration + output_length = original_length // 160 * 240 # 16kHz→24kHz frame ratio + output = y_hat.squeeze().cpu().numpy()[:output_length] + + # Save + fd, output_path = tempfile.mkstemp(suffix=".wav") + os.close(fd) + sf.write(output_path, output, 24000) + + # Cleanup + del net_g, phone_extractor, pitch_estimator + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + speaker_name = speakers[target_speaker] if target_speaker < len(speakers) else f"Speaker {target_speaker}" + return output_path, f"Converted using Beatrice v2 | 24kHz | Speaker: {speaker_name} | Pitch: {pitch_shift:+d} | Formant: {formant_shift:+.1f}" + + except Exception as e: + logger.exception("Beatrice inference error") + return None, f"Error: {str(e)}" + + +# ============================================================ +# GRADIO UI - Gradio 6 Compatible +# ============================================================ + +def train_ui( + audio_file, + model_name: str, + epochs: int, + batch_size: int, + sample_rate: int, + f0_method: str = "rmvpe", + progress=gr.Progress() +): + """Training function for Gradio UI - Generator for live log updates""" + if audio_file is None: + yield None, None, "❌ Please upload training audio" + return + + if not model_name or model_name.strip() == "": + yield None, None, "❌ Please enter a model name" + return + + # Check if CUDA available + has_cuda = torch.cuda.is_available() + device_info = "GPU (CUDA)" if has_cuda else "CPU" + + # Log accumulator for live updates + logs = [] + + try: + model_name = sanitize_model_name(model_name) + output_dir = f"trained_models/{model_name}" + data_dir = f"{output_dir}/data" + + # Preprocessing phase + logs.append(f"🚀 Starting on {device_info}") + logs.append(f"📂 Output: {output_dir}") + yield None, None, "\n".join(logs) + + progress(0.1, "Preprocessing...") + logs.append("🔄 Preprocessing audio...") + yield None, None, "\n".join(logs) + + audio_path = audio_file if isinstance(audio_file, str) else audio_file.name + result = preprocess_audio_for_training(audio_path, data_dir, target_sr=sample_rate, f0_method=f0_method) + + if result is None: + logs.append("❌ Preprocessing failed - no valid audio chunks") + yield None, None, "\n".join(logs) + return + + logs.append("✅ Preprocessing complete") + logs.append(f"🏋️ Training {epochs} epochs...") + logs.append("─" * 40) + yield None, None, "\n".join(logs) + + # Training phase - iterate over generator for live updates + ckpt = None + idx = None + for msg, path, index in train_rvc_generator( + data_dir=data_dir, + output_dir=output_dir, + epochs=epochs, + batch_size=batch_size, + lr=1e-5, + target_sr=sample_rate, + progress_callback=progress + ): + logs.append(msg) + yield None, None, "\n".join(logs) + if path: + ckpt = path + if index: + idx = index + + if ckpt: + logs.append("─" * 40) + logs.append(f"✅ Training complete!") + logs.append(f"📦 Model: {ckpt}") + if idx: + logs.append(f"📦 Index: {idx}") + progress(1.0, "Done!") + yield ckpt, idx, "\n".join(logs) + else: + logs.append("❌ Training failed") + yield None, None, "\n".join(logs) + + except Exception as e: + logger.exception("Training error") + logs.append(f"❌ Error: {str(e)}") + yield None, None, "\n".join(logs) + +with gr.Blocks() as demo: + gr.Markdown(f"# 🎤 Voice Conversion (RVC + Beatrice)\nInference: CPU • Training: {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}") + + with gr.Tabs(): + # ==================== TAB 1: VOICE CONVERSION ==================== + with gr.Tab("🎵 Voice Conversion"): + with gr.Row(): + with gr.Column(): + source_audio = gr.Audio(label="Source Audio", type="filepath") + + gr.Markdown("### Model") + model_type = gr.Radio( + ["RVC v2", "Beatrice v2"], value="RVC v2", + label="Model Type", + info="RVC: .pth files | Beatrice: .pt.gz files" + ) + + # RVC model inputs + with gr.Group(visible=True) as rvc_model_group: + with gr.Row(): + model_file = gr.File(label="RVC Model (.pth)", file_types=[".pth"]) + load_example_btn = gr.Button("Load Example (Benee)", size="sm") + index_file = gr.File(label="Index File (.index) - Optional", file_types=[".index"]) + + # Beatrice model inputs + with gr.Group(visible=False) as beatrice_model_group: + beatrice_model_file = gr.File(label="Beatrice Model (.pt.gz)", file_types=[".gz"]) + with gr.Row(): + beatrice_target_speaker = gr.Number(value=0, label="Target Speaker", precision=0) + beatrice_formant_shift = gr.Slider(-2, 2, value=0.0, step=0.5, label="Formant Shift") + + with gr.Row(): + pitch_shift = gr.Slider(-12, 12, value=0, step=1, label="Pitch (semitones)") + f0_method = gr.Radio(["rmvpe", "pm", "harvest"], value="rmvpe", label="F0 Method", visible=True) + + with gr.Row(visible=True) as rvc_extra_options: + index_rate = gr.Slider(0, 1, value=0.75, step=0.05, label="Index Rate") + protect = gr.Slider(0, 0.5, value=0.33, step=0.01, label="Protect (voiceless consonants)") + convert_btn = gr.Button("Convert", variant="primary") + + with gr.Column(): + output_audio = gr.Audio(label="Converted Audio", type="filepath") + output_info = gr.Textbox(label="Status", lines=2) + + def update_model_type(model_type_val): + is_rvc = model_type_val == "RVC v2" + return ( + gr.update(visible=is_rvc), # rvc_model_group + gr.update(visible=not is_rvc), # beatrice_model_group + gr.update(visible=is_rvc), # f0_method + gr.update(visible=is_rvc), # rvc_extra_options + ) + + model_type.change( + update_model_type, + [model_type], + [rvc_model_group, beatrice_model_group, f0_method, rvc_extra_options] + ) + + load_example_btn.click( + load_example_model, + [], + [model_file, index_file, output_info] + ) + + def convert_unified(source, m_type, rvc_model, rvc_index, beat_model, + beat_speaker, beat_formant, pitch, f0, idx_rate, prot, + progress=gr.Progress()): + if m_type == "RVC v2": + return convert_voice(source, rvc_model, rvc_index, pitch, f0, idx_rate, prot, progress) + else: + return convert_voice_beatrice( + source, beat_model, + target_speaker=int(beat_speaker), + pitch_shift=int(pitch), + formant_shift=float(beat_formant), + progress=progress + ) + + convert_btn.click( + convert_unified, + [source_audio, model_type, model_file, index_file, beatrice_model_file, + beatrice_target_speaker, beatrice_formant_shift, + pitch_shift, f0_method, index_rate, protect], + [output_audio, output_info], + api_name="convert", + concurrency_limit=1, + ) + + gr.Markdown("**Models:** [HuggingFace](https://huggingface.co/models?search=rvc) | [Weights.gg](https://weights.gg)") + + # ==================== TAB 2: TRAINING ==================== + with gr.Tab("🏋️ Training"): + # GPU Warning + gpu_status = "🟢 GPU Available (CUDA)" if torch.cuda.is_available() else "🟡 CPU Only (Training will be slow)" + gr.Markdown(f""" + ### Training Status: {gpu_status} + + {'**GPU detected!** Training will use CUDA acceleration.' if torch.cuda.is_available() else '**No GPU detected.** Training will run on CPU (~30 sec/epoch). For faster training, run locally with CUDA GPU.'} + """) + + # Trainer selector - Beatrice always available (downloads assets from HF) + trainer_selector = gr.Dropdown( + choices=["RVC v2", "Beatrice v2"], + value="RVC v2", + label="Trainer", + info="RVC v2: general purpose | Beatrice v2: low latency (~50ms)" + ) + + with gr.Row(): + with gr.Column(): + train_audio = gr.Audio(label="Training Audio", type="filepath") + train_model_name = gr.Textbox(label="Model Name", placeholder="my_voice", value="my_voice") + + # RVC-specific options + with gr.Group(visible=True) as rvc_options: + with gr.Row(): + train_epochs = gr.Slider(1, 500, value=50, step=1, label="Epochs") + train_batch = gr.Slider(1, 8, value=2, step=1, label="Batch Size") + with gr.Row(): + train_sr = gr.Radio([32000, 40000, 48000], value=40000, label="Sample Rate") + train_f0 = gr.Radio(["rmvpe", "pm", "harvest"], value="rmvpe", label="F0 Method") + + # Beatrice-specific options + with gr.Group(visible=False) as beatrice_options: + with gr.Row(): + beatrice_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs (30 recommended)") + beatrice_batch = gr.Slider(1, 64, value=8, step=1, label="Batch Size") + beatrice_resume = gr.Checkbox(label="Resume from checkpoint", value=False) + + train_btn = gr.Button("Start Training", variant="primary") + + with gr.Column(): + train_output_model = gr.File(label="Trained Model (.pth)") + train_output_index = gr.File(label="Index File (.index)") + train_status = gr.Textbox(label="Training Status", lines=6) + + # Toggle visibility based on trainer selection + def update_trainer_options(trainer): + is_rvc = trainer == "RVC v2" + return gr.update(visible=is_rvc), gr.update(visible=not is_rvc) + + trainer_selector.change( + update_trainer_options, + [trainer_selector], + [rvc_options, beatrice_options] + ) + + # Unified training function + def train_unified(trainer, audio, name, rvc_epochs, rvc_batch, rvc_sr, rvc_f0, + beat_epochs, beat_batch, beat_resume, progress=gr.Progress()): + if trainer == "RVC v2": + # Use generator for live updates (yields 3-tuples: model, index, status) + for result in train_ui(audio, name, rvc_epochs, rvc_batch, rvc_sr, rvc_f0, progress): + yield result + else: + # Beatrice training (embedded - downloads assets from HF) + if audio is None: + yield None, None, "❌ Please upload training audio" + return + if not name or name.strip() == "": + yield None, None, "❌ Please enter a model name" + return + + try: + name = sanitize_model_name(name) + output_dir = f"trained_models/beatrice_{name}" + data_dir = f"{output_dir}/training_data" + logs = [] + + # Preprocess audio into speaker chunks + logs.append(f"🚀 Starting Beatrice training on {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}") + yield None, None, "\n".join(logs) + + progress(0.05, "Preprocessing audio...") + audio_path = audio if isinstance(audio, str) else audio.name + chunks, _ = preprocess_audio_for_beatrice(audio_path, data_dir, name) + + if chunks == 0: + yield None, None, "❌ Preprocessing failed - no valid audio chunks" + return + + logs.append(f"✅ Preprocessed {chunks} audio chunks") + logs.append("─" * 40) + yield None, None, "\n".join(logs) + + # Train using embedded generator + ckpt = None + for msg, path in train_beatrice_generator( + data_dir=data_dir, + output_dir=output_dir, + epochs=beat_epochs, + batch_size=beat_batch, + resume=beat_resume, + progress_callback=progress + ): + logs.append(msg) + yield None, None, "\n".join(logs) + if path: + ckpt = path + + if ckpt: + logs.append("─" * 40) + logs.append(f"✅ Training complete!") + logs.append(f"📦 Model: {ckpt}") + progress(1.0, "Done!") + yield ckpt, None, "\n".join(logs) + else: + logs.append("❌ Training failed") + yield None, None, "\n".join(logs) + + except Exception as e: + logger.exception("Beatrice training error") + yield None, None, f"❌ Error: {str(e)}" + + train_btn.click( + train_unified, + [trainer_selector, train_audio, train_model_name, + train_epochs, train_batch, train_sr, train_f0, + beatrice_epochs, beatrice_batch, beatrice_resume], + [train_output_model, train_output_index, train_status], + api_name="train", + concurrency_limit=1, + ) + + gr.Markdown(""" + --- + ### Training Tips + + **RVC v2:** + - 50-100 epochs for quick test, 200-500 for quality + - CPU training: ~30 sec/epoch (100 epochs ≈ 50 min) + + **Beatrice v2:** + - 20-50 epochs recommended, GPU recommended (CPU works but slow) + - Pretrained assets downloaded automatically from HuggingFace + - Output: .pt.gz checkpoint (use in Voice Conversion tab) + - Lower latency (~50ms vs ~100ms) + + ### CLI + ```bash + python app.py train -a voice.mp3 -o ./model --epochs 100 + python app.py train-beatrice -a voice.mp3 -o ./beatrice_model --epochs 30 + python app.py infer -i input.wav -m beatrice_model.pt.gz -o output.wav + ``` + """) + + +def cli_convert(args): + """CLI mode conversion - supports both RVC and Beatrice models""" + print(f"Converting: {args.input}") + print(f"Model: {args.model}") + + # Auto-detect model type from extension or --type flag + model_type = getattr(args, 'type', None) + model_path = args.model + index_path = getattr(args, 'index', None) + + if args.example: + print("Downloading example model...") + model_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_MODEL_FILE) + index_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_INDEX_FILE) + model_type = "rvc" + print(f"Model: {model_path}") + + if not model_path: + print("Error: No model specified. Use -m MODEL.pth or --example") + sys.exit(1) + + # Auto-detect type from extension if not specified + if model_type is None: + if model_path.endswith('.pt.gz') or model_path.endswith('.gz'): + model_type = "beatrice" + else: + model_type = "rvc" + + if model_type == "beatrice": + # Beatrice inference + print(f"Using Beatrice v2 inference") + output_path, status = convert_voice_beatrice( + source_audio=args.input, + model_file=model_path, + target_speaker=getattr(args, 'speaker', 0), + pitch_shift=args.pitch, + formant_shift=getattr(args, 'formant_shift', 0.0), + ) + else: + # RVC inference + class FileObj: + def __init__(self, path): + self.name = path + + model_file = FileObj(model_path) + index_file = FileObj(index_path) if index_path else None + + output_path, status = convert_voice( + source_audio=args.input, + model_file=model_file, + index_file=index_file, + pitch_shift=args.pitch, + f0_method=args.f0, + index_rate=args.index_rate, + progress=lambda *a, **k: None + ) + + if output_path: + shutil.copy(output_path, args.output) + print(f"Output: {args.output}") + print(status) + else: + print(f"Failed: {status}") + sys.exit(1) + +def cli_train_beatrice(args): + """CLI Beatrice training mode (embedded - downloads assets from HF)""" + print(f"=== Beatrice v2 Training (Embedded) ===") + print(f"Input audio: {args.audio}") + print(f"Output dir: {args.output}") + + # Preprocess audio - Beatrice expects: data_dir/speaker_name/*.wav + data_dir = os.path.join(args.output, "training_data") + speaker_name = os.path.splitext(os.path.basename(args.audio))[0] + + print(f"\n[1/2] Preprocessing audio for Beatrice...") + chunks, _ = preprocess_audio_for_beatrice(args.audio, data_dir, speaker_name) + + if chunks == 0: + print("Preprocessing failed - no valid audio chunks") + sys.exit(1) + + print(f"Created {chunks} audio chunks") + + # Train using embedded code + print(f"\n[2/2] Training Beatrice model ({args.epochs} epochs)...") + ckpt = None + for msg, path in train_beatrice_generator( + data_dir=data_dir, + output_dir=args.output, + epochs=args.epochs, + batch_size=args.batch, + resume=args.resume, + ): + print(msg) + if path: + ckpt = path + + if ckpt: + print(f"\nTraining complete!") + print(f"Model: {ckpt}") + else: + print("Training failed!") + sys.exit(1) + +def cli_train(args): + """CLI training mode""" + print(f"=== RVC Training ===") + print(f"Input audio: {args.audio}") + print(f"Output dir: {args.output}") + + # Create temp dir for preprocessing + data_dir = f"{args.output}/data" + + # Preprocess + print("\n[1/2] Preprocessing audio...") + result = preprocess_audio_for_training(args.audio, data_dir, target_sr=args.sr, f0_method=args.f0) + if result is None: + print("Preprocessing failed!") + sys.exit(1) + + # Train + print(f"\n[2/2] Training for {args.epochs} epochs...") + ckpt, idx = train_rvc( + data_dir=data_dir, + output_dir=args.output, + epochs=args.epochs, + batch_size=args.batch, + lr=args.lr, + target_sr=args.sr + ) + + if ckpt: + print(f"\nTraining complete!") + print(f"Model saved: {ckpt}") + if idx: + print(f"Index saved: {idx}") + else: + print("Training failed!") + sys.exit(1) + +if __name__ == "__main__": + # Check if any CLI args (besides script name) + if len(sys.argv) > 1: + parser = argparse.ArgumentParser( + description="RVC Voice Conversion - Inference and Training", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + subparsers = parser.add_subparsers(dest="command", help="Commands") + + # Inference subcommand + infer_parser = subparsers.add_parser("infer", help="Voice conversion inference (RVC + Beatrice)", + epilog=""" +Examples: + python app.py infer -i voice.wav -m model.pth -o output.wav + python app.py infer -i voice.wav -m beatrice_model.pt.gz -o output.wav --type beatrice + python app.py infer -i voice.wav --example -o output.wav + """) + infer_parser.add_argument("-i", "--input", required=True, help="Input audio file") + infer_parser.add_argument("-o", "--output", required=True, help="Output audio file") + infer_parser.add_argument("-m", "--model", help="Model file (.pth for RVC, .pt.gz for Beatrice)") + infer_parser.add_argument("--type", choices=["rvc", "beatrice"], default=None, help="Model type (auto-detected from extension)") + infer_parser.add_argument("--index", help="Index file (.index) - RVC only") + infer_parser.add_argument("--example", action="store_true", help="Use example model (Benee-RVC)") + infer_parser.add_argument("-p", "--pitch", type=int, default=0, help="Pitch shift (-12 to 12)") + infer_parser.add_argument("--f0", choices=["rmvpe", "pm", "harvest"], default="rmvpe", help="F0 method (RVC only)") + infer_parser.add_argument("--index-rate", type=float, default=0.75, help="Index rate 0-1 (RVC only)") + infer_parser.add_argument("--speaker", type=int, default=0, help="Target speaker index (Beatrice only)") + infer_parser.add_argument("--formant-shift", type=float, default=0.0, help="Formant shift -2 to 2 (Beatrice only)") + + # Training subcommand (RVC) + train_parser = subparsers.add_parser("train", help="Train RVC model", + epilog=""" +Examples: + python app.py train -a voice.mp3 -o ./my_model --epochs 5 + python app.py train -a dataset.wav -o ./trained --epochs 10 --batch 4 + """) + train_parser.add_argument("-a", "--audio", required=True, help="Training audio file") + train_parser.add_argument("-o", "--output", required=True, help="Output directory for model") + train_parser.add_argument("--epochs", type=int, default=5, help="Number of epochs (default: 5)") + train_parser.add_argument("--batch", type=int, default=2, help="Batch size (default: 2)") + train_parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate (default: 1e-5)") + train_parser.add_argument("--sr", type=int, default=40000, help="Sample rate (default: 40000)") + train_parser.add_argument("--f0", choices=["rmvpe", "pm", "harvest"], default="rmvpe", help="F0 method (default: rmvpe)") + + # Beatrice v2 Training subcommand + beatrice_parser = subparsers.add_parser("train-beatrice", help="Train Beatrice v2 model (downloads assets from HF)", + epilog=""" +Examples: + python app.py train-beatrice -a voice.mp3 -o ./beatrice_model --epochs 20 + python app.py train-beatrice -a dataset.wav -o ./trained --epochs 50 --batch 8 --resume + +Pretrained assets are downloaded automatically from HuggingFace. + """) + beatrice_parser.add_argument("-a", "--audio", required=True, help="Training audio file") + beatrice_parser.add_argument("-o", "--output", required=True, help="Output directory for model") + beatrice_parser.add_argument("--epochs", type=int, default=20, help="Number of epochs (default: 20)") + beatrice_parser.add_argument("--batch", type=int, default=24, help="Batch size (default: 24, reduce for less VRAM)") + beatrice_parser.add_argument("--resume", action="store_true", help="Resume from checkpoint") + + args = parser.parse_args() + + if args.command == "infer": + cli_convert(args) + elif args.command == "train": + cli_train(args) + elif args.command == "train-beatrice": + cli_train_beatrice(args) + else: + parser.print_help() + else: + # No args = Gradio mode (Gradio 6 syntax) + demo.launch( + mcp_server=True, + show_error=True, + ssr_mode=False, + )