Zsage0's picture
Upload app.py
3fd8c76 verified
"""
RVC + Beatrice v2 Voice Conversion - Single-file app for HuggingFace Spaces
RVC-Project + Beatrice v2 (fierce-cats/beatrice-trainer), consolidated into single file
- Inference: RVC v2 (.pth) + Beatrice v2 (.pt.gz), CPU or GPU
- Training: RVC v2 + Beatrice v2, GPU recommended
Usage:
CLI: python app.py infer -i input.wav -m model.pth -o output.wav
python app.py infer -i input.wav -m beatrice.pt.gz -o output.wav
Gradio: python app.py
"""
import os
import sys
# MPS fallback for macOS
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import argparse
import gc
import gzip
import json as json_module
import logging
import math
import re
import shutil
import tempfile
import warnings
# Suppress known harmless warnings from HF Spaces / torch internals
warnings.filterwarnings("ignore", message=".*torch.distributed.reduce_op.*", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*torch.nn.utils.weight_norm.*", category=FutureWarning)
from collections import defaultdict
from fractions import Fraction
from functools import partial
from pathlib import Path
from random import Random
from typing import Optional, List, Tuple, Union, BinaryIO, Literal, Sequence, Iterable, Callable
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm
import librosa
import pyworld
import soundfile as sf
import torchaudio
from scipy import signal
from huggingface_hub import hf_hub_download
from tqdm.auto import tqdm
# 48 Hz high-pass filter to remove low-frequency artifacts (same as Applio)
FILTER_ORDER = 5
CUTOFF_FREQUENCY = 48 # Hz
SAMPLE_RATE = 16000 # Hz
bh, ah = signal.butter(N=FILTER_ORDER, Wn=CUTOFF_FREQUENCY, btype="high", fs=SAMPLE_RATE)
def sanitize_model_name(name: str) -> str:
"""Sanitize model name for safe use in file paths"""
name = os.path.basename(name.strip())
name = re.sub(r'[^\w\-.]', '_', name)
return name or "unnamed_model"
def list_rvc_models() -> list:
"""Scan the weights/ directory and return a sorted list of .pth model filenames."""
weights_dir = Path("weights")
if not weights_dir.exists():
return []
return sorted([p.name for p in weights_dir.glob("*.pth")])
# Default example model
DEFAULT_MODEL_REPO = "audo/Benee-RVC"
DEFAULT_MODEL_FILE = "BENEE8000.pth"
DEFAULT_INDEX_FILE = "added_IVF1054_Flat_nprobe_8.index"
# RVC v2 pretrained weights from official repo
RVC_PRETRAINED_REPO = "lj1995/VoiceConversionWebUI"
RVC_PRETRAINED_V2 = {
# Generator with f0 (pitch)
"f0G48k": "pretrained_v2/f0G48k.pth",
"f0G40k": "pretrained_v2/f0G40k.pth",
"f0G32k": "pretrained_v2/f0G32k.pth",
# Discriminator with f0
"f0D48k": "pretrained_v2/f0D48k.pth",
"f0D40k": "pretrained_v2/f0D40k.pth",
"f0D32k": "pretrained_v2/f0D32k.pth",
# Generator without f0
"G48k": "pretrained_v2/G48k.pth",
"G40k": "pretrained_v2/G40k.pth",
"G32k": "pretrained_v2/G32k.pth",
# Discriminator without f0
"D48k": "pretrained_v2/D48k.pth",
"D40k": "pretrained_v2/D40k.pth",
"D32k": "pretrained_v2/D32k.pth",
}
def download_pretrained_rvc(name: str) -> str:
"""Download RVC v2 pretrained weights from HuggingFace"""
if name not in RVC_PRETRAINED_V2:
raise ValueError(f"Unknown pretrained: {name}. Available: {list(RVC_PRETRAINED_V2.keys())}")
filepath = RVC_PRETRAINED_V2[name]
logger.info(f"Downloading pretrained {name} from {RVC_PRETRAINED_REPO}...")
return hf_hub_download(repo_id=RVC_PRETRAINED_REPO, filename=filepath)
# Beatrice v2 pretrained assets
BEATRICE_REPO = "fierce-cats/beatrice-trainer"
BEATRICE_PRETRAINED = {
"phone_extractor": "assets/pretrained/122_checkpoint_03000000.pt",
"pitch_estimator": "assets/pretrained/104_3_checkpoint_00300000.pt",
"pretrained_model": "assets/pretrained/151_checkpoint_libritts_r_200_02750000.pt.gz",
}
def download_beatrice_asset(name: str) -> str:
"""Download Beatrice v2 pretrained asset from HuggingFace"""
if name not in BEATRICE_PRETRAINED:
raise ValueError(f"Unknown asset: {name}. Available: {list(BEATRICE_PRETRAINED.keys())}")
filepath = BEATRICE_PRETRAINED[name]
logger.info(f"Downloading Beatrice asset {name} from {BEATRICE_REPO}...")
return hf_hub_download(repo_id=BEATRICE_REPO, filename=filepath)
def download_beatrice_augmentation():
"""Download Beatrice augmentation assets (noise + IR) - optional for training"""
try:
from huggingface_hub import snapshot_download
cache_dir = snapshot_download(repo_id=BEATRICE_REPO, allow_patterns=["assets/noise/*", "assets/ir/*"])
noise_dir = os.path.join(cache_dir, "assets", "noise")
ir_dir = os.path.join(cache_dir, "assets", "ir")
if os.path.isdir(noise_dir) and os.path.isdir(ir_dir):
return noise_dir, ir_dir
return None, None
except Exception as e:
logger.warning(f"Could not download augmentation assets: {e}")
return None, None
def load_pretrained_weights(model: nn.Module, pretrained_path: str) -> None:
"""Load pretrained weights into model, handling speaker embedding mismatch"""
logger.info(f"Loading pretrained weights: {pretrained_path}")
state_dict = torch.load(pretrained_path, map_location="cpu", weights_only=True)
# Handle different checkpoint formats
if "model" in state_dict:
state_dict = state_dict["model"]
# Filter out mismatched keys, but handle emb_g specially
model_state = model.state_dict()
filtered_state = {}
skipped = []
for k, v in state_dict.items():
if k in model_state:
if v.shape == model_state[k].shape:
filtered_state[k] = v
elif k == "emb_g.weight":
# Initialize our speaker embedding with mean of pretrained embeddings
# This gives a much better starting point than random initialization
mean_emb = v.mean(dim=0, keepdim=True) # [1, 256]
num_speakers = model_state[k].shape[0]
filtered_state[k] = mean_emb.expand(num_speakers, -1).clone()
logger.info(f"Initialized emb_g from pretrained mean ({v.shape[0]} -> {num_speakers} speakers)")
else:
skipped.append(f"{k}: {v.shape} vs {model_state[k].shape}")
else:
skipped.append(f"{k}: not in model")
if skipped:
logger.info(f"Skipped {len(skipped)} mismatched keys")
model.load_state_dict(filtered_state, strict=False)
logger.info(f"Loaded {len(filtered_state)}/{len(state_dict)} pretrained weights")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Device selection:
# - Inference: Always CPU (HF Spaces free tier, also works everywhere)
# - Training: GPU if available for speed, CPU fallback
device = torch.device("cpu") # For inference
train_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # For training
logger.info(f"Inference device: {device}")
logger.info(f"Training device: {train_device}")
# ============================================================
# CPU OPTIMIZATION — Locked 2-core HuggingFace Spaces config
# ============================================================
# Restrict PyTorch to exactly 2 physical cores.
# OpenMP and MKL must both be capped before any tensor ops fire.
_CPU_CORES = 2
torch.set_num_threads(_CPU_CORES)
torch.set_num_interop_threads(_CPU_CORES)
os.environ["OMP_NUM_THREADS"] = str(_CPU_CORES)
os.environ["MKL_NUM_THREADS"] = str(_CPU_CORES)
os.environ["OPENBLAS_NUM_THREADS"] = str(_CPU_CORES)
os.environ["VECLIB_MAXIMUM_THREADS"] = str(_CPU_CORES)
os.environ["NUMEXPR_NUM_THREADS"] = str(_CPU_CORES)
# torch.inference_mode is heavier than no_grad but also frees the
# autograd graph eagerly, which helps on a memory-constrained CPU.
# Enable oneDNN graph fusion (fuses conv+bn, linear+relu etc. into
# single kernels — measurable speedup on Intel Xeon VMs).
torch.backends.mkldnn.enabled = True
try:
torch.jit.enable_onednn_fusion(True)
except Exception:
pass
logger.info(f"PyTorch CPU threads: {torch.get_num_threads()} (interop={torch.get_num_interop_threads()})")
# ============================================================
# MEMORY MANAGEMENT — purge_memory()
# Call this between every heavy operation to prevent OOM on
# the 16 GB HuggingFace Spaces free-tier CPU instance.
# ============================================================
import ctypes, platform
def purge_memory(*tensors_or_arrays):
"""
Aggressively free memory after a heavy generation step.
Pass any tensors / numpy arrays that should be deleted.
The function:
1. Deletes every passed object from caller scope.
2. Runs Python gc (two passes: first collects cycles,
second collects anything the first pass freed).
3. On Linux (HuggingFace Spaces), calls malloc_trim(0) via
ctypes so glibc returns freed pages to the OS immediately.
Without this, RSS can stay high even after gc.collect().
4. Clears CUDA cache if a GPU is somehow available.
"""
for obj in tensors_or_arrays:
try:
del obj
except Exception:
pass
# Two-pass gc: cycles first, then their referents
gc.collect()
gc.collect()
# Return glibc memory to the OS (Linux only — HF Spaces is Linux)
if platform.system() == "Linux":
try:
ctypes.CDLL("libc.so.6").malloc_trim(0)
except Exception:
pass
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# ============================================================
# COMMONS - Helper functions from infer/lib/infer_pack/commons.py
# ============================================================
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
return t_act * s_act
def slice_segments(x, ids_str, segment_size=4):
"""Slice segments from tensor"""
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def slice_segments2(x, ids_str, segment_size=4):
"""Slice segments from 2D tensor"""
ret = torch.zeros_like(x[:, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
"""Random slice segments"""
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=1)
ids_str = (torch.rand([b], device=x.device) * ids_str_max.float()).long()
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
# ============================================================
# MODULES - From infer/lib/infer_pack/modules.py
# ============================================================
LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class WN(nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
super().__init__()
assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels
self.kernel_size = (kernel_size,)
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.p_dropout = float(p_dropout)
self.in_layers = nn.ModuleList()
self.res_skip_layers = nn.ModuleList()
self.drop = nn.Dropout(float(p_dropout))
if gin_channels != 0:
cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
self.cond_layer = weight_norm(cond_layer, name="weight")
for i in range(n_layers):
dilation = dilation_rate ** i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding)
in_layer = weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
for i, (in_layer, res_skip_layer) in enumerate(zip(self.in_layers, self.res_skip_layers)):
x_in = in_layer(x)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
acts = self.drop(acts)
res_skip_acts = res_skip_layer(acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:, :self.hidden_channels, :]
x = (x + res_acts) * x_mask
output = output + res_skip_acts[:, self.hidden_channels:, :]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.gin_channels != 0:
remove_weight_norm(self.cond_layer)
for l in self.in_layers:
remove_weight_norm(l)
for l in self.res_skip_layers:
remove_weight_norm(l)
class ResBlock1(nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]))),
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
])
self.convs2.apply(init_weights)
self.lrelu_slope = LRELU_SLOPE
def forward(self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, self.lrelu_slope)
if x_mask is not None:
xt = xt * x_mask
xt = c1(xt)
xt = F.leaky_relu(xt, self.lrelu_slope)
if x_mask is not None:
xt = xt * x_mask
xt = c2(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super().__init__()
self.convs = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))),
])
self.convs.apply(init_weights)
self.lrelu_slope = LRELU_SLOPE
def forward(self, x, x_mask: Optional[torch.Tensor] = None):
for c in self.convs:
xt = F.leaky_relu(x, self.lrelu_slope)
if x_mask is not None:
xt = xt * x_mask
xt = c(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Flip(nn.Module):
def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False):
x = torch.flip(x, [1])
if not reverse:
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
return x, logdet
else:
return x, torch.zeros([1], device=x.device)
class ResidualCouplingLayer(nn.Module):
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False):
assert channels % 2 == 0
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.half_channels = channels // 2
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=float(p_dropout), gin_channels=gin_channels)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
logs = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x, torch.zeros([1])
def remove_weight_norm(self):
self.enc.remove_weight_norm()
# ============================================================
# ATTENTIONS - From infer/lib/infer_pack/attentions.py
# ============================================================
class MultiHeadAttention(nn.Module):
def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, proximal_bias=False, proximal_init=False):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.p_dropout = p_dropout
self.window_size = window_size
self.heads_share = heads_share
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels ** -0.5
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
nn.init.xavier_uniform_(self.conv_v.weight)
if proximal_init:
with torch.no_grad():
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, _ = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
b, d, t_s = key.size()
t_t = query.size(2)
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1)
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
return output, p_attn
def _matmul_with_relative_values(self, x, y):
return torch.matmul(x, y.unsqueeze(0))
def _matmul_with_relative_keys(self, x, y):
return torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
def _get_relative_embeddings(self, relative_embeddings, length):
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0])
else:
padded_relative_embeddings = relative_embeddings
return padded_relative_embeddings[:, slice_start_position:slice_end_position]
def _relative_position_to_absolute_position(self, x):
batch, heads, length, _ = x.size()
x = F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0])
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(x_flat, [0, int(length) - 1, 0, 0, 0, 0])
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
return x_final
def _absolute_position_to_relative_position(self, x):
batch, heads, length, _ = x.size()
x = F.pad(x, [0, int(length) - 1, 0, 0, 0, 0, 0, 0])
x_flat = x.view([batch, heads, int(length ** 2) + int(length * (length - 1))])
x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0])
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length):
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.causal = causal
self.is_activation = activation == "gelu"
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
self.drop = nn.Dropout(p_dropout)
def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
x = self.conv_1(self._padding(x, x_mask))
if self.is_activation:
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(self._padding(x, x_mask))
return x * x_mask
def _padding(self, x, x_mask):
if self.causal:
if self.kernel_size == 1:
return x * x_mask
pad_l = self.kernel_size - 1
return F.pad(x * x_mask, [pad_l, 0, 0, 0, 0, 0])
else:
if self.kernel_size == 1:
return x * x_mask
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
return F.pad(x * x_mask, [pad_l, pad_r, 0, 0, 0, 0])
class Encoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, window_size=10, **kwargs):
super().__init__()
self.hidden_channels = hidden_channels
self.n_layers = int(n_layers)
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for attn, norm1, ffn, norm2 in zip(self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2):
y = attn(x, x, attn_mask)
y = self.drop(y)
x = norm1(x + y)
y = ffn(x, x_mask)
y = self.drop(y)
x = norm2(x + y)
return x * x_mask
# ============================================================
# MODELS - From infer/lib/infer_pack/models.py
# ============================================================
sr2sr = {"32k": 32000, "40k": 40000, "48k": 48000}
class TextEncoder256(nn.Module):
def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True):
super().__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.emb_phone = nn.Linear(256, hidden_channels)
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
if f0:
self.emb_pitch = nn.Embedding(256, hidden_channels)
self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout))
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor):
if pitch is None:
x = self.emb_phone(phone)
else:
x = self.emb_phone(phone) + self.emb_pitch(pitch)
x = x * math.sqrt(self.hidden_channels)
x = self.lrelu(x)
x = torch.transpose(x, 1, -1)
x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return m, logs, x_mask
class TextEncoder768(nn.Module):
def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True):
super().__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.emb_phone = nn.Linear(768, hidden_channels)
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
if f0:
self.emb_pitch = nn.Embedding(256, hidden_channels)
self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout))
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor):
if pitch is None:
x = self.emb_phone(phone)
else:
x = self.emb_phone(phone) + self.emb_pitch(pitch)
x = x * math.sqrt(self.hidden_channels)
x = self.lrelu(x)
x = torch.transpose(x, 1, -1)
x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return m, logs, x_mask
class ResidualCouplingBlock(nn.Module):
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
super().__init__()
self.n_flows = n_flows
self.flows = nn.ModuleList()
for i in range(n_flows):
self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
self.flows.append(Flip())
def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False):
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
else:
for flow in self.flows[::-1]:
x, _ = flow.forward(x, x_mask, g=g, reverse=reverse)
return x
def remove_weight_norm(self):
for i in range(self.n_flows):
self.flows[i * 2].remove_weight_norm()
class PosteriorEncoder(nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0):
super().__init__()
self.out_channels = out_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None):
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
def remove_weight_norm(self):
self.enc.remove_weight_norm()
class Generator(nn.Module):
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
super().__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2)))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock_class(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(init_weights)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
class SineGen(nn.Module):
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0):
super().__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.dim = harmonic_num + 1
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
def _f02uv(self, f0):
uv = torch.ones_like(f0)
uv = uv * (f0 > self.voiced_threshold)
return uv.float()
def forward(self, f0: torch.Tensor, upp: int):
with torch.no_grad():
f0 = f0[:, None].transpose(1, 2)
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
f0_buf[:, :, 0] = f0[:, :, 0]
for idx in range(self.harmonic_num):
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
rad_values = (f0_buf / self.sampling_rate) % 1
rand_ini = torch.rand(f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
tmp_over_one = torch.cumsum(rad_values, 1)
tmp_over_one *= upp
tmp_over_one = F.interpolate(tmp_over_one.transpose(2, 1), scale_factor=float(upp), mode="linear", align_corners=True).transpose(2, 1)
rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
tmp_over_one %= 1
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
cumsum_shift = torch.zeros_like(rad_values)
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi)
sine_waves = sine_waves * self.sine_amp
uv = self._f02uv(f0)
uv = F.interpolate(uv.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
class SourceModuleHnNSF(nn.Module):
def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0, is_half=False):
super().__init__()
self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
self.l_linear = nn.Linear(harmonic_num + 1, 1)
self.l_tanh = nn.Tanh()
def forward(self, x: torch.Tensor, upp: int = 1):
sine_wavs, uv, _ = self.l_sin_gen(x, upp)
sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
return sine_merge, None, None
class GeneratorNSF(nn.Module):
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, sr, is_half=False):
super().__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.f0_upsamp = nn.Upsample(scale_factor=math.prod(upsample_rates))
self.m_source = SourceModuleHnNSF(sampling_rate=sr, harmonic_num=0, is_half=is_half)
self.noise_convs = nn.ModuleList()
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
c_cur = upsample_initial_channel // (2 ** (i + 1))
self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2)))
if i + 1 < len(upsample_rates):
stride_f0 = math.prod(upsample_rates[i + 1:])
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
else:
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock_class(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(init_weights)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
self.upp = math.prod(upsample_rates)
self.lrelu_slope = LRELU_SLOPE
def forward(self, x, f0, g: Optional[torch.Tensor] = None):
har_source, _, _ = self.m_source(f0, self.upp)
har_source = har_source.transpose(1, 2)
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
if i < self.num_upsamples:
x = F.leaky_relu(x, self.lrelu_slope)
x = ups(x)
x_source = noise_convs(har_source)
x = x + x_source
xs = None
l = [i * self.num_kernels + j for j in range(self.num_kernels)]
for j, resblock in enumerate(self.resblocks):
if j in l:
if xs is None:
xs = resblock(x)
else:
xs += resblock(x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
# Synthesizer classes for different model versions
class SynthesizerTrnMs256NSFsid(nn.Module):
"""RVC v1 model with f0"""
def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, **kwargs):
super().__init__()
if isinstance(sr, str):
sr = sr2sr[sr]
self.segment_size = segment_size
self.gin_channels = gin_channels
self.spk_embed_dim = spk_embed_dim
self.enc_p = TextEncoder256(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout))
self.dec = GeneratorNSF(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, is_half=kwargs.get("is_half", False))
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
self.emb_g = nn.Embedding(spk_embed_dim, gin_channels)
@torch.jit.export
def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: torch.Tensor, nsff0: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None):
g = self.emb_g(sid).unsqueeze(-1)
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
if rate is not None:
head = int(z_p.shape[2] * (1 - rate.item()))
z_p = z_p[:, :, head:]
x_mask = x_mask[:, :, head:]
nsff0 = nsff0[:, head:]
z = self.flow(z_p, x_mask, g=g, reverse=True)
o = self.dec(z * x_mask, nsff0, g=g)
return o, x_mask, (z, z_p, m_p, logs_p)
class SynthesizerTrnMs768NSFsid(nn.Module):
"""RVC v2 model with f0"""
def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, **kwargs):
super().__init__()
if isinstance(sr, str):
sr = sr2sr[sr]
self.segment_size = segment_size
self.gin_channels = gin_channels
self.spk_embed_dim = spk_embed_dim
self.enc_p = TextEncoder768(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout))
self.dec = GeneratorNSF(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, is_half=kwargs.get("is_half", False))
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
self.emb_g = nn.Embedding(spk_embed_dim, gin_channels)
def forward(self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds):
"""Training forward pass"""
g = self.emb_g(ds).unsqueeze(-1)
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
pitchf = slice_segments2(pitchf, ids_slice, self.segment_size)
o = self.dec(z_slice, pitchf, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
@torch.jit.export
def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: torch.Tensor, nsff0: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None):
g = self.emb_g(sid).unsqueeze(-1)
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
if rate is not None:
head = int(z_p.shape[2] * (1.0 - rate.item()))
z_p = z_p[:, :, head:]
x_mask = x_mask[:, :, head:]
nsff0 = nsff0[:, head:]
z = self.flow(z_p, x_mask, g=g, reverse=True)
o = self.dec(z * x_mask, nsff0, g=g)
return o, x_mask, (z, z_p, m_p, logs_p)
class SynthesizerTrnMs256NSFsid_nono(nn.Module):
"""RVC v1 model without f0"""
def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr=None, **kwargs):
super().__init__()
self.segment_size = segment_size
self.gin_channels = gin_channels
self.enc_p = TextEncoder256(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), f0=False)
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
self.emb_g = nn.Embedding(spk_embed_dim, gin_channels)
@torch.jit.export
def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None):
g = self.emb_g(sid).unsqueeze(-1)
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
if rate is not None:
head = int(z_p.shape[2] * (1.0 - rate.item()))
z_p = z_p[:, :, head:]
x_mask = x_mask[:, :, head:]
z = self.flow(z_p, x_mask, g=g, reverse=True)
o = self.dec(z * x_mask, g=g)
return o, x_mask, (z, z_p, m_p, logs_p)
class SynthesizerTrnMs768NSFsid_nono(nn.Module):
"""RVC v2 model without f0"""
def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr=None, **kwargs):
super().__init__()
self.segment_size = segment_size
self.gin_channels = gin_channels
self.enc_p = TextEncoder768(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), f0=False)
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
self.emb_g = nn.Embedding(spk_embed_dim, gin_channels)
@torch.jit.export
def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, sid: torch.Tensor, rate: Optional[torch.Tensor] = None):
g = self.emb_g(sid).unsqueeze(-1)
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
if rate is not None:
head = int(z_p.shape[2] * (1.0 - rate.item()))
z_p = z_p[:, :, head:]
x_mask = x_mask[:, :, head:]
z = self.flow(z_p, x_mask, g=g, reverse=True)
o = self.dec(z * x_mask, g=g)
return o, x_mask, (z, z_p, m_p, logs_p)
# ============================================================
# DISCRIMINATOR - For training
# ============================================================
class DiscriminatorS(nn.Module):
def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else weight_norm
self.convs = nn.ModuleList([
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorP(nn.Module):
def __init__(self, period, use_spectral_norm=False):
super().__init__()
self.period = period
norm_f = nn.utils.spectral_norm if use_spectral_norm else weight_norm
self.convs = nn.ModuleList([
norm_f(nn.Conv2d(1, 32, (5, 1), (3, 1), padding=(2, 0))),
norm_f(nn.Conv2d(32, 128, (5, 1), (3, 1), padding=(2, 0))),
norm_f(nn.Conv2d(128, 512, (5, 1), (3, 1), padding=(2, 0))),
norm_f(nn.Conv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0))),
norm_f(nn.Conv2d(1024, 1024, (5, 1), 1, padding=(2, 0))),
])
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
b, c, t = x.shape
if t % self.period != 0:
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(nn.Module):
def __init__(self, use_spectral_norm=False):
super().__init__()
periods = [2, 3, 5, 7, 11, 17, 23, 37] # 8 periods for v2 pretrained (9 total discriminators)
self.discriminators = nn.ModuleList(
[DiscriminatorS(use_spectral_norm)] +
[DiscriminatorP(p, use_spectral_norm) for p in periods]
)
def forward(self, y, y_hat):
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
for d in self.discriminators:
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
# ============================================================
# TRAINING LOSSES
# ============================================================
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl.float().detach() - gl.float()))
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
loss += torch.mean((1 - dr.float()) ** 2) + torch.mean(dg.float() ** 2)
return loss
def generator_loss(disc_outputs):
loss = 0
for dg in disc_outputs:
loss += torch.mean((1 - dg.float()) ** 2)
return loss
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
z_p, logs_q, m_p, logs_p, z_mask = [x.float() for x in [z_p, logs_q, m_p, logs_p, z_mask]]
kl = logs_p - logs_q - 0.5 + 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
return torch.sum(kl * z_mask) / torch.sum(z_mask)
# ============================================================
# HUBERT EXTRACTION - Using torchaudio bundle
# ============================================================
# ContentVec model for v1 (256-dim) and HuBERT for v2 (768-dim)
_contentvec_model = None # For v1 models (256-dim output)
_hubert_model = None # For v2 models (768-dim output)
_hubert_bundle = None
CONTENTVEC_REPO = "IAHispano/Applio"
CONTENTVEC_MODEL = "Resources/embedders/contentvec/pytorch_model.bin"
CONTENTVEC_CONFIG = "Resources/embedders/contentvec/config.json"
def load_contentvec():
"""Load ContentVec model from HuggingFace for v1 models (256-dim output)"""
global _contentvec_model
if _contentvec_model is None:
try:
from transformers import HubertModel, HubertConfig
logger.info("Loading ContentVec model from HuggingFace...")
# Download model files
model_path = hf_hub_download(repo_id=CONTENTVEC_REPO, filename=CONTENTVEC_MODEL)
config_path = hf_hub_download(repo_id=CONTENTVEC_REPO, filename=CONTENTVEC_CONFIG)
# Create model with final_proj layer
class HubertModelWithFinalProj(HubertModel):
def __init__(self, config):
super().__init__(config)
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
config = HubertConfig.from_pretrained(config_path)
_contentvec_model = HubertModelWithFinalProj(config)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
_contentvec_model.load_state_dict(state_dict)
_contentvec_model.to(device).eval()
logger.info(f"ContentVec loaded: hidden={config.hidden_size}, proj={config.classifier_proj_size}")
except Exception as e:
logger.warning(f"Failed to load ContentVec: {e}, falling back to torchaudio HuBERT")
_contentvec_model = None
return _contentvec_model
def load_hubert():
"""Load HuBERT model via torchaudio for v2 models (768-dim output)"""
global _hubert_model, _hubert_bundle
if _hubert_model is None:
import torchaudio
logger.info("Loading HuBERT model via torchaudio...")
_hubert_bundle = torchaudio.pipelines.HUBERT_BASE
_hubert_model = _hubert_bundle.get_model().to(device)
_hubert_model.eval()
logger.info("HuBERT model loaded")
return _hubert_model, _hubert_bundle
def extract_hubert_features(audio: np.ndarray, sr: int = 16000, version: str = "v2") -> torch.Tensor:
"""Extract ContentVec features from audio (same as Applio)
v1 models: Use ContentVec with final_proj (256-dim)
v2 models: Use ContentVec without final_proj (768-dim)
"""
audio = audio.astype(np.float32)
if np.abs(audio).max() > 1.0:
audio = audio / np.abs(audio).max()
inputs = torch.from_numpy(audio).unsqueeze(0).to(device)
# Use ContentVec for ALL versions (same as Applio)
contentvec = load_contentvec()
if contentvec is not None:
with torch.no_grad():
output = contentvec(inputs)
if version == "v1":
# v1: use final_proj for 256-dim
feats = contentvec.final_proj(output.last_hidden_state)
else:
# v2: use raw hidden state (768-dim)
feats = output.last_hidden_state
return feats
# Fallback to torchaudio HuBERT if ContentVec not available
logger.warning("ContentVec not available, using torchaudio HuBERT (results may be degraded)")
hubert, bundle = load_hubert()
with torch.no_grad():
features, _ = hubert.extract_features(inputs)
layer_idx = 11 if version == "v2" else 8
feats = features[min(layer_idx, len(features)-1)]
if version == "v1":
proj = nn.Linear(768, 256, bias=False).to(device)
with torch.no_grad():
w = torch.zeros(256, 768)
for i in range(256):
w[i, i*3:(i+1)*3] = 1/3
proj.weight.copy_(w)
feats = proj(feats)
return feats
# ============================================================
# F0 EXTRACTION
# ============================================================
def extract_f0_pm(audio: np.ndarray, sr: int = 16000, f0_up_key: int = 0) -> Tuple[np.ndarray, np.ndarray]:
"""Extract F0 using parselmouth (pm method)"""
import parselmouth
p_len = audio.shape[0] // 160 + 1
f0_min = 65
f0_max = 1100
l_pad = int(np.ceil(1.5 / f0_min * 16000))
r_pad = l_pad + 1
s = parselmouth.Sound(np.pad(audio, (l_pad, r_pad)), 16000).to_pitch_ac(
time_step=0.01, voicing_threshold=0.6, pitch_floor=f0_min, pitch_ceiling=f0_max,
)
f0 = s.selected_array["frequency"]
if len(f0) < p_len:
f0 = np.pad(f0, (0, p_len - len(f0)))
f0 = f0[:p_len]
f0 *= pow(2, f0_up_key / 12)
return f0_to_coarse(f0)
def extract_f0_harvest(audio: np.ndarray, sr: int = 16000, f0_up_key: int = 0) -> Tuple[np.ndarray, np.ndarray]:
"""Extract F0 using pyworld harvest"""
import pyworld
from scipy import signal as scipy_signal
f0, t = pyworld.harvest(audio.astype(np.double), fs=16000, f0_ceil=1100, f0_floor=50, frame_period=10)
f0 = scipy_signal.medfilt(f0, 3)
f0 *= pow(2, f0_up_key / 12)
return f0_to_coarse(f0)
def f0_to_coarse(f0: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Convert f0 to coarse representation"""
f0_min = 50
f0_max = 1100
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
f0bak = f0.copy()
f0_mel = 1127 * np.log(1 + f0 / 700)
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1
f0_mel[f0_mel <= 1] = 1
f0_mel[f0_mel > 255] = 255
f0_coarse = np.rint(f0_mel).astype(np.int32)
return f0_coarse, f0bak
# ============================================================
# RMVPE F0 EXTRACTION (from Applio - IAHispano/Applio)
# ============================================================
class RMVPE_ConvBlockRes(nn.Module):
def __init__(self, in_channels, out_channels, momentum=0.01):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False),
nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(),
nn.Conv2d(out_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False),
nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(),
)
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) if in_channels != out_channels else None
def forward(self, x):
r = self.conv(x)
return r + self.shortcut(x) if self.shortcut else r + x
class RMVPE_ResEncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
super().__init__()
self.conv = nn.ModuleList([RMVPE_ConvBlockRes(in_channels, out_channels, momentum)])
for _ in range(n_blocks - 1):
self.conv.append(RMVPE_ConvBlockRes(out_channels, out_channels, momentum))
self.kernel_size = kernel_size
if kernel_size is not None:
self.pool = nn.AvgPool2d(kernel_size=kernel_size)
def forward(self, x):
for c in self.conv:
x = c(x)
return (x, self.pool(x)) if self.kernel_size is not None else x
class RMVPE_Encoder(nn.Module):
def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
super().__init__()
self.n_encoders = n_encoders
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
self.layers = nn.ModuleList()
for _ in range(n_encoders):
self.layers.append(RMVPE_ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum))
in_channels = out_channels
out_channels *= 2
in_size //= 2
self.out_size = in_size
self.out_channel = out_channels
def forward(self, x):
concat_tensors = []
x = self.bn(x)
for layer in self.layers:
t, x = layer(x)
concat_tensors.append(t)
return x, concat_tensors
class RMVPE_Intermediate(nn.Module):
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
super().__init__()
self.layers = nn.ModuleList([RMVPE_ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)])
for _ in range(n_inters - 1):
self.layers.append(RMVPE_ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class RMVPE_ResDecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
super().__init__()
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
self.conv1 = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, (3, 3), stride, (1, 1), out_padding, bias=False),
nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(),
)
self.conv2 = nn.ModuleList([RMVPE_ConvBlockRes(out_channels * 2, out_channels, momentum)])
for _ in range(n_blocks - 1):
self.conv2.append(RMVPE_ConvBlockRes(out_channels, out_channels, momentum))
def forward(self, x, concat_tensor):
x = self.conv1(x)
x = torch.cat((x, concat_tensor), dim=1)
for c in self.conv2:
x = c(x)
return x
class RMVPE_Decoder(nn.Module):
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
super().__init__()
self.layers = nn.ModuleList()
for _ in range(n_decoders):
out_channels = in_channels // 2
self.layers.append(RMVPE_ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
in_channels = out_channels
self.n_decoders = n_decoders
def forward(self, x, concat_tensors):
for i in range(self.n_decoders):
x = self.layers[i](x, concat_tensors[-1 - i])
return x
class RMVPE_DeepUnet(nn.Module):
def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
super().__init__()
self.encoder = RMVPE_Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
self.intermediate = RMVPE_Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
self.decoder = RMVPE_Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
def forward(self, x):
x, concat_tensors = self.encoder(x)
x = self.intermediate(x)
x = self.decoder(x, concat_tensors)
return x
class RMVPE_BiGRU(nn.Module):
def __init__(self, input_features, hidden_features, num_layers):
super().__init__()
self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
def forward(self, x):
return self.gru(x)[0]
RMVPE_N_MELS = 128
RMVPE_N_CLASS = 360
class RMVPE_E2E(nn.Module):
def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
super().__init__()
self.unet = RMVPE_DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
if n_gru:
self.fc = nn.Sequential(
RMVPE_BiGRU(3 * 128, 256, n_gru),
nn.Linear(512, RMVPE_N_CLASS), nn.Dropout(0.25), nn.Sigmoid(),
)
else:
self.fc = nn.Sequential(nn.Linear(3 * RMVPE_N_MELS, RMVPE_N_CLASS), nn.Dropout(0.25), nn.Sigmoid())
def forward(self, mel):
mel = mel.transpose(-1, -2).unsqueeze(1)
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
return self.fc(x)
class RMVPE_MelSpectrogram(nn.Module):
def __init__(self, n_mel_channels=128, sample_rate=16000, win_length=1024, hop_length=160, n_fft=None, mel_fmin=30, mel_fmax=8000, clamp=1e-5):
super().__init__()
from librosa.filters import mel as librosa_mel
n_fft = win_length if n_fft is None else n_fft
self.hann_window = {}
mel_basis = librosa_mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True)
self.register_buffer("mel_basis", torch.from_numpy(mel_basis).float())
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.clamp = clamp
def forward(self, audio, keyshift=0, speed=1, center=True):
factor = 2 ** (keyshift / 12)
n_fft_new = int(np.round(self.n_fft * factor))
win_length_new = int(np.round(self.win_length * factor))
hop_length_new = int(np.round(self.hop_length * speed))
key = f"{keyshift}_{audio.device}"
if key not in self.hann_window:
self.hann_window[key] = torch.hann_window(win_length_new).to(audio.device)
fft = torch.stft(audio, n_fft=n_fft_new, hop_length=hop_length_new, win_length=win_length_new,
window=self.hann_window[key], center=center, return_complex=True)
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
if keyshift != 0:
size = self.n_fft // 2 + 1
resize = magnitude.size(1)
if resize < size:
magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
mel_output = torch.matmul(self.mel_basis, magnitude)
return torch.log(torch.clamp(mel_output, min=self.clamp))
_rmvpe_model = None
def load_rmvpe():
"""Download and load RMVPE model for f0 extraction"""
global _rmvpe_model
if _rmvpe_model is None:
logger.info("Downloading RMVPE model...")
rmvpe_path = hf_hub_download(repo_id="IAHispano/Applio", filename="Resources/predictors/rmvpe.pt")
model = RMVPE_E2E(4, 1, (2, 2))
ckpt = torch.load(rmvpe_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt)
model.eval().to(device)
mel_extractor = RMVPE_MelSpectrogram().to(device)
cents_mapping = 20 * np.arange(RMVPE_N_CLASS) + 1997.3794084376191
_rmvpe_model = (model, mel_extractor, np.pad(cents_mapping, (4, 4)))
logger.info("RMVPE model loaded")
return _rmvpe_model
def extract_f0_rmvpe(audio: np.ndarray, sr: int = 16000, f0_up_key: int = 0, thred: float = 0.03) -> Tuple[np.ndarray, np.ndarray]:
"""Extract F0 using RMVPE (best quality, neural network based)"""
model, mel_extractor, cents_mapping = load_rmvpe()
audio_t = torch.from_numpy(audio).float().to(device).unsqueeze(0)
mel = mel_extractor(audio_t, center=True)
del audio_t
# mel2hidden with chunking
with torch.no_grad():
n_frames = mel.shape[-1]
mel_padded = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect")
chunks = []
for start in range(0, mel_padded.shape[-1], 32000):
end = min(start + 32000, mel_padded.shape[-1])
chunks.append(model(mel_padded[..., start:end]))
hidden = torch.cat(chunks, dim=1)[:, :n_frames].squeeze(0).cpu().numpy()
# Decode hidden to f0
center = np.argmax(hidden, axis=1)
salience = np.pad(hidden, ((0, 0), (4, 4)))
center += 4
todo_salience = []
todo_cents = []
for idx in range(salience.shape[0]):
s, e = center[idx] - 4, center[idx] + 5
todo_salience.append(salience[idx, s:e])
todo_cents.append(cents_mapping[s:e])
todo_salience = np.array(todo_salience)
todo_cents = np.array(todo_cents)
cents_pred = np.sum(todo_salience * todo_cents, 1) / np.sum(todo_salience, 1)
cents_pred[np.max(salience, axis=1) <= thred] = 0
f0 = 10 * (2 ** (cents_pred / 1200))
f0[f0 == 10] = 0
f0 *= pow(2, f0_up_key / 12)
return f0_to_coarse(f0)
# ============================================================
# MODEL LOADING
# ============================================================
_model_cache = {}
def load_rvc_model(model_path: str):
"""Load RVC model and auto-detect version"""
if model_path in _model_cache:
return _model_cache[model_path]
logger.info(f"Loading RVC model: {model_path}")
try:
cpt = torch.load(model_path, map_location="cpu", weights_only=True)
except Exception:
logger.warning("Model requires unsafe loading - may be an older format")
cpt = torch.load(model_path, map_location="cpu", weights_only=False)
weight_key = None
for key in ["weight", "model", "state_dict", "net_g"]:
if key in cpt:
weight_key = key
break
if weight_key is None:
raise ValueError(f"Cannot find model weights. Keys: {list(cpt.keys())}")
config = cpt.get("config", None)
if config is None:
config = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4], 109, 256, 40000]
logger.warning("No config found, using v2 defaults")
version = cpt.get("version", "v1")
if_f0 = cpt.get("f0", 1)
if weight_key in cpt:
emb_weight = cpt[weight_key].get("emb_g.weight")
if emb_weight is not None:
config[-3] = emb_weight.shape[0]
sr = config[-1] if isinstance(config[-1], int) else 40000
if version == "v1":
model_class = SynthesizerTrnMs256NSFsid if if_f0 == 1 else SynthesizerTrnMs256NSFsid_nono
else:
model_class = SynthesizerTrnMs768NSFsid if if_f0 == 1 else SynthesizerTrnMs768NSFsid_nono
model = model_class(
spec_channels=config[0], segment_size=config[1], inter_channels=config[2],
hidden_channels=config[3], filter_channels=config[4], n_heads=config[5],
n_layers=config[6], kernel_size=config[7], p_dropout=config[8],
resblock=config[9], resblock_kernel_sizes=config[10],
resblock_dilation_sizes=config[11], upsample_rates=config[12],
upsample_initial_channel=config[13], upsample_kernel_sizes=config[14],
spk_embed_dim=config[15], gin_channels=config[16], sr=sr, is_half=False
)
model.load_state_dict(cpt[weight_key], strict=False)
model.eval().to(device)
_model_cache[model_path] = (model, sr, version, if_f0)
logger.info(f"Model loaded: version={version}, f0={if_f0}, sr={sr}")
return model, sr, version, if_f0
# ============================================================
# TRAINING - Simplified for CPU testing
# ============================================================
def spectrogram_torch(y, n_fft, hop_size, win_size, center=False):
"""Compute spectrogram"""
hann_window = torch.hann_window(win_size).to(y.device)
y = F.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect').squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
return spec
# Mel spectrogram for training loss
_mel_basis_cache = {}
def spec_to_mel_torch(spec, n_fft=2048, num_mels=125, sampling_rate=40000, fmin=0, fmax=None):
"""Convert spectrogram to mel spectrogram"""
from librosa.filters import mel as librosa_mel_fn
global _mel_basis_cache
if fmax is None:
fmax = sampling_rate // 2
key = f"{n_fft}_{num_mels}_{sampling_rate}_{fmin}_{fmax}_{spec.dtype}_{spec.device}"
if key not in _mel_basis_cache:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
_mel_basis_cache[key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
melspec = torch.matmul(_mel_basis_cache[key], spec)
melspec = torch.log(torch.clamp(melspec, min=1e-5)) # Log-amplitude
return melspec
def preprocess_audio_for_training(audio_path: str, output_dir: str, target_sr: int = 40000, f0_method: str = "rmvpe"):
"""Preprocess audio file for training - slice and extract features"""
import scipy.signal as signal
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f"{output_dir}/wavs", exist_ok=True)
os.makedirs(f"{output_dir}/hubert", exist_ok=True)
os.makedirs(f"{output_dir}/f0", exist_ok=True)
logger.info(f"Preprocessing: {audio_path}")
# Load and resample audio
audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
# High-pass filter
bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=target_sr)
audio = signal.lfilter(bh, ah, audio)
# Slice into chunks (3.7 seconds with 0.3 overlap)
chunk_size = int(3.7 * target_sr)
hop = int(3.4 * target_sr)
chunks = []
for i, start in enumerate(range(0, len(audio) - chunk_size, hop)):
chunk = audio[start:start + chunk_size]
# Normalize
max_val = np.abs(chunk).max()
if max_val > 0.01: # Skip silence
chunk = chunk / max_val * 0.9
chunks.append((i, chunk))
if not chunks:
logger.warning("No valid audio chunks found")
return None
logger.info(f"Created {len(chunks)} chunks")
# Save chunks and extract features
manifest = []
for idx, chunk in chunks:
# Save wav
wav_path = f"{output_dir}/wavs/{idx:04d}.wav"
sf.write(wav_path, chunk, target_sr)
# Resample to 16k for HuBERT
chunk_16k = librosa.resample(chunk, orig_sr=target_sr, target_sr=16000)
# Extract HuBERT features
feats = extract_hubert_features(chunk_16k, sr=16000, version="v2")
hubert_path = f"{output_dir}/hubert/{idx:04d}.npy"
np.save(hubert_path, feats.squeeze(0).cpu().numpy())
# Extract F0
if f0_method == "rmvpe":
f0_coarse, f0 = extract_f0_rmvpe(chunk_16k, 16000, 0)
elif f0_method == "harvest":
f0_coarse, f0 = extract_f0_harvest(chunk_16k, 16000, 0)
else:
f0_coarse, f0 = extract_f0_pm(chunk_16k, 16000, 0)
f0_path = f"{output_dir}/f0/{idx:04d}.npy"
np.save(f0_path, np.stack([f0_coarse, f0], axis=0))
manifest.append(f"{idx:04d}")
# Save manifest
with open(f"{output_dir}/manifest.txt", "w") as f:
f.write("\n".join(manifest))
logger.info(f"Preprocessing complete: {len(manifest)} samples")
return output_dir
def train_rvc_generator(
data_dir: str,
output_dir: str,
epochs: int = 10,
batch_size: int = 2,
lr: float = 1e-5, # Lower LR prevents overfitting on small data
target_sr: int = 40000,
progress_callback=None
):
"""Generator version of train_rvc - yields (epoch_msg, ckpt_path) tuples"""
logger.info(f"Starting training: {data_dir} -> {output_dir}")
os.makedirs(output_dir, exist_ok=True)
# Load manifest
with open(f"{data_dir}/manifest.txt") as f:
samples = [l.strip() for l in f if l.strip()]
if len(samples) < 1:
logger.error("No training samples found")
return None
logger.info(f"Training with {len(samples)} samples")
# Model config (v2 40k defaults)
config = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4], 1, 256, target_sr]
# Create models (v2 only - 768-dim HuBERT features)
net_g = SynthesizerTrnMs768NSFsid(
spec_channels=config[0], segment_size=config[1], inter_channels=config[2],
hidden_channels=config[3], filter_channels=config[4], n_heads=config[5],
n_layers=config[6], kernel_size=config[7], p_dropout=config[8],
resblock=config[9], resblock_kernel_sizes=config[10],
resblock_dilation_sizes=config[11], upsample_rates=config[12],
upsample_initial_channel=config[13], upsample_kernel_sizes=config[14],
spk_embed_dim=config[15], gin_channels=config[16], sr=target_sr
).to(train_device)
net_d = MultiPeriodDiscriminator().to(train_device)
logger.info(f"Training on device: {train_device}")
# Download and load pretrained weights (essential for good results)
sr_key = f"{target_sr // 1000}k" # e.g., "40k"
try:
pretrain_g_path = download_pretrained_rvc(f"f0G{sr_key}")
pretrain_d_path = download_pretrained_rvc(f"f0D{sr_key}")
load_pretrained_weights(net_g, pretrain_g_path)
load_pretrained_weights(net_d, pretrain_d_path)
except Exception as e:
logger.warning(f"Failed to load pretrained weights: {e}")
logger.warning("Training from scratch (results may be poor)")
# Optimizers (after loading pretrained weights)
optim_g = torch.optim.AdamW(net_g.parameters(), lr=lr, betas=(0.8, 0.99))
optim_d = torch.optim.AdamW(net_d.parameters(), lr=lr, betas=(0.8, 0.99))
# LR scheduler (matches Applio - exponential decay)
lr_decay = 0.999875
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=lr_decay)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=lr_decay)
net_g.train()
net_d.train()
# Training loop
for epoch in range(epochs):
total_loss_g, total_loss_d = 0, 0
np.random.shuffle(samples)
for i in range(0, len(samples), batch_size):
batch_samples = samples[i:i+batch_size]
# Load batch data
wavs, huberts, f0s = [], [], []
for s in batch_samples:
wav, _ = librosa.load(f"{data_dir}/wavs/{s}.wav", sr=target_sr, mono=True)
hubert = np.load(f"{data_dir}/hubert/{s}.npy")
# Upsample 50Hz -> 100Hz using interpolation (same as inference)
hubert_t = torch.from_numpy(hubert).unsqueeze(0).permute(0, 2, 1) # (1, 768, seq)
hubert_t = F.interpolate(hubert_t, scale_factor=2, mode='linear', align_corners=False)
hubert = hubert_t.permute(0, 2, 1).squeeze(0).numpy() # (seq*2, 768)
f0_data = np.load(f"{data_dir}/f0/{s}.npy")
wavs.append(wav)
huberts.append(hubert)
f0s.append(f0_data)
# Compute spectrogram first to get target length
max_wav_len = max(len(w) for w in wavs)
wav_batch = np.zeros((len(wavs), max_wav_len))
for j, w in enumerate(wavs):
wav_batch[j, :len(w)] = w
wav_t = torch.FloatTensor(wav_batch).unsqueeze(1).to(train_device)
spec = spectrogram_torch(wav_t.squeeze(1), 2048, 400, 2048)
spec_len = spec.shape[2] # Target length for all features
# Pad/truncate features to match spec length exactly
hubert_batch = np.zeros((len(huberts), spec_len, huberts[0].shape[1]))
f0_batch = np.zeros((len(f0s), spec_len))
f0f_batch = np.zeros((len(f0s), spec_len))
for j, (h, f) in enumerate(zip(huberts, f0s)):
# Truncate or pad HuBERT to spec_len
h_len = min(h.shape[0], spec_len)
hubert_batch[j, :h_len] = h[:h_len]
# Truncate or pad F0 to spec_len
f0_len = min(f.shape[1], spec_len)
f0_batch[j, :f0_len] = f[0, :f0_len]
f0f_batch[j, :f0_len] = f[1, :f0_len]
# To tensors - all features now have spec_len
hubert_t = torch.FloatTensor(hubert_batch).to(train_device)
f0_t = torch.LongTensor(f0_batch.astype(np.int64)).to(train_device)
f0f_t = torch.FloatTensor(f0f_batch).to(train_device)
lengths_t = torch.LongTensor([spec_len] * len(batch_samples)).to(train_device)
sid_t = torch.LongTensor([0] * len(batch_samples)).to(train_device)
spec_lengths = torch.LongTensor([spec_len] * len(batch_samples)).to(train_device)
# Forward pass generator
# Args: phone, phone_lengths, pitch, pitchf, y (spec), y_lengths, ds
try:
y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(
hubert_t, lengths_t, f0_t, f0f_t, spec, spec_lengths, sid_t
)
except Exception as e:
logger.warning(f"Generator forward failed: {e}")
continue
# Slice wav at same position model generated (CRITICAL for proper loss)
# ids_slice is in latent space, multiply by hop_length to get waveform position
hop_length = 400
segment_size_wav = 32 * hop_length # segment_size in latent * hop_length
y = slice_segments(wav_t, ids_slice * hop_length, segment_size_wav)
# Discriminator forward
y_d_rs, y_d_gs, fmap_rs, fmap_gs = net_d(y, y_hat.detach())
# Discriminator loss
loss_d = discriminator_loss(y_d_rs, y_d_gs)
optim_d.zero_grad()
loss_d.backward()
optim_d.step()
# Generator loss
y_d_rs, y_d_gs, fmap_rs, fmap_gs = net_d(y, y_hat)
loss_gen = generator_loss(y_d_gs)
loss_fm = feature_loss(fmap_rs, fmap_gs)
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
# Mel spectrogram loss (crucial for quality)
# Config: n_fft=2048, hop=400, win=2048, n_mels=125, fmin=0, fmax=None
y_mel = spec_to_mel_torch(spectrogram_torch(y.squeeze(1), 2048, 400, 2048),
n_fft=2048, num_mels=125, sampling_rate=target_sr, fmin=0, fmax=None)
y_hat_mel = spec_to_mel_torch(spectrogram_torch(y_hat.squeeze(1), 2048, 400, 2048),
n_fft=2048, num_mels=125, sampling_rate=target_sr, fmin=0, fmax=None)
# Align lengths if needed
min_len = min(y_mel.shape[2], y_hat_mel.shape[2])
loss_mel = F.l1_loss(y_mel[:, :, :min_len], y_hat_mel[:, :, :min_len]) * 45 # c_mel = 45
loss_g = loss_gen + loss_fm + loss_mel + loss_kl
optim_g.zero_grad()
loss_g.backward()
optim_g.step()
total_loss_g += loss_g.item()
total_loss_d += loss_d.item()
avg_loss_g = total_loss_g / max(1, len(samples) // batch_size)
avg_loss_d = total_loss_d / max(1, len(samples) // batch_size)
epoch_msg = f"Epoch {epoch+1}/{epochs} - G: {avg_loss_g:.2f}, D: {avg_loss_d:.2f}"
logger.info(epoch_msg)
# Update progress callback if provided
if progress_callback:
progress_pct = 0.30 + (0.65 * (epoch + 1) / epochs)
progress_callback(progress_pct, epoch_msg)
# Yield epoch message for live UI updates
yield epoch_msg, None, None
# Step LR schedulers
scheduler_g.step()
scheduler_d.step()
# Save checkpoint
ckpt_path = f"{output_dir}/model.pth"
torch.save({
"weight": net_g.state_dict(),
"config": config,
"version": "v2", # v2 only
"f0": 1,
}, ckpt_path)
logger.info(f"Saved checkpoint: {ckpt_path}")
# Generate index file for better speaker similarity
index_path = None
try:
import faiss
hubert_dir = f"{data_dir}/hubert"
npys = []
for name in sorted(os.listdir(hubert_dir)):
if name.endswith('.npy'):
phone = np.load(os.path.join(hubert_dir, name))
npys.append(phone)
if npys:
big_npy = np.concatenate(npys, axis=0)
n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
n_ivf = max(1, n_ivf) # Ensure at least 1
index = faiss.index_factory(big_npy.shape[1], f"IVF{n_ivf},Flat")
index.train(big_npy)
index.add(big_npy)
index_path = f"{output_dir}/model.index"
faiss.write_index(index, index_path)
logger.info(f"Saved index: {index_path}")
except Exception as e:
logger.warning(f"Failed to generate index: {e}")
# Cleanup training models to free memory
purge_memory(net_g, net_d, optim_g, optim_d, scheduler_g, scheduler_d)
yield "Training complete!", ckpt_path, index_path
def train_rvc(
data_dir: str,
output_dir: str,
epochs: int = 10,
batch_size: int = 2,
lr: float = 1e-5, # Lower LR prevents overfitting on small data
target_sr: int = 40000,
progress_callback=None
):
"""Non-generator wrapper for CLI use - returns (checkpoint_path, index_path)"""
ckpt = None
idx = None
for msg, path, index in train_rvc_generator(data_dir, output_dir, epochs, batch_size, lr, target_sr, progress_callback):
if path:
ckpt = path
if index:
idx = index
return ckpt, idx
# ============================================================
# INFERENCE
# ============================================================
def convert_voice(
source_audio: str,
model_file,
index_file=None,
pitch_shift: int = 0,
f0_method: str = "pm",
index_rate: float = 0.5,
protect: float = 0.33,
volume_envelope: float = 1.0,
progress=gr.Progress()
) -> Tuple[str, str]:
"""Convert voice using RVC model (Applio-compatible pipeline)."""
try:
if source_audio is None:
return None, "Please upload source audio"
if model_file is None:
return None, "Please upload RVC model (.pth)"
model_path = model_file.name if hasattr(model_file, 'name') else model_file
progress(0.1, "Loading model...")
model, tgt_sr, version, if_f0 = load_rvc_model(model_path)
progress(0.2, "Loading audio...")
audio, sr = librosa.load(source_audio, sr=16000, mono=True)
# Apply 48Hz high-pass filter (critical - removes low-frequency artifacts)
audio = signal.filtfilt(bh, ah, audio)
# Normalize audio
audio_max = np.abs(audio).max() / 0.95
if audio_max > 1:
audio /= audio_max
# Pipeline constants (same as Applio)
window = 160 # Critical for feature/pitch alignment
x_pad = 1 # Padding in seconds
t_pad = 16000 * x_pad # Padding in samples
# Pad audio
audio_pad = np.pad(audio, (t_pad, t_pad), mode="reflect")
p_len = audio_pad.shape[0] // window
progress(0.3, "Extracting features...")
feats = extract_hubert_features(audio_pad, sr=16000, version=version)
# Save original features for protect mechanism
feats0 = feats.clone() if if_f0 == 1 and protect < 0.5 else None
# Index retrieval (speaker similarity)
if index_file is not None and index_rate > 0:
try:
import faiss
index_path = index_file.name if hasattr(index_file, 'name') else index_file
progress(0.4, "Loading index...")
index = faiss.read_index(index_path)
big_npy = index.reconstruct_n(0, index.ntotal)
npy = feats[0].cpu().numpy().astype("float32")
score, ix = index.search(npy, k=8)
weight = np.square(1 / score)
weight /= weight.sum(axis=1, keepdims=True)
npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
feats = torch.from_numpy(npy).unsqueeze(0).to(device) * index_rate + (1 - index_rate) * feats
except Exception as e:
logger.warning(f"Index retrieval failed: {e}")
# Feature upsampling by 2x
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
# Adjust length based on audio
p_len = min(audio_pad.shape[0] // window, feats.shape[1])
pitch, pitchf = None, None
if if_f0 == 1:
progress(0.5, f"Extracting F0 ({f0_method})...")
if f0_method == "rmvpe":
pitch, pitchf = extract_f0_rmvpe(audio_pad, 16000, pitch_shift)
elif f0_method == "harvest":
pitch, pitchf = extract_f0_harvest(audio_pad, 16000, pitch_shift)
else:
pitch, pitchf = extract_f0_pm(audio_pad, 16000, pitch_shift)
pitch = pitch[:p_len]
pitchf = pitchf[:p_len]
# Upsample feats0 for protect
if feats0 is not None:
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
# Apply protect mechanism (preserve original features for unvoiced segments)
if protect < 0.5 and feats0 is not None:
pitchf_tensor = torch.from_numpy(pitchf).float().to(device)
pitchff = pitchf_tensor.clone()
pitchff[pitchf_tensor > 0] = 1
pitchff[pitchf_tensor < 1] = protect
pitchff = pitchff.unsqueeze(0).unsqueeze(-1)
feats = feats[:, :p_len, :] * pitchff + feats0[:, :p_len, :] * (1 - pitchff)
if len(pitch) < p_len:
pitch = np.pad(pitch, (0, p_len - len(pitch)))
pitchf = np.pad(pitchf, (0, p_len - len(pitchf)))
pitch = torch.LongTensor(pitch).unsqueeze(0).to(device)
pitchf = torch.FloatTensor(pitchf).unsqueeze(0).to(device)
p_len_tensor = torch.LongTensor([p_len]).to(device)
sid = torch.LongTensor([0]).to(device)
progress(0.7, "Running inference...")
with torch.no_grad():
if if_f0 == 1:
audio_out = model.infer(feats[:, :p_len, :], p_len_tensor, pitch, pitchf, sid)[0][0, 0].data.cpu().float().numpy()
else:
audio_out = model.infer(feats[:, :p_len, :], p_len_tensor, sid)[0][0, 0].data.cpu().float().numpy()
# Remove padding from output
t_pad_tgt = int(t_pad * tgt_sr / 16000)
if len(audio_out) > 2 * t_pad_tgt:
audio_out = audio_out[t_pad_tgt:-t_pad_tgt]
# RMS mixing - match volume dynamics of source audio
if volume_envelope != 1.0:
try:
source_at_tgt_sr = librosa.resample(audio, orig_sr=16000, target_sr=tgt_sr)
frame_len = tgt_sr // 2 * 2
hop_len = tgt_sr // 2
rms_source = librosa.feature.rms(y=source_at_tgt_sr, frame_length=frame_len, hop_length=hop_len)
rms_output = librosa.feature.rms(y=audio_out, frame_length=frame_len, hop_length=hop_len)
rms_source = F.interpolate(
torch.from_numpy(rms_source).float().unsqueeze(0),
size=audio_out.shape[0], mode="linear"
).squeeze()
rms_output = F.interpolate(
torch.from_numpy(rms_output).float().unsqueeze(0),
size=audio_out.shape[0], mode="linear"
).squeeze()
rms_output = torch.maximum(rms_output, torch.zeros_like(rms_output) + 1e-6)
# Applio formula: target * (source^(1-rate) * output^(rate-1))
audio_out = audio_out * (torch.pow(rms_source, 1 - volume_envelope) * torch.pow(rms_output, volume_envelope - 1)).numpy()
except Exception as e:
logger.warning(f"RMS mixing failed: {e}")
# Final normalization
audio_max = np.abs(audio_out).max() / 0.99
if audio_max > 1:
audio_out /= audio_max
progress(0.9, "Saving output...")
fd, output_path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
sf.write(output_path, audio_out, tgt_sr)
# Aggressive memory purge after inference — frees glibc arena on Linux
_model_cache.clear()
_cleanup_args = [model, feats, audio_out, audio, audio_pad]
if feats0 is not None:
_cleanup_args.append(feats0)
purge_memory(*_cleanup_args)
return output_path, f"Converted: {version}, sr={tgt_sr}, pitch={pitch_shift:+d}"
except Exception as e:
logger.exception("Conversion failed")
_model_cache.clear()
purge_memory()
return None, f"Error: {str(e)}"
# ============================================================
# DEFAULT MODEL DOWNLOAD
# ============================================================
def load_example_model():
"""Download and load the default example model from HuggingFace"""
import shutil
try:
logger.info(f"Downloading example model from {DEFAULT_MODEL_REPO}...")
model_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_MODEL_FILE)
index_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_INDEX_FILE)
# Gradio 6 requires files to be in allowed directories (cwd or /tmp)
# Copy from HF cache to temp directory
temp_dir = tempfile.mkdtemp()
temp_model = os.path.join(temp_dir, DEFAULT_MODEL_FILE)
temp_index = os.path.join(temp_dir, DEFAULT_INDEX_FILE)
shutil.copy2(model_path, temp_model)
shutil.copy2(index_path, temp_index)
return temp_model, temp_index, f"Loaded: {DEFAULT_MODEL_REPO}"
except Exception as e:
logger.exception("Failed to download example model")
return None, None, f"Error: {str(e)}"
# ============================================================
# BEATRICE V2 MODEL
# ============================================================
def beatrice_load_audio(file, **kwargs):
"""Load audio using soundfile directly (for Beatrice dataset)"""
data, sr = sf.read(file, dtype='float32')
# soundfile returns (samples, channels), convert to torch (channels, samples)
wav = torch.from_numpy(data)
if wav.ndim == 1:
wav = wav.unsqueeze(0) # mono -> (1, samples)
else:
wav = wav.T # (samples, channels) -> (channels, samples)
return wav, sr
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
def dump_params(params: torch.Tensor, f: BinaryIO):
if params is None:
return
if params.dtype == torch.bfloat16:
f.write(
params.detach()
.clone()
.float()
.view(torch.short)
.numpy()
.ravel()[1::2]
.tobytes()
)
else:
f.write(params.detach().numpy().ravel().tobytes())
f.flush()
def dump_layer(layer: nn.Module, f: BinaryIO):
dump = partial(dump_params, f=f)
if hasattr(layer, "dump"):
layer.dump(f)
elif isinstance(layer, (nn.Linear, nn.Conv1d, nn.LayerNorm)):
dump(layer.weight)
dump(layer.bias)
elif isinstance(layer, nn.MultiheadAttention):
embed_dim = layer.embed_dim
num_heads = layer.num_heads
# [3 * embed_dim, embed_dim]
in_proj_weight = layer.in_proj_weight.data.clone()
in_proj_weight[: 2 * embed_dim] *= 1.0 / math.sqrt(
math.sqrt(embed_dim // num_heads)
)
in_proj_weight = in_proj_weight.view(
3, num_heads, embed_dim // num_heads, embed_dim
)
# [num_heads, 3, embed_dim / num_heads, embed_dim]
in_proj_weight = in_proj_weight.transpose(0, 1)
# [3 * embed_dim]
in_proj_bias = layer.in_proj_bias.data.clone()
in_proj_bias[: 2 * embed_dim] *= 1.0 / math.sqrt(
math.sqrt(embed_dim // num_heads)
)
in_proj_bias = in_proj_bias.view(3, num_heads, embed_dim // num_heads)
# [num_heads, 3, embed_dim / num_heads]
in_proj_bias = in_proj_bias.transpose(0, 1)
dump(in_proj_weight)
dump(in_proj_bias)
dump(layer.out_proj.weight)
dump(layer.out_proj.bias)
elif isinstance(layer, nn.Embedding):
dump(layer.weight)
elif isinstance(layer, nn.Parameter):
dump(layer)
elif isinstance(layer, nn.ModuleList):
for layer_i in layer:
dump_layer(layer_i, f)
else:
assert False, layer
class CausalConv1d(nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
delay: int = 0,
):
padding = (kernel_size - 1) * dilation - delay
self.trim = (kernel_size - 1) * dilation - 2 * delay
if self.trim < 0:
raise ValueError
super().__init__(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
result = super().forward(input)
if self.trim == 0:
return result
else:
return result[:, :, : -self.trim]
class WSConv1d(CausalConv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
delay: int = 0,
):
super().__init__(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
delay=delay,
)
self.weight.data.normal_(
0.0, math.sqrt(1.0 / (in_channels * kernel_size // groups))
)
if bias:
self.bias.data.zero_()
self.gain = nn.Parameter(torch.ones((out_channels, 1, 1)))
def standardized_weight(self) -> torch.Tensor:
var, mean = torch.var_mean(self.weight, [1, 2], keepdim=True)
scale = (
self.gain
* (
self.in_channels * self.kernel_size[0] // self.groups * var + 1e-8
).rsqrt()
)
return scale * (self.weight - mean)
def forward(self, input: torch.Tensor) -> torch.Tensor:
result = F.conv1d(
input,
self.standardized_weight(),
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
if self.trim == 0:
return result
else:
return result[:, :, : -self.trim]
def merge_weights(self):
self.weight.data[:] = self.standardized_weight().detach()
self.gain.data.fill_(1.0)
class WSLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__(in_features, out_features, bias)
self.weight.data.normal_(0.0, math.sqrt(1.0 / in_features))
self.bias.data.zero_()
self.gain = nn.Parameter(torch.ones((out_features, 1)))
def standardized_weight(self) -> torch.Tensor:
var, mean = torch.var_mean(self.weight, 1, keepdim=True)
scale = self.gain * (self.in_features * var + 1e-8).rsqrt()
return scale * (self.weight - mean)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.standardized_weight(), self.bias)
def merge_weights(self):
self.weight.data[:] = self.standardized_weight().detach()
self.gain.data.fill_(1.0)
class CrossAttention(nn.Module):
def __init__(
self,
qk_channels: int,
vo_channels: int,
num_heads: int,
in_q_channels: int,
in_kv_channels: int,
out_channels: int,
dropout: float = 0.0,
):
super().__init__()
assert qk_channels % num_heads == 0
self.qk_channels = qk_channels
self.vo_channels = vo_channels
self.num_heads = num_heads
self.in_q_channels = in_q_channels
self.in_kv_channels = in_kv_channels
self.out_channels = out_channels
self.dropout = dropout
self.head_qk_channels = qk_channels // num_heads
self.head_vo_channels = vo_channels // num_heads
self.q_projection = nn.Linear(in_q_channels, qk_channels)
self.q_projection.weight.data.normal_(0.0, math.sqrt(1.0 / in_q_channels))
self.q_projection.bias.data.zero_()
self.kv_projection = nn.Linear(in_kv_channels, qk_channels + vo_channels)
self.kv_projection.weight.data.normal_(0.0, math.sqrt(1.0 / in_kv_channels))
self.kv_projection.bias.data.zero_()
self.out_projection = nn.Linear(vo_channels, out_channels)
self.out_projection.weight.data.normal_(0.0, math.sqrt(1.0 / vo_channels))
self.out_projection.bias.data.zero_()
def forward(
self,
q: torch.Tensor,
kv: torch.Tensor,
) -> torch.Tensor:
# q: [batch_size, q_length, in_q_channels]
# kv: [batch_size, kv_length, in_kv_channels]
batch_size, q_length, _ = q.size()
_, kv_length, _ = kv.size()
# [batch_size, q_length, qk_channels]
q = self.q_projection(q)
# [batch_size, kv_length, qk_channels + vo_channels]
kv = self.kv_projection(kv)
# [batch_size, kv_length, qk_channels], [batch_size, kv_length, vo_channels]
k, v = kv.split([self.qk_channels, self.vo_channels], dim=2)
q = q.view(
batch_size, q_length, self.num_heads, self.head_qk_channels
).transpose(1, 2)
k = k.view(
batch_size, kv_length, self.num_heads, self.head_qk_channels
).transpose(1, 2)
v = v.view(
batch_size, kv_length, self.num_heads, self.head_vo_channels
).transpose(1, 2)
# [batch_size, num_heads, q_length, head_vo_channels]
attn_out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout)
# [batch_size, q_length, vo_channels]
attn_out = (
attn_out.transpose(1, 2)
.contiguous()
.view(batch_size, q_length, self.vo_channels)
)
# [batch_size, q_length, out_channels]
attn_out = self.out_projection(attn_out)
return attn_out
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
q_projection_weight = self.q_projection.weight.data.clone()
q_projection_bias = self.q_projection.bias.data.clone()
q_projection_weight *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels))
q_projection_bias *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels))
dump_params(q_projection_weight, f)
dump_params(q_projection_bias, f)
dump_layer(self.out_projection, f)
def dump_kv(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump_kv(f)
return
if not hasattr(f, "write"):
raise TypeError
kv_projection_weight = self.kv_projection.weight.data.clone()
kv_projection_bias = self.kv_projection.bias.data.clone()
k_projection_weight, v_projection_weight = kv_projection_weight.split(
[self.qk_channels, self.vo_channels]
)
k_projection_bias, v_projection_bias = kv_projection_bias.split(
[self.qk_channels, self.vo_channels]
)
k_projection_weight *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels))
k_projection_bias *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels))
# [qk_channels, in_kv_channels] -> [num_heads, head_qk_channels, in_kv_channels]
k_projection_weight = k_projection_weight.view(
self.num_heads, self.head_qk_channels, self.in_kv_channels
)
# [qk_channels] -> [num_heads, head_qk_channels]
k_projection_bias = k_projection_bias.view(
self.num_heads, self.head_qk_channels
)
# [vo_channels, in_kv_channels] -> [num_heads, head_vo_channels, in_kv_channels]
v_projection_weight = v_projection_weight.view(
self.num_heads, self.head_vo_channels, self.in_kv_channels
)
# [vo_channels] -> [num_heads, head_vo_channels]
v_projection_bias = v_projection_bias.view(
self.num_heads, self.head_vo_channels
)
for i in range(self.num_heads):
# [head_qk_channels, in_kv_channels]
dump_params(k_projection_weight[i], f)
# [head_vo_channels, in_kv_channels]
dump_params(v_projection_weight[i], f)
for i in range(self.num_heads):
# [head_qk_channels]
dump_params(k_projection_bias[i], f)
# [head_vo_channels]
dump_params(v_projection_bias[i], f)
class ConvNeXtBlock(nn.Module):
def __init__(
self,
channels: int,
intermediate_channels: int,
layer_scale_init_value: float,
kernel_size: int = 7,
use_weight_standardization: bool = False,
enable_scaling: bool = False,
pre_scale: float = 1.0,
post_scale: float = 1.0,
use_mha: bool = False,
cross_attention: bool = False,
num_heads: int = 4,
attention_dropout: float = 0.1,
attention_channels: Optional[int] = None,
kv_channels: Optional[int] = None,
):
super().__init__()
self.use_weight_standardization = use_weight_standardization
self.enable_scaling = enable_scaling
self.use_mha = use_mha
self.cross_attention = cross_attention
if use_mha:
self.attn_norm = nn.LayerNorm(channels)
if cross_attention:
self.mha = CrossAttention(
qk_channels=attention_channels,
vo_channels=attention_channels,
num_heads=num_heads,
in_q_channels=channels,
in_kv_channels=kv_channels,
out_channels=channels,
dropout=attention_dropout,
)
else: # self-attention
assert attention_channels is None
assert kv_channels is None
self.mha = nn.MultiheadAttention(
embed_dim=channels,
num_heads=num_heads,
dropout=attention_dropout,
batch_first=True,
)
self.dwconv = CausalConv1d(
channels, channels, kernel_size=kernel_size, groups=channels
)
self.norm = nn.LayerNorm(channels)
self.pwconv1 = nn.Linear(channels, intermediate_channels)
self.pwconv2 = nn.Linear(intermediate_channels, channels)
self.gamma = nn.Parameter(torch.full((channels,), layer_scale_init_value))
self.dwconv.weight.data.normal_(0.0, math.sqrt(1.0 / kernel_size))
self.dwconv.bias.data.zero_()
self.pwconv1.weight.data.normal_(0.0, math.sqrt(2.0 / channels))
self.pwconv1.bias.data.zero_()
self.pwconv2.weight.data.normal_(0.0, math.sqrt(1.0 / intermediate_channels))
self.pwconv2.bias.data.zero_()
if use_weight_standardization:
self.norm = nn.Identity()
self.dwconv = WSConv1d(channels, channels, kernel_size, groups=channels)
self.pwconv1 = WSLinear(channels, intermediate_channels)
self.pwconv2 = WSLinear(intermediate_channels, channels)
del self.gamma
if enable_scaling:
self.register_buffer("pre_scale", torch.tensor(pre_scale))
self.register_buffer("post_scale", torch.tensor(post_scale))
self.post_scale_weight = nn.Parameter(torch.ones(()))
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
kv: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.use_mha:
batch_size, channels, length = x.size()
if self.cross_attention:
assert kv is not None
else:
assert kv is None
assert length % 4 == 0
identity = x
if self.cross_attention:
# kv: [batch_size, kv_length, kv_channels]
x = x.transpose(1, 2)
x = self.attn_norm(x)
x = self.mha(x, kv)
x = x.transpose(1, 2)
else:
x = x.view(batch_size, channels, length // 4, 4)
x = x.permute(0, 3, 2, 1)
x = x.reshape(batch_size * 4, length // 4, channels)
x = self.attn_norm(x)
x, _ = self.mha(
x, x, x, attn_mask=attn_mask, is_causal=True, need_weights=False
)
x = x.view(batch_size, 4, length // 4, channels)
x = x.permute(0, 3, 2, 1)
x = x.reshape(batch_size, channels, length)
x += identity
identity = x
if self.enable_scaling:
x = x * self.pre_scale
x = self.dwconv(x)
x = x.transpose(1, 2)
x = self.norm(x)
x = self.pwconv1(x)
x = F.gelu(x, approximate="tanh")
x = self.pwconv2(x)
if not self.use_weight_standardization:
x *= self.gamma
if self.enable_scaling:
x *= self.post_scale * self.post_scale_weight
x = x.transpose(1, 2)
x += identity
return x
def merge_weights(self):
if self.use_mha:
if self.cross_attention:
assert isinstance(self.mha, CrossAttention)
self.mha.q_projection.bias.data += torch.mv(
self.mha.q_projection.weight.data, self.attn_norm.bias.data
)
self.mha.q_projection.weight.data *= self.attn_norm.weight.data[None, :]
self.attn_norm.bias.data[:] = 0.0
self.attn_norm.weight.data[:] = 1.0
else: # self-attention
assert isinstance(self.mha, nn.MultiheadAttention)
self.mha.in_proj_bias.data += torch.mv(
self.mha.in_proj_weight.data, self.attn_norm.bias.data
)
self.mha.in_proj_weight.data *= self.attn_norm.weight.data[None, :]
self.attn_norm.bias.data[:] = 0.0
self.attn_norm.weight.data[:] = 1.0
if self.use_weight_standardization:
self.dwconv.merge_weights()
self.pwconv1.merge_weights()
self.pwconv2.merge_weights()
else:
self.pwconv1.bias.data += torch.mv(
self.pwconv1.weight.data, self.norm.bias.data
)
self.pwconv1.weight.data *= self.norm.weight.data[None, :]
self.norm.bias.data[:] = 0.0
self.norm.weight.data[:] = 1.0
self.pwconv2.weight.data *= self.gamma.data[:, None]
self.pwconv2.bias.data *= self.gamma.data
self.gamma.data[:] = 1.0
if self.enable_scaling:
self.dwconv.weight.data *= self.pre_scale.data
self.pre_scale.data.fill_(1.0)
self.pwconv2.weight.data *= (
self.post_scale.data * self.post_scale_weight.data
)
self.pwconv2.bias.data *= self.post_scale.data * self.post_scale_weight.data
self.post_scale.data.fill_(1.0)
self.post_scale_weight.data.fill_(1.0)
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
if self.use_mha:
dump_layer(self.mha, f)
dump_layer(self.dwconv, f)
dump_layer(self.pwconv1, f)
dump_layer(self.pwconv2, f)
class ConvNeXtStack(nn.Module):
def __init__(
self,
in_channels: int,
channels: int,
intermediate_channels: int,
n_blocks: int,
delay: int,
embed_kernel_size: int,
kernel_size: int,
use_weight_standardization: bool = False,
enable_scaling: bool = False,
use_mha: bool = False,
cross_attention: bool = False,
kv_channels: Optional[int] = None,
):
super().__init__()
assert delay * 2 + 1 <= embed_kernel_size
assert not (use_weight_standardization and use_mha) # 未対応
self.use_weight_standardization = use_weight_standardization
self.use_mha = use_mha
self.cross_attention = cross_attention
self.embed = CausalConv1d(in_channels, channels, embed_kernel_size, delay=delay)
self.norm = nn.LayerNorm(channels)
self.convnext = nn.ModuleList()
for i in range(n_blocks):
pre_scale = 1.0 / math.sqrt(1.0 + i / n_blocks) if enable_scaling else 1.0
post_scale = 1.0 / math.sqrt(n_blocks) if enable_scaling else 1.0
block = ConvNeXtBlock(
channels=channels,
intermediate_channels=intermediate_channels,
layer_scale_init_value=1.0 / n_blocks,
kernel_size=kernel_size,
use_weight_standardization=use_weight_standardization,
enable_scaling=enable_scaling,
pre_scale=pre_scale,
post_scale=post_scale,
use_mha=use_mha,
cross_attention=cross_attention,
num_heads=4,
attention_dropout=0.1,
attention_channels=kv_channels,
kv_channels=kv_channels,
)
self.convnext.append(block)
self.final_layer_norm = nn.LayerNorm(channels)
self.embed.weight.data.normal_(
0.0, math.sqrt(0.5 / (embed_kernel_size * in_channels))
)
self.embed.bias.data.zero_()
if use_weight_standardization:
self.embed = WSConv1d(in_channels, channels, embed_kernel_size, delay=delay)
self.norm = nn.Identity()
self.final_layer_norm = nn.Identity()
def forward(
self, x: torch.Tensor, kv: Optional[torch.Tensor] = None
) -> torch.Tensor:
x = self.embed(x)
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
if self.use_mha and not self.cross_attention:
pad_length = -x.size(2) % 4
if pad_length:
x = F.pad(x, (0, pad_length))
t40 = x.size(2) // 4
attn_mask = torch.ones((t40, t40), dtype=torch.bool, device=x.device).triu(
1
)
else:
attn_mask = None
for conv_block in self.convnext:
x = conv_block(x, attn_mask=attn_mask, kv=kv)
if self.use_mha and not self.cross_attention and pad_length:
x = x[:, :, :-pad_length]
x = self.final_layer_norm(x.transpose(1, 2)).transpose(1, 2)
return x
def merge_weights(self):
if self.use_weight_standardization:
self.embed.merge_weights()
for conv_block in self.convnext:
conv_block.merge_weights()
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.embed, f)
if not self.use_weight_standardization:
dump_layer(self.norm, f)
dump_layer(self.convnext, f)
if not self.use_weight_standardization:
dump_layer(self.final_layer_norm, f)
def dump_kv(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump_kv(f)
return
if not hasattr(f, "write"):
raise TypeError
assert self.use_mha and self.cross_attention
for conv_block in self.convnext:
if not conv_block.use_mha or not conv_block.cross_attention:
continue
assert isinstance(conv_block, ConvNeXtBlock)
assert hasattr(conv_block, "mha")
assert isinstance(conv_block.mha, CrossAttention)
conv_block.mha.dump_kv(f)
class FeatureExtractor(nn.Module):
def __init__(self, hidden_channels: int):
super().__init__()
# fmt: off
self.conv0 = weight_norm(nn.Conv1d(1, hidden_channels // 8, 10, 5, bias=False))
self.conv1 = weight_norm(nn.Conv1d(hidden_channels // 8, hidden_channels // 4, 3, 2, bias=False))
self.conv2 = weight_norm(nn.Conv1d(hidden_channels // 4, hidden_channels // 2, 3, 2, bias=False))
self.conv3 = weight_norm(nn.Conv1d(hidden_channels // 2, hidden_channels, 3, 2, bias=False))
self.conv4 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 3, 2, bias=False))
self.conv5 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 2, 2, bias=False))
# fmt: on
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [batch_size, 1, wav_length]
wav_length = x.size(2)
if wav_length % 160 != 0:
warnings.warn("wav_length % 160 != 0")
x = F.pad(x, (40, 40))
x = F.gelu(self.conv0(x), approximate="tanh")
x = F.gelu(self.conv1(x), approximate="tanh")
x = F.gelu(self.conv2(x), approximate="tanh")
x = F.gelu(self.conv3(x), approximate="tanh")
x = F.gelu(self.conv4(x), approximate="tanh")
x = F.gelu(self.conv5(x), approximate="tanh")
# [batch_size, hidden_channels, wav_length / 160]
return x
def remove_weight_norm(self):
remove_weight_norm(self.conv0)
remove_weight_norm(self.conv1)
remove_weight_norm(self.conv2)
remove_weight_norm(self.conv3)
remove_weight_norm(self.conv4)
remove_weight_norm(self.conv5)
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.conv0, f)
dump_layer(self.conv1, f)
dump_layer(self.conv2, f)
dump_layer(self.conv3, f)
dump_layer(self.conv4, f)
dump_layer(self.conv5, f)
class FeatureProjection(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.norm = nn.LayerNorm(channels)
self.dropout = nn.Dropout(0.1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# [batch_size, channels, length]
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
x = self.dropout(x)
return x
class PhoneExtractor(nn.Module):
def __init__(
self,
phone_channels: int = 128,
hidden_channels: int = 128,
backbone_embed_kernel_size: int = 9,
kernel_size: int = 17,
n_blocks: int = 20,
):
super().__init__()
self.feature_extractor = FeatureExtractor(hidden_channels)
self.feature_projection = FeatureProjection(hidden_channels)
self.backbone = ConvNeXtStack(
in_channels=hidden_channels,
channels=hidden_channels,
intermediate_channels=hidden_channels * 3,
n_blocks=n_blocks,
delay=0,
embed_kernel_size=backbone_embed_kernel_size,
kernel_size=kernel_size,
use_mha=True,
)
self.head = weight_norm(nn.Conv1d(hidden_channels, phone_channels, 1))
def forward(
self, x: torch.Tensor, return_stats: bool = True
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]:
# x: [batch_size, 1, wav_length]
stats = {}
# [batch_size, 1, wav_length] -> [batch_size, feature_extractor_hidden_channels, length]
x = self.feature_extractor(x)
if return_stats:
stats["feature_norm"] = x.detach().norm(dim=1).mean()
# [batch_size, feature_extractor_hidden_channels, length] -> [batch_size, hidden_channels, length]
x = self.feature_projection(x)
# [batch_size, hidden_channels, length]
x = self.backbone(x)
# [batch_size, hidden_channels, length] -> [batch_size, phone_channels, length]
phone = self.head(F.gelu(x, approximate="tanh"))
results = [phone]
if return_stats:
stats["code_norm"] = phone.detach().norm(dim=1).mean()
results.append(stats)
if len(results) == 1:
return results[0]
return tuple(results)
@torch.inference_mode()
def units(self, x: torch.Tensor) -> torch.Tensor:
# x: [batch_size, 1, wav_length]
# [batch_size, 1, wav_length] -> [batch_size, phone_channels, length]
phone = self.forward(x, return_stats=False)
# [batch_size, phone_channels, length] -> [batch_size, length, phone_channels]
phone = phone.transpose(1, 2)
# [batch_size, length, phone_channels]
return phone
def remove_weight_norm(self):
self.feature_extractor.remove_weight_norm()
remove_weight_norm(self.head)
def merge_weights(self):
self.backbone.merge_weights()
self.backbone.embed.bias.data += (
(
self.feature_projection.norm.bias.data[None, :, None]
* self.backbone.embed.weight.data # [o, i, k]
)
.sum(1)
.sum(1)
)
self.backbone.embed.weight.data *= self.feature_projection.norm.weight.data[
None, :, None
]
self.feature_projection.norm.bias.data[:] = 0.0
self.feature_projection.norm.weight.data[:] = 1.0
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.feature_extractor, f)
dump_layer(self.backbone, f)
dump_layer(self.head, f)
class VectorQuantizer(nn.Module):
def __init__(
self,
n_speakers: int,
codebook_size: int,
channels: int,
topk: int = 4,
training_time_vq: Literal["none", "self", "random"] = "none",
):
super().__init__()
assert 1 <= topk <= codebook_size
self.n_speakers = n_speakers
self.codebook_size = codebook_size
self.channels = channels
self.topk = topk
self.training_time_vq = training_time_vq
self.register_buffer(
"codebooks",
torch.empty(n_speakers, codebook_size, channels, dtype=torch.half),
)
self.codebooks: torch.Tensor
# VQ の適用箇所を変更しやすいように hook にしている
self._hook_handle: Optional[torch.utils.hooks.RemovableHandle] = None
self.target_speaker_ids: Optional[torch.Tensor] = None
def _hook(_, __, output):
return self(output, self.target_speaker_ids)
self._hook_fn = _hook
@torch.no_grad()
def build_codebooks(
self,
collector_func: Callable,
target_layer: nn.Module,
inputs: Sequence[Iterable[torch.Tensor]],
kmeans_n_iters: int = 50,
):
assert len(inputs) == self.n_speakers
assert self._hook_handle is None, "hook already installed"
device = next(self.buffers()).device
for spk_id, inps in enumerate(tqdm(inputs, desc="Building codebooks")):
activations: list[torch.Tensor] = []
# TODO: データ多すぎる場合に間引く処理をする
def _collect(_, __, output):
# output: [batch_size, channels, length]
activations.append(output.detach())
handle = target_layer.register_forward_hook(_collect)
for x in inps:
collector_func(x.to(device))
handle.remove()
if not activations:
raise RuntimeError(f"No activation collected for speaker {spk_id}")
# [n_data, channels]
activations: torch.Tensor = torch.cat(
[
a.transpose(1, 2).reshape(a.size(0) * a.size(2), self.channels)
for a in activations
]
)
activations = activations.float()
activations = F.normalize(activations, dim=1, eps=1e-6)
# [codebook_size, channels]
centers = (
self._kmeans_plus_plus(activations, self.codebook_size, kmeans_n_iters)
if activations.size(0) >= self.codebook_size
else self._pad_replicate(activations, self.codebook_size)
)
self.codebooks[spk_id] = centers.to(self.codebooks.dtype)
def forward(
self, x: torch.Tensor, speaker_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
batch_size, channels, length = x.size()
assert channels == self.channels
device = x.device
dtype = x.dtype
if self.training:
if self.training_time_vq == "none":
return x
elif self.training_time_vq == "self":
if self.target_speaker_ids is None:
raise ValueError("target_speaker_ids is not set")
elif self.training_time_vq == "random":
speaker_ids = torch.randint(
0, self.n_speakers, (batch_size,), device=device
)
else:
raise ValueError(f"Unknown training_time_vq: {self.training_time_vq}")
else:
if speaker_ids is None:
return x
speaker_ids = speaker_ids.to(device)
# [batch_size, channels, length] → [batch_size, length, channels]
q = F.normalize(x, dim=1, eps=1e-6)
codes = self.codebooks[speaker_ids].to(q.dtype)
# [batch_size, length, codebook_size]
sim = torch.einsum("bcl,bkc->blk", q, codes)
# [batch_size, length, topk]
_, topk_idx = sim.topk(self.topk, dim=-1)
# [batch_size, length, codebook_size, channels]
expanded_codes = codes[:, None, :, :].expand(-1, length, -1, -1)
# [batch_size, length, topk, channels]
expanded_topk_idx = topk_idx[:, :, :, None].expand(-1, -1, -1, channels)
# [batch_size, length, topk, channels]
gathered = expanded_codes.gather(2, expanded_topk_idx)
# [batch_size, length, channels]
gathered = gathered.mean(2)
# [batch_size, channels, length]
return gathered.transpose(1, 2).to(dtype)
def enable_hook(self, target_layer: nn.Module):
if self._hook_handle is not None:
raise RuntimeError("hook already installed")
self._hook_handle = target_layer.register_forward_hook(self._hook_fn)
def disable_hook(self):
if self._hook_handle is None:
raise RuntimeError("hook not installed")
self._hook_handle.remove()
self._hook_handle = None
def set_target_speaker_ids(self, speaker_ids: Optional[torch.Tensor]):
# この話者が使われる条件は forward() を参照
self.target_speaker_ids = speaker_ids
@staticmethod
def _pad_replicate(x: torch.Tensor, n: int) -> torch.Tensor:
# データ数が n に満たないとき適当に複製して埋める
idx = torch.arange(n, device=x.device) % x.size(0)
return x[idx]
@staticmethod
def _kmeans_plus_plus(
x: torch.Tensor, n_clusters: int, n_iters: int = 50
) -> torch.Tensor:
n_data, _ = x.size()
center_indices = [torch.randint(0, n_data, ()).item()]
min_distances = torch.full((n_data,), math.inf, device=x.device)
for _ in range(1, n_clusters):
last_center_index = center_indices[-1]
min_distances = min_distances.minimum(
torch.cdist(x, x[last_center_index : last_center_index + 1])
.float()
.square_()
.squeeze_(1)
)
probs = min_distances / (min_distances.sum() + 1e-12)
center_indices.append(torch.multinomial(probs, 1).item())
centers = x[center_indices]
del min_distances, probs
for _ in range(n_iters):
distances = torch.cdist(x, centers) # [n_data, n_clusters]
labels = distances.argmin(1) # [n_data]
# [n_clusters, dim]
new_centers = torch.zeros_like(centers).index_add_(0, labels, x)
# [n_clusters]
counts = labels.bincount(minlength=n_clusters)
if (counts == 0).sum().item() != 0:
# TODO: 割り当てがないクラスタの処理
warnings.warn("Some clusters have no assigned data points.")
new_centers /= counts[:, None].clamp_(min=1).float()
centers = new_centers
return centers
def extract_pitch_features(
y: torch.Tensor, # [..., wav_length]
hop_length: int = 160, # 10ms
win_length: int = 560, # 35ms
max_corr_period: int = 256, # 16ms, 62.5Hz (16000 / 256)
corr_win_length: int = 304, # 19ms
instfreq_features_cutoff_bin: int = 64, # 1828Hz (16000 * 64 / 560)
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert max_corr_period + corr_win_length == win_length
# パディングする
padding_length = (win_length - hop_length) // 2
y = F.pad(y, (padding_length, padding_length))
# フレームにする
# [..., win_length, n_frames]
y_frames = y.unfold(-1, win_length, hop_length).transpose_(-2, -1)
# 複素スペクトログラム
# Complex[..., (win_length // 2 + 1), n_frames]
spec: torch.Tensor = torch.fft.rfft(y_frames, n=win_length, dim=-2)
# Complex[..., instfreq_features_cutoff_bin, n_frames]
spec = spec[..., :instfreq_features_cutoff_bin, :]
# 対数パワースペクトログラム
log_power_spec = spec.abs().add_(1e-5).log10_()
# 瞬時位相の時間差分
# 時刻 0 の値は 0
delta_spec = spec[..., :, 1:] * spec[..., :, :-1].conj()
delta_spec /= delta_spec.abs().add_(1e-5)
delta_spec = torch.cat(
[torch.zeros_like(delta_spec[..., :, :1]), delta_spec], dim=-1
)
# [..., instfreq_features_cutoff_bin * 3, n_frames]
instfreq_features = torch.cat(
[log_power_spec, delta_spec.real, delta_spec.imag], dim=-2
)
# 自己相関
# 元々これに 2.0 / corr_win_length を掛けて使おうと思っていたが、
# この値は振幅の 2 乗に比例していて、NN に入力するために良い感じに分散を
# 標準化する方法が思いつかなかったのでやめた
flipped_y_frames = y_frames.flip((-2,))
a = torch.fft.rfft(flipped_y_frames, n=win_length, dim=-2)
b = torch.fft.rfft(y_frames[..., -corr_win_length:, :], n=win_length, dim=-2)
# [..., max_corr_period, n_frames]
corr = torch.fft.irfft(a * b, n=win_length, dim=-2)[..., corr_win_length:, :]
# エネルギー項
energy = flipped_y_frames.square_().cumsum_(-2)
energy0 = energy[..., corr_win_length - 1 : corr_win_length, :]
energy = energy[..., corr_win_length:, :] - energy[..., :-corr_win_length, :]
# Difference function
corr_diff = (energy0 + energy).sub_(corr.mul_(2.0))
assert corr_diff.min() >= -1e-3, corr_diff.min()
corr_diff.clamp_(min=0.0) # 計算誤差対策
# 標準化
corr_diff *= 2.0 / corr_win_length
corr_diff.sqrt_()
# 変換モデルへの入力用のエネルギー
energy = (
(y_frames * torch.signal.windows.cosine(win_length, device=y.device)[..., None])
.square_()
.sum(-2, keepdim=True)
)
energy.clamp_(min=1e-3).log10_() # >= -3, 振幅 1 の正弦波なら大体 2.15
energy *= 0.5 # >= -1.5, 振幅 1 の正弦波なら大体 1.07, 1 の差は振幅で 20dB の差
return (
instfreq_features, # [..., instfreq_features_cutoff_bin * 3, n_frames]
corr_diff, # [..., max_corr_period, n_frames]
energy, # [..., 1, n_frames]
)
class PitchEstimator(nn.Module):
def __init__(
self,
input_instfreq_channels: int = 192,
input_corr_channels: int = 256,
pitch_bins: int = 448,
channels: int = 192,
intermediate_channels: int = 192 * 2,
n_blocks: int = 9,
delay: int = 1, # 10ms, 特徴抽出と合わせると 22.5ms
embed_kernel_size: int = 3,
kernel_size: int = 33,
pitch_bins_per_octave: int = 96,
):
super().__init__()
self.pitch_bins_per_octave = pitch_bins_per_octave
self.instfreq_embed_0 = nn.Conv1d(input_instfreq_channels, channels, 1)
self.instfreq_embed_1 = nn.Conv1d(channels, channels, 1)
self.corr_embed_0 = nn.Conv1d(input_corr_channels, channels, 1)
self.corr_embed_1 = nn.Conv1d(channels, channels, 1)
self.backbone = ConvNeXtStack(
channels,
channels,
intermediate_channels,
n_blocks,
delay,
embed_kernel_size,
kernel_size,
enable_scaling=True,
)
self.head = nn.Conv1d(channels, pitch_bins, 1)
def forward(self, wav: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# wav: [batch_size, 1, wav_length]
# [batch_size, input_instfreq_channels, length],
# [batch_size, input_corr_channels, length]
with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False):
instfreq_features, corr_diff, energy = extract_pitch_features(
wav.squeeze(1),
hop_length=160,
win_length=560,
max_corr_period=256,
corr_win_length=304,
instfreq_features_cutoff_bin=64,
)
instfreq_features = F.gelu(
self.instfreq_embed_0(instfreq_features), approximate="tanh"
)
instfreq_features = self.instfreq_embed_1(instfreq_features)
corr_diff = F.gelu(self.corr_embed_0(corr_diff), approximate="tanh")
corr_diff = self.corr_embed_1(corr_diff)
# [batch_size, channels, length]
x = F.gelu(instfreq_features + corr_diff, approximate="tanh")
x = self.backbone(x)
# [batch_size, pitch_bins, length]
x = self.head(x)
return x, energy
def sample_pitch(
self, pitch: torch.Tensor, band_width: int = 4, return_features: bool = False
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
# pitch: [batch_size, pitch_bins, length]
# 返されるピッチの値には 0 は含まれない
batch_size, pitch_bins, length = pitch.size()
pitch = pitch.softmax(1)
if return_features:
unvoiced_proba = pitch[:, :1, :].clone()
pitch[:, 0, :] = -100.0
pitch = (
pitch.transpose(1, 2).contiguous().view(batch_size * length, 1, pitch_bins)
)
band_pitch = F.conv1d(
pitch,
torch.ones((1, 1, 1), device=pitch.device).expand(1, 1, band_width),
)
# [batch_size * length, 1, pitch_bins - band_width + 1] -> Long[batch_size * length, 1]
quantized_band_pitch = band_pitch.argmax(2)
if return_features:
# [batch_size * length, 1]
band_proba = band_pitch.gather(2, quantized_band_pitch[:, :, None])
# [batch_size * length, 1]
half_pitch_band_proba = band_pitch.gather(
2,
(quantized_band_pitch - self.pitch_bins_per_octave).clamp_(min=1)[
:, :, None
],
)
half_pitch_band_proba[
quantized_band_pitch <= self.pitch_bins_per_octave
] = 0.0
half_pitch_proba = (half_pitch_band_proba / (band_proba + 1e-6)).view(
batch_size, 1, length
)
# [batch_size * length, 1]
double_pitch_band_proba = band_pitch.gather(
2,
(quantized_band_pitch + self.pitch_bins_per_octave).clamp_(
max=pitch_bins - band_width
)[:, :, None],
)
double_pitch_band_proba[
quantized_band_pitch
> pitch_bins - band_width - self.pitch_bins_per_octave
] = 0.0
double_pitch_proba = (double_pitch_band_proba / (band_proba + 1e-6)).view(
batch_size, 1, length
)
# Long[1, pitch_bins]
mask = torch.arange(pitch_bins, device=pitch.device)[None, :]
# bool[batch_size * length, pitch_bins]
mask = (quantized_band_pitch <= mask) & (
mask < quantized_band_pitch + band_width
)
# Long[batch_size, length]
quantized_pitch = (pitch.squeeze(1) * mask).argmax(1).view(batch_size, length)
if return_features:
features = torch.cat(
[unvoiced_proba, half_pitch_proba, double_pitch_proba], dim=1
)
# Long[batch_size, length], [batch_size, 3, length]
return quantized_pitch, features
else:
return quantized_pitch
def merge_weights(self):
self.backbone.merge_weights()
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.instfreq_embed_0, f)
dump_layer(self.instfreq_embed_1, f)
dump_layer(self.corr_embed_0, f)
dump_layer(self.corr_embed_1, f)
dump_layer(self.backbone, f)
dump_layer(self.head, f)
def overlap_add(
ir_amp: torch.Tensor,
ir_phase: torch.Tensor,
window: torch.Tensor,
pitch: torch.Tensor,
hop_length: int = 240,
delay: int = 0,
sr: float = 24000.0,
) -> torch.Tensor:
batch_size, ir_length, length = ir_amp.size()
ir_length = (ir_length - 1) * 2
assert ir_phase.size() == ir_amp.size()
assert window.size() == (ir_length,), (window.size(), ir_amp.size())
assert pitch.size() == (batch_size, length * hop_length)
assert 0 <= delay < ir_length, (delay, ir_length)
# 正規化角周波数 [2π rad]
normalized_freq = pitch / sr
# 初期位相 [2π rad] をランダムに設定
normalized_freq[:, 0] = torch.rand(batch_size, device=pitch.device)
with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False):
phase = (normalized_freq.double().cumsum_(1) % 1.0).float()
# 重ねる箇所を求める
# [n_pitchmarks], [n_pitchmarks]
indices0, indices1 = torch.nonzero(phase[:, :-1] > phase[:, 1:], as_tuple=True)
# 重ねる箇所の小数部分 (位相の遅れ) を求める
numer = 1.0 - phase[indices0, indices1]
# [n_pitchmarks]
fractional_part = numer / (numer + phase[indices0, indices1 + 1])
# 重ねる値を求める
# Complex[n_pitchmarks, ir_length / 2 + 1]
ir_amp = ir_amp[indices0, :, indices1 // hop_length]
ir_phase = ir_phase[indices0, :, indices1 // hop_length]
# 位相遅れの量 [rad]
# [n_pitchmarks, ir_length / 2 + 1]
delay_phase = (
torch.arange(ir_length // 2 + 1, device=pitch.device, dtype=torch.float32)[
None, :
]
* (-math.tau / ir_length)
* fractional_part[:, None]
)
# Complex[n_pitchmarks, ir_length / 2 + 1]
spec = torch.polar(ir_amp, ir_phase + delay_phase)
# [n_pitchmarks, ir_length]
ir = torch.fft.irfft(spec, n=ir_length, dim=1)
ir *= window
# 加算する値をサンプル単位にばらす
# [n_pitchmarks * ir_length]
ir = ir.ravel()
# Long[n_pitchmarks * ir_length]
indices0 = indices0[:, None].expand(-1, ir_length).ravel()
# Long[n_pitchmarks * ir_length]
indices1 = (
indices1[:, None] + torch.arange(ir_length, device=pitch.device)
).ravel()
# overlap-add する
overlap_added_signal = torch.zeros(
(batch_size, length * hop_length + ir_length), device=pitch.device
)
overlap_added_signal.index_put_((indices0, indices1), ir, accumulate=True)
overlap_added_signal = overlap_added_signal[:, delay : -ir_length + delay]
return overlap_added_signal
def generate_noise(
aperiodicity: torch.Tensor, delay: int = 0
) -> tuple[torch.Tensor, torch.Tensor]:
# aperiodicity: [batch_size, hop_length, length]
batch_size, hop_length, length = aperiodicity.size()
excitation = torch.rand(
batch_size, (length + 1) * hop_length, device=aperiodicity.device
)
excitation -= 0.5
n_fft = 2 * hop_length
# 矩形窓で分析
# Complex[batch_size, hop_length + 1, length]
noise = torch.stft(
excitation,
n_fft=n_fft,
hop_length=hop_length,
window=torch.ones(n_fft, device=excitation.device),
center=False,
return_complex=True,
)
assert noise.size(2) == aperiodicity.size(2)
noise[:, 0, :] = 0.0
noise[:, 1:, :] *= aperiodicity
# ハン窓で合成
# torch.istft は最適合成窓が使われるので使えないことに注意
# [batch_size, 2 * hop_length, length]
noise = torch.fft.irfft(noise, n=2 * hop_length, dim=1)
noise *= torch.hann_window(2 * hop_length, device=noise.device)[None, :, None]
# [batch_size, (length + 1) * hop_length]
noise = F.fold(
noise,
(1, (length + 1) * hop_length),
(1, 2 * hop_length),
stride=(1, hop_length),
).squeeze_((1, 2))
assert delay < hop_length
noise = noise[:, delay : -hop_length + delay]
excitation = excitation[:, delay : -hop_length + delay]
return noise, excitation # [batch_size, length * hop_length]
D4C_PREVENT_ZERO_DIVISION = True # False にすると本家の処理
def interp(x: torch.Tensor, y: torch.Tensor, xi: torch.Tensor) -> torch.Tensor:
# x が単調増加で等間隔と仮定
# 外挿は起こらないと仮定
x = torch.as_tensor(x)
y = torch.as_tensor(y)
xi = torch.as_tensor(xi)
if xi.ndim < y.ndim:
diff_ndim = y.ndim - xi.ndim
xi = xi.view(tuple([1] * diff_ndim) + xi.size())
if xi.size()[:-1] != y.size()[:-1]:
xi = xi.expand(y.size()[:-1] + (xi.size(-1),))
assert (x.min(-1).values == x[..., 0]).all()
assert (x.max(-1).values == x[..., -1]).all()
assert (xi.min(-1).values >= x[..., 0]).all()
assert (xi.max(-1).values <= x[..., -1]).all()
delta_x = (x[..., -1].double() - x[..., 0].double()) / (x.size(-1) - 1.0)
delta_x = delta_x.to(x.dtype)
xi = (xi - x[..., :1]).div_(delta_x[..., None])
xi_base = xi.floor()
xi_fraction = xi.sub_(xi_base)
xi_base = xi_base.long()
delta_y = y.diff(dim=-1, append=y[..., -1:])
yi = y.gather(-1, xi_base) + delta_y.gather(-1, xi_base) * xi_fraction
return yi
def linear_smoothing(
group_delay: torch.Tensor, sr: int, n_fft: int, width: torch.Tensor
) -> torch.Tensor:
group_delay = torch.as_tensor(group_delay)
assert group_delay.size(-1) == n_fft // 2 + 1
width = torch.as_tensor(width)
boundary = (width.max() * n_fft / sr).long() + 1
dtype = group_delay.dtype
device = group_delay.device
fft_resolution = sr / n_fft
mirroring_freq_axis = (
torch.arange(-boundary, n_fft // 2 + 1 + boundary, dtype=dtype, device=device)
.add(0.5)
.mul(fft_resolution)
)
if group_delay.ndim == 1:
mirroring_spec = F.pad(
group_delay[None], (boundary, boundary), mode="reflect"
).squeeze_(0)
elif group_delay.ndim >= 4:
shape = group_delay.size()
mirroring_spec = F.pad(
group_delay.view(math.prod(shape[:-1]), group_delay.size(-1)),
(boundary, boundary),
mode="reflect",
).view(shape[:-1] + (shape[-1] + 2 * boundary,))
else:
mirroring_spec = F.pad(group_delay, (boundary, boundary), mode="reflect")
mirroring_segment = mirroring_spec.mul(fft_resolution).cumsum_(-1)
center_freq = torch.arange(n_fft // 2 + 1, dtype=dtype, device=device).mul_(
fft_resolution
)
low_freq = center_freq - width[..., None] * 0.5
high_freq = center_freq + width[..., None] * 0.5
levels = interp(
mirroring_freq_axis, mirroring_segment, torch.cat([low_freq, high_freq], dim=-1)
)
low_levels, high_levels = levels.split([n_fft // 2 + 1] * 2, dim=-1)
smoothed = (high_levels - low_levels).div_(width[..., None])
return smoothed
def dc_correction(
spec: torch.Tensor, sr: int, n_fft: int, f0: torch.Tensor
) -> torch.Tensor:
spec = torch.as_tensor(spec)
f0 = torch.as_tensor(f0)
dtype = spec.dtype
device = spec.device
upper_limit = 2 + (f0 * (n_fft / sr)).long()
max_upper_limit = upper_limit.max()
upper_limit_mask = (
torch.arange(max_upper_limit - 1, device=device) < (upper_limit - 1)[..., None]
)
low_freq_axis = torch.arange(max_upper_limit + 1, dtype=dtype, device=device) * (
sr / n_fft
)
low_freq_replica = interp(
f0[..., None] - low_freq_axis.flip(-1),
spec[..., : max_upper_limit + 1].flip(-1),
low_freq_axis[..., : max_upper_limit - 1] * upper_limit_mask,
)
output = spec.clone()
output[..., : max_upper_limit - 1] += low_freq_replica * upper_limit_mask
return output
def nuttall(n: int, device: torch.types.Device) -> torch.Tensor:
t = torch.linspace(0, math.tau, n, device=device)
coefs = torch.tensor([0.355768, -0.487396, 0.144232, -0.012604], device=device)
terms = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device)
cos_matrix = (terms[:, None] * t).cos_() # [4, n]
window = coefs.matmul(cos_matrix)
return window
def get_windowed_waveform(
x: torch.Tensor,
sr: int,
f0: torch.Tensor,
position: torch.Tensor,
half_window_length_ratio: float,
window_type: Literal["hann", "blackman"],
n_fft: int,
) -> tuple[torch.Tensor, torch.Tensor]:
x = torch.as_tensor(x)
f0 = torch.as_tensor(f0)
position = torch.as_tensor(position)
current_sample = position * sr
# [...]
half_window_length = (half_window_length_ratio * sr / f0).add_(0.5).long()
# [..., fft_size]
base_index = -half_window_length[..., None] + torch.arange(n_fft, device=x.device)
base_index_mask = base_index <= half_window_length[..., None]
# [..., fft_size]
safe_index = ((current_sample + 0.501).long()[..., None] + base_index).clamp_(
0, x.size(-1) - 1
)
# [..., fft_size]
time_axis = base_index.to(x.dtype).div_(half_window_length_ratio)
# [...]
normalized_f0 = math.pi / sr * f0
# [..., fft_size]
phase = time_axis.mul_(normalized_f0[..., None])
if window_type == "hann":
window = phase.cos_().mul_(0.5).add_(0.5)
elif window_type == "blackman":
window = phase.mul(2.0).cos_().mul_(0.08).add_(phase.cos().mul_(0.5)).add_(0.42)
else:
assert False
window *= base_index_mask
prefix_shape = tuple(
max(x_size, i_size) for x_size, i_size in zip(x.size(), safe_index.size())
)[:-1]
waveform = (
x.expand(prefix_shape + (-1,))
.gather(-1, safe_index.expand(prefix_shape + (-1,)))
.mul_(window)
)
if not D4C_PREVENT_ZERO_DIVISION:
waveform += torch.randn_like(window).mul_(1e-12)
waveform *= base_index_mask
waveform -= window * waveform.sum(-1, keepdim=True).div_(
window.sum(-1, keepdim=True)
)
return waveform, window
def get_centroid(x: torch.Tensor, n_fft: int) -> torch.Tensor:
x = torch.as_tensor(x)
if D4C_PREVENT_ZERO_DIVISION:
x = x / x.norm(dim=-1, keepdim=True).clamp(min=6e-8)
else:
x = x / x.norm(dim=-1, keepdim=True)
spec0 = torch.fft.rfft(x, n_fft)
spec1 = torch.fft.rfft(
x * torch.arange(1, x.size(-1) + 1, dtype=x.dtype, device=x.device).div_(n_fft),
n_fft,
)
centroid = spec0.real * spec1.real + spec0.imag * spec1.imag
return centroid
def get_static_centroid(
x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int
) -> torch.Tensor:
"""First step: calculation of temporally static parameters on basis of group delay"""
x1, _ = get_windowed_waveform(
x, sr, f0, position + 0.25 / f0, 2.0, "blackman", n_fft
)
x2, _ = get_windowed_waveform(
x, sr, f0, position - 0.25 / f0, 2.0, "blackman", n_fft
)
centroid1 = get_centroid(x1, n_fft)
centroid2 = get_centroid(x2, n_fft)
return dc_correction(centroid1 + centroid2, sr, n_fft, f0)
def get_smoothed_power_spec(
x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int
) -> tuple[torch.Tensor, torch.Tensor]:
x = torch.as_tensor(x)
f0 = torch.as_tensor(f0)
x, window = get_windowed_waveform(x, sr, f0, position, 2.0, "hann", n_fft)
window_weight = window.square().sum(-1, keepdim=True)
rms = x.square().sum(-1, keepdim=True).div_(window_weight).sqrt_()
if D4C_PREVENT_ZERO_DIVISION:
x = x / (rms * math.sqrt(n_fft)).clamp_(min=6e-8)
smoothed_power_spec = torch.fft.rfft(x, n_fft).abs().square_()
smoothed_power_spec = dc_correction(smoothed_power_spec, sr, n_fft, f0)
smoothed_power_spec = linear_smoothing(smoothed_power_spec, sr, n_fft, f0)
return smoothed_power_spec, rms.detach().squeeze(-1)
def get_static_group_delay(
static_centroid: torch.Tensor,
smoothed_power_spec: torch.Tensor,
sr: int,
f0: torch.Tensor,
n_fft: int,
) -> torch.Tensor:
"""Second step: calculation of parameter shaping"""
if D4C_PREVENT_ZERO_DIVISION:
smoothed_power_spec = smoothed_power_spec.clamp(min=6e-8)
static_group_delay = static_centroid / smoothed_power_spec # t_g
static_group_delay = linear_smoothing(
static_group_delay, sr, n_fft, f0 * 0.5
) # t_gs
smoothed_group_delay = linear_smoothing(static_group_delay, sr, n_fft, f0) # t_gb
static_group_delay = static_group_delay - smoothed_group_delay # t_D
return static_group_delay
def get_coarse_aperiodicity(
group_delay: torch.Tensor,
sr: int,
n_fft: int,
freq_interval: int,
n_aperiodicities: int,
window: torch.Tensor,
) -> torch.Tensor:
"""Third step: estimation of band-aperiodicity"""
group_delay = torch.as_tensor(group_delay)
window = torch.as_tensor(window)
boundary = int(round(n_fft * 8 / window.size(-1)))
half_window_length = window.size(-1) // 2
coarse_aperiodicity = torch.empty(
group_delay.size()[:-1] + (n_aperiodicities,),
dtype=group_delay.dtype,
device=group_delay.device,
)
for i in range(n_aperiodicities):
center = freq_interval * (i + 1) * n_fft // sr
segment = (
group_delay[
..., center - half_window_length : center + half_window_length + 1
]
* window
)
power_spec: torch.Tensor = torch.fft.rfft(segment, n_fft).abs().square_()
cumulative_power_spec = power_spec.sort(-1).values.cumsum_(-1)
if D4C_PREVENT_ZERO_DIVISION:
cumulative_power_spec.clamp_(min=6e-8)
coarse_aperiodicity[..., i] = (
cumulative_power_spec[..., n_fft // 2 - boundary - 1]
/ cumulative_power_spec[..., -1]
)
coarse_aperiodicity.log10_().mul_(10.0)
return coarse_aperiodicity
def d4c_love_train(
x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, threshold: float
) -> int:
x = torch.as_tensor(x)
position = torch.as_tensor(position)
f0: torch.Tensor = torch.as_tensor(f0)
vuv = f0 != 0
lowest_f0 = 40
f0 = f0.clamp(min=lowest_f0)
n_fft = 1 << (3 * sr // lowest_f0).bit_length()
boundary0 = (100 * n_fft - 1) // sr + 1
boundary1 = (4000 * n_fft - 1) // sr + 1
boundary2 = (7900 * n_fft - 1) // sr + 1
waveform, _ = get_windowed_waveform(x, sr, f0, position, 1.5, "blackman", n_fft)
power_spec = torch.fft.rfft(waveform, n_fft).abs().square_()
power_spec[..., : boundary0 + 1] = 0.0
cumulative_spec = power_spec.cumsum_(-1)
vuv = vuv & (
cumulative_spec[..., boundary1] > threshold * cumulative_spec[..., boundary2]
)
return vuv
def d4c_general_body(
x: torch.Tensor,
sr: int,
f0: torch.Tensor,
freq_interval: int,
position: torch.Tensor,
n_fft: int,
n_aperiodicities: int,
window: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
static_centroid = get_static_centroid(x, sr, f0, position, n_fft)
smoothed_power_spec, rms = get_smoothed_power_spec(x, sr, f0, position, n_fft)
static_group_delay = get_static_group_delay(
static_centroid, smoothed_power_spec, sr, f0, n_fft
)
coarse_aperiodicity = get_coarse_aperiodicity(
static_group_delay, sr, n_fft, freq_interval, n_aperiodicities, window
)
coarse_aperiodicity.add_((f0[..., None] - 100.0).div_(50.0)).clamp_(max=0.0)
return coarse_aperiodicity, rms
def d4c(
x: torch.Tensor,
f0: torch.Tensor,
t: torch.Tensor,
sr: int,
threshold: float = 0.85,
n_fft_spec: Optional[int] = None,
coarse_only: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Adapted from https://github.com/tuanad121/Python-WORLD/blob/master/world/d4c.py"""
FLOOR_F0 = 71
FLOOR_F0_D4C = 47
UPPER_LIMIT = 15000
FREQ_INTERVAL = 3000
assert sr == int(sr)
sr = int(sr)
assert sr % 2 == 0
x = torch.as_tensor(x)
f0 = torch.as_tensor(f0)
temporal_positions = torch.as_tensor(t)
n_fft_d4c = 1 << (4 * sr // FLOOR_F0_D4C).bit_length()
if n_fft_spec is None:
n_fft_spec = 1 << (3 * sr // FLOOR_F0).bit_length()
n_aperiodicities = min(UPPER_LIMIT, sr // 2 - FREQ_INTERVAL) // FREQ_INTERVAL
assert n_aperiodicities >= 1
window_length = FREQ_INTERVAL * n_fft_d4c // sr * 2 + 1
window = nuttall(window_length, device=x.device)
freq_axis = torch.arange(n_fft_spec // 2 + 1, device=x.device) * (sr / n_fft_spec)
coarse_aperiodicity, rms = d4c_general_body(
x[..., None, :],
sr,
f0.clamp(min=FLOOR_F0_D4C),
FREQ_INTERVAL,
temporal_positions,
n_fft_d4c,
n_aperiodicities,
window,
)
if coarse_only:
return coarse_aperiodicity, rms
even_coarse_axis = (
torch.arange(n_aperiodicities + 3, device=x.device) * FREQ_INTERVAL
)
assert even_coarse_axis[-2] <= sr // 2 < even_coarse_axis[-1], sr
coarse_axis_low = (
torch.arange(n_aperiodicities + 1, dtype=torch.float, device=x.device)
* FREQ_INTERVAL
)
aperiodicity_low = interp(
coarse_axis_low,
F.pad(coarse_aperiodicity, (1, 0), value=-60.0),
freq_axis[freq_axis < n_aperiodicities * FREQ_INTERVAL],
)
coarse_axis_high = torch.tensor(
[n_aperiodicities * FREQ_INTERVAL, sr * 0.5], device=x.device
)
aperiodicity_high = interp(
coarse_axis_high,
F.pad(coarse_aperiodicity[..., -1:], (0, 1), value=-1e-12),
freq_axis[freq_axis >= n_aperiodicities * FREQ_INTERVAL],
)
aperiodicity = torch.cat([aperiodicity_low, aperiodicity_high], -1)
aperiodicity = 10.0 ** (aperiodicity / 20.0)
vuv = d4c_love_train(x[..., None, :], sr, f0, temporal_positions, threshold)
aperiodicity = torch.where(vuv[..., None], aperiodicity, 1 - 1e-12)
return aperiodicity, coarse_aperiodicity
class Vocoder(nn.Module):
def __init__(
self,
channels: int,
speaker_embedding_channels: int = 128,
hop_length: int = 240,
n_pre_blocks: int = 4,
out_sample_rate: float = 24000.0,
):
super().__init__()
self.hop_length = hop_length
self.out_sample_rate = out_sample_rate
self.prenet = ConvNeXtStack(
in_channels=channels,
channels=channels,
intermediate_channels=channels * 2,
n_blocks=n_pre_blocks,
delay=2, # 20ms 遅延
embed_kernel_size=7,
kernel_size=33,
enable_scaling=True,
use_mha=True,
cross_attention=True,
kv_channels=speaker_embedding_channels,
)
self.ir_generator = ConvNeXtStack(
in_channels=channels,
channels=channels,
intermediate_channels=channels * 2,
n_blocks=2,
delay=0,
embed_kernel_size=3,
kernel_size=33,
use_weight_standardization=True,
enable_scaling=True,
)
self.ir_generator_post = WSConv1d(channels, 512, 1)
self.register_buffer("ir_scale", torch.tensor(1.0))
self.ir_window = nn.Parameter(torch.ones(512))
self.aperiodicity_generator = ConvNeXtStack(
in_channels=channels,
channels=channels,
intermediate_channels=channels * 2,
n_blocks=1,
delay=0,
embed_kernel_size=3,
kernel_size=33,
use_weight_standardization=True,
enable_scaling=True,
)
self.aperiodicity_generator_post = WSConv1d(channels, hop_length, 1, bias=False)
self.register_buffer("aperiodicity_scale", torch.tensor(0.005))
self.post_filter_generator = ConvNeXtStack(
in_channels=channels,
channels=channels,
intermediate_channels=channels * 2,
n_blocks=1,
delay=0,
embed_kernel_size=3,
kernel_size=33,
use_weight_standardization=True,
enable_scaling=True,
)
self.post_filter_generator_post = WSConv1d(channels, 512, 1, bias=False)
self.register_buffer("post_filter_scale", torch.tensor(0.01))
def forward(
self, x: torch.Tensor, pitch: torch.Tensor, speaker_embedding: torch.Tensor
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
# x: [batch_size, channels, length]
# pitch: [batch_size, length]
# speaker_embedding: [batch_size, speaker_embedding_length, speaker_embedding_channels]
batch_size, _, length = x.size()
x = self.prenet(x, speaker_embedding)
ir = self.ir_generator(x)
ir = F.silu(ir, inplace=True)
# [batch_size, 512, length]
ir = self.ir_generator_post(ir)
ir *= self.ir_scale
ir_amp = ir[:, : ir.size(1) // 2 + 1, :].exp()
ir_phase = F.pad(ir[:, ir.size(1) // 2 + 1 :, :], (0, 0, 1, 1))
ir_phase[:, 1::2, :] += math.pi
# TODO: 直流成分が正の値しか取れないのを修正する
# 最近傍補間
# [batch_size, length * hop_length]
pitch = torch.repeat_interleave(pitch, self.hop_length, dim=1)
# [batch_size, length * hop_length]
periodic_signal = overlap_add(
ir_amp,
ir_phase,
self.ir_window,
pitch,
self.hop_length,
delay=0,
sr=self.out_sample_rate,
)
aperiodicity = self.aperiodicity_generator(x)
aperiodicity = F.silu(aperiodicity, inplace=True)
# [batch_size, hop_length, length]
aperiodicity = self.aperiodicity_generator_post(aperiodicity)
aperiodicity *= self.aperiodicity_scale
# [batch_size, length * hop_length], [batch_size, length * hop_length]
aperiodic_signal, noise_excitation = generate_noise(aperiodicity, delay=0)
post_filter = self.post_filter_generator(x)
post_filter = F.silu(post_filter, inplace=True)
# [batch_size, 512, length]
post_filter = self.post_filter_generator_post(post_filter)
post_filter *= self.post_filter_scale
post_filter[:, 0, :] += 1.0
# [batch_size, length, 512]
post_filter = post_filter.transpose(1, 2)
with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False):
periodic_signal = periodic_signal.float()
aperiodic_signal = aperiodic_signal.float()
post_filter = post_filter.float()
post_filter = torch.fft.rfft(post_filter, n=768)
# [batch_size, length, 768]
periodic_signal = torch.fft.irfft(
torch.fft.rfft(
periodic_signal.view(batch_size, length, self.hop_length), n=768
)
* post_filter,
n=768,
)
aperiodic_signal = torch.fft.irfft(
torch.fft.rfft(
aperiodic_signal.view(batch_size, length, self.hop_length), n=768
)
* post_filter,
n=768,
)
periodic_signal = F.fold(
periodic_signal.transpose(1, 2),
(1, (length - 1) * self.hop_length + 768),
(1, 768),
stride=(1, self.hop_length),
).squeeze_((1, 2))
aperiodic_signal = F.fold(
aperiodic_signal.transpose(1, 2),
(1, (length - 1) * self.hop_length + 768),
(1, 768),
stride=(1, self.hop_length),
).squeeze_((1, 2))
periodic_signal = periodic_signal[:, 120 : 120 + length * self.hop_length]
aperiodic_signal = aperiodic_signal[:, 120 : 120 + length * self.hop_length]
noise_excitation = noise_excitation[:, 120:]
# TODO: compensation の正確さが怪しくなってくる。今も本当に必要なのか?
# [batch_size, 1, length * hop_length]
y_g_hat = (periodic_signal + aperiodic_signal)[:, None, :]
return y_g_hat, {
"periodic_signal": periodic_signal.detach(),
"aperiodic_signal": aperiodic_signal.detach(),
"noise_excitation": noise_excitation.detach(),
}
def merge_weights(self):
self.prenet.merge_weights()
self.ir_generator.merge_weights()
self.ir_generator_post.merge_weights()
self.aperiodicity_generator.merge_weights()
self.aperiodicity_generator_post.merge_weights()
self.ir_generator_post.weight.data *= self.ir_scale
self.ir_generator_post.bias.data *= self.ir_scale
self.ir_scale.fill_(1.0)
self.aperiodicity_generator_post.weight.data *= self.aperiodicity_scale
self.aperiodicity_scale.fill_(1.0)
self.post_filter_generator.merge_weights()
self.post_filter_generator_post.merge_weights()
self.post_filter_generator_post.weight.data *= self.post_filter_scale
self.post_filter_scale.fill_(1.0)
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.prenet, f)
dump_layer(self.ir_generator, f)
dump_layer(self.ir_generator_post, f)
dump_layer(self.ir_window, f)
dump_layer(self.aperiodicity_generator, f)
dump_layer(self.aperiodicity_generator_post, f)
dump_layer(self.post_filter_generator, f)
dump_layer(self.post_filter_generator_post, f)
def compute_loudness(
x: torch.Tensor, sr: int, win_lengths: list[int]
) -> list[torch.Tensor]:
# x: [batch_size, wav_length]
assert x.ndim == 2
n_fft = 2048
chunk_length = n_fft // 2
n_taps = chunk_length + 1
results = []
with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False):
if not hasattr(compute_loudness, "filter"):
compute_loudness.filter = {}
if sr not in compute_loudness.filter:
ir = torch.zeros(n_taps, device=x.device, dtype=torch.double)
ir[0] = 0.5
ir = torchaudio.functional.treble_biquad(
ir, sr, 4.0, 1500.0, 1.0 / math.sqrt(2)
)
ir = torchaudio.functional.highpass_biquad(ir, sr, 38.0, 0.5)
ir *= 2.0
compute_loudness.filter[sr] = torch.fft.rfft(ir, n=n_fft).to(
torch.complex64
)
x = x.float()
wav_length = x.size(-1)
if wav_length % chunk_length != 0:
x = F.pad(x, (0, chunk_length - wav_length % chunk_length))
padded_wav_length = x.size(-1)
x = x.view(x.size()[:-1] + (padded_wav_length // chunk_length, chunk_length))
x = torch.fft.irfft(
torch.fft.rfft(x, n=n_fft) * compute_loudness.filter[sr],
n=n_fft,
)
x = F.fold(
x.transpose(-2, -1),
(1, padded_wav_length + chunk_length),
(1, n_fft),
stride=(1, chunk_length),
).squeeze_((-3, -2))[..., :wav_length]
x.square_()
for win_length in win_lengths:
hop_length = win_length // 4
# [..., n_frames]
energy = (
x.unfold(-1, win_length, hop_length)
.matmul(torch.hann_window(win_length, device=x.device))
.add_(win_length / 4.0 * 1e-5)
.log10_()
)
# フィルタリング後の波形が振幅 1 の正弦波なら大体 log10(win_length/4), 1 の差は 10dB の差
results.append(energy)
return results
def beatrice_slice_segments(
x: torch.Tensor, start_indices: torch.Tensor, segment_length: int
) -> torch.Tensor:
batch_size, channels, _ = x.size()
# [batch_size, 1, segment_size]
indices = start_indices[:, None, None] + torch.arange(
segment_length, device=start_indices.device
)
# [batch_size, channels, segment_size]
indices = indices.expand(batch_size, channels, segment_length)
return x.gather(2, indices)
class ConverterNetwork(nn.Module):
def __init__(
self,
phone_extractor: PhoneExtractor,
pitch_estimator: PitchEstimator,
n_speakers: int,
pitch_bins: int,
hidden_channels: int,
vq_topk: int = 4,
training_time_vq: Literal["none", "self", "random"] = "none",
phone_noise_ratio: int = 0.5,
floor_noise_level: float = 1e-3,
):
super().__init__()
self.frozen_modules = {
"phone_extractor": phone_extractor.eval().requires_grad_(False),
"pitch_estimator": pitch_estimator.eval().requires_grad_(False),
}
self.pitch_bins = pitch_bins
self.phone_noise_ratio = phone_noise_ratio
self.floor_noise_level = floor_noise_level
self.out_sample_rate = out_sample_rate = 24000
phone_channels = 128
self.vq = VectorQuantizer(
n_speakers=n_speakers,
codebook_size=512,
channels=phone_channels,
topk=vq_topk,
training_time_vq=training_time_vq,
)
self.embed_phone = nn.Conv1d(phone_channels, hidden_channels, 1)
self.embed_phone.weight.data.normal_(0.0, math.sqrt(2.0 / (256 * 5)))
self.embed_phone.bias.data.zero_()
self.embed_quantized_pitch = nn.Embedding(pitch_bins, hidden_channels)
phase = (
torch.arange(pitch_bins, dtype=torch.float)[:, None]
* (
torch.arange(0, hidden_channels, 2, dtype=torch.float)
* (-math.log(10000.0) / hidden_channels)
).exp_()
)
self.embed_quantized_pitch.weight.data[:, 0::2] = phase.sin()
self.embed_quantized_pitch.weight.data[:, 1::2] = phase.cos_()
self.embed_quantized_pitch.weight.data *= math.sqrt(4.0 / 5.0)
self.embed_quantized_pitch.weight.requires_grad_(False)
self.embed_pitch_features = nn.Conv1d(4, hidden_channels, 1)
self.embed_pitch_features.weight.data.normal_(0.0, math.sqrt(2.0 / (4 * 5)))
self.embed_pitch_features.bias.data.zero_()
self.embed_speaker = nn.Embedding(n_speakers, hidden_channels)
self.embed_speaker.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0))
self.embed_formant_shift = nn.Embedding(9, hidden_channels)
self.embed_formant_shift.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0))
self.key_value_speaker_embedding_length = 384
self.key_value_speaker_embedding_channels = 128
self.key_value_speaker_embedding = nn.Embedding(
n_speakers,
self.key_value_speaker_embedding_length
* self.key_value_speaker_embedding_channels,
)
self.key_value_speaker_embedding.weight.data[0].normal_()
self.key_value_speaker_embedding.weight.data[1:] = (
self.key_value_speaker_embedding.weight.data[0]
)
self.vocoder = Vocoder(
channels=hidden_channels,
speaker_embedding_channels=self.key_value_speaker_embedding_channels,
hop_length=out_sample_rate // 100,
n_pre_blocks=4,
out_sample_rate=out_sample_rate,
)
self.melspectrograms = nn.ModuleList()
for win_length, n_mels in [
(32, 5),
(64, 10),
(128, 20),
(256, 40),
(512, 80),
(1024, 160),
(2048, 320),
]:
self.melspectrograms.append(
torchaudio.transforms.MelSpectrogram(
sample_rate=out_sample_rate,
n_fft=win_length,
win_length=win_length,
hop_length=win_length // 4,
n_mels=n_mels,
power=2,
norm="slaney",
mel_scale="slaney",
)
)
def initialize_vq(self, inputs: Sequence[Iterable[torch.Tensor]]):
collector_func = self.frozen_modules["phone_extractor"].units
target_layer = self.frozen_modules["phone_extractor"].head
self.vq.build_codebooks(
collector_func,
target_layer,
inputs,
)
self.vq.enable_hook(target_layer)
def enable_hook(self):
target_layer = self.frozen_modules["phone_extractor"].head
self.vq.enable_hook(target_layer)
def _get_resampler(
self, orig_freq, new_freq, device, cache={}
) -> torchaudio.transforms.Resample:
key = orig_freq, new_freq
if key in cache:
return cache[key]
resampler = torchaudio.transforms.Resample(orig_freq, new_freq).to(
device, non_blocking=True
)
cache[key] = resampler
return resampler
def forward(
self,
x: torch.Tensor,
target_speaker_id: torch.Tensor,
formant_shift_semitone: torch.Tensor,
pitch_shift_semitone: Optional[torch.Tensor] = None,
slice_start_indices: Optional[torch.Tensor] = None,
slice_segment_length: Optional[int] = None,
return_stats: bool = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]:
# x: [batch_size, 1, wav_length]
# target_speaker_id: Long[batch_size]
# formant_shift_semitone: [batch_size]
# pitch_shift_semitone: [batch_size]
# slice_start_indices: [batch_size]
batch_size, _, _ = x.size()
self.vq.set_target_speaker_ids(target_speaker_id)
with torch.inference_mode():
phone_extractor: PhoneExtractor = self.frozen_modules["phone_extractor"]
pitch_estimator: PitchEstimator = self.frozen_modules["pitch_estimator"]
# [batch_size, 1, wav_length] -> [batch_size, phone_channels, length]
phone = phone_extractor.units(x).transpose(1, 2)
if self.training and self.phone_noise_ratio != 0.0:
phone *= (1.0 - self.phone_noise_ratio) / phone.square().mean(
1, keepdim=True
).sqrt_()
noise = torch.randn_like(phone)
noise *= (
self.phone_noise_ratio
/ noise.square().mean(1, keepdim=True).sqrt_()
)
phone += noise
# F.rms_norm は PyTorch >= 2.4 が必要
phone *= (
1.0
/ phone.square()
.mean(1, keepdim=True)
.add_(torch.finfo(torch.float).eps)
.sqrt_()
)
# [batch_size, 1, wav_length] -> [batch_size, pitch_bins, length], [batch_size, 1, length]
pitch, energy = pitch_estimator(x)
# augmentation
if self.training:
# [batch_size, pitch_bins - 1]
weights = pitch.softmax(1)[:, 1:, :].mean(2)
# [batch_size]
mean_pitch = (
weights
* torch.arange(
1,
self.embed_quantized_pitch.num_embeddings,
device=weights.device,
)
).sum(1) / weights.sum(1)
mean_pitch = mean_pitch.round_().long()
target_pitch = torch.randint_like(mean_pitch, 64, 257)
shift = target_pitch - mean_pitch
shift_ratio = (
2.0 ** (shift.float() / pitch_estimator.pitch_bins_per_octave)
).tolist()
shift = []
interval_length = 100 # 1s
interval_zeros = torch.zeros(
(1, 1, interval_length * 160), device=x.device
)
concatenated_shifted_x = []
offsets = [0]
torch.backends.cudnn.benchmark = False
for i in range(batch_size):
shift_ratio_i = shift_ratio[i]
shift_ratio_fraction_i = Fraction.from_float(
shift_ratio_i
).limit_denominator(30)
shift_numer_i = shift_ratio_fraction_i.numerator
shift_denom_i = shift_ratio_fraction_i.denominator
shift_ratio_i = shift_numer_i / shift_denom_i
shift_i = int(
round(
math.log2(shift_ratio_i)
* pitch_estimator.pitch_bins_per_octave
)
)
shift.append(shift_i)
shift_ratio[i] = shift_ratio_i
# [1, 1, wav_length / shift_ratio]
with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False):
shifted_x_i = self._get_resampler(
shift_numer_i, shift_denom_i, x.device
)(x[i])[None]
if shifted_x_i.size(2) % 160 != 0:
shifted_x_i = F.pad(
shifted_x_i,
(0, 160 - shifted_x_i.size(2) % 160),
mode="reflect",
)
assert shifted_x_i.size(2) % 160 == 0
offsets.append(
offsets[-1] + interval_length + shifted_x_i.size(2) // 160
)
concatenated_shifted_x.extend([interval_zeros, shifted_x_i])
if offsets[-1] % 256 != 0:
# 長さが同じ方が何かのキャッシュが効いて早くなるようなので
# 適当に 256 の倍数になるようにパディングして長さのパターン数を減らす
concatenated_shifted_x.append(
torch.zeros(
(1, 1, (256 - offsets[-1] % 256) * 160), device=x.device
)
)
# [batch_size, 1, sum(wav_length) + batch_size * 16000]
concatenated_shifted_x = torch.cat(concatenated_shifted_x, dim=2)
assert concatenated_shifted_x.size(2) % (256 * 160) == 0
# [1, pitch_bins, length / shift_ratio], [1, 1, length / shift_ratio]
concatenated_pitch, concatenated_energy = pitch_estimator(
concatenated_shifted_x
)
for i in range(batch_size):
shift_i = shift[i]
shift_ratio_i = shift_ratio[i]
left = offsets[i] + interval_length
right = offsets[i + 1]
pitch_i = concatenated_pitch[:, :, left:right]
energy_i = concatenated_energy[:, :, left:right]
pitch_i = F.interpolate(
pitch_i,
scale_factor=shift_ratio_i,
mode="linear",
align_corners=False,
)
energy_i = F.interpolate(
energy_i,
scale_factor=shift_ratio_i,
mode="linear",
align_corners=False,
)
assert pitch_i.size(2) == energy_i.size(2)
assert abs(pitch_i.size(2) - pitch.size(2)) <= 10
length = min(pitch_i.size(2), pitch.size(2))
if shift_i > 0:
pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length]
pitch[i : i + 1, 1:-shift_i, :length] = pitch_i[
:, 1 + shift_i :, :length
]
pitch[i : i + 1, -shift_i:, :length] = -10.0
elif shift_i < 0:
pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length]
pitch[i : i + 1, 1 : 1 - shift_i, :length] = -10.0
pitch[i : i + 1, 1 - shift_i :, :length] = pitch_i[
:, 1:shift_i, :length
]
energy[i : i + 1, :, :length] = energy_i[:, :, :length]
torch.backends.cudnn.benchmark = True
# [batch_size, pitch_bins, length] -> Long[batch_size, length], [batch_size, 3, length]
quantized_pitch, pitch_features = pitch_estimator.sample_pitch(
pitch, return_features=True
)
if pitch_shift_semitone is not None:
quantized_pitch = torch.where(
quantized_pitch == 0,
quantized_pitch,
(
quantized_pitch
+ (
pitch_shift_semitone[:, None]
* (pitch_estimator.pitch_bins_per_octave / 12.0)
)
.round_()
.long()
).clamp_(1, self.pitch_bins - 1),
)
pitch = 55.0 * 2.0 ** (
quantized_pitch.float() / pitch_estimator.pitch_bins_per_octave
)
# phone が 2.5ms 先読みしているのに対して、
# energy は 12.5ms, pitch_features は 22.5ms 先読みしているので、
# ずらして phone に合わせる
energy = F.pad(energy[:, :, :-1], (1, 0), mode="reflect")
quantized_pitch = F.pad(quantized_pitch[:, :-2], (2, 0), mode="reflect")
pitch_features = F.pad(pitch_features[:, :, :-2], (2, 0), mode="reflect")
# [batch_size, 1, length], [batch_size, 3, length] -> [batch_size, 4, length]
pitch_features = torch.cat([energy, pitch_features], dim=1)
formant_shift_indices = (
((formant_shift_semitone + 2.0) * 2.0).round_().long()
)
phone = phone.clone()
quantized_pitch = quantized_pitch.clone()
pitch_features = pitch_features.clone()
formant_shift_indices = formant_shift_indices.clone()
pitch = pitch.clone()
# [batch_sise, hidden_channels, length]
x = (
self.embed_phone(phone)
+ self.embed_quantized_pitch(quantized_pitch).transpose(1, 2)
+ self.embed_pitch_features(pitch_features)
+ (
self.embed_speaker(target_speaker_id)[:, :, None]
+ self.embed_formant_shift(formant_shift_indices)[:, :, None]
)
)
if slice_start_indices is not None:
assert slice_segment_length is not None
# [batch_size, hidden_channels, length] -> [batch_size, hidden_channels, segment_length]
x = beatrice_slice_segments(x, slice_start_indices, slice_segment_length)
x = F.silu(x, inplace=True)
speaker_embedding = self.key_value_speaker_embedding(target_speaker_id).view(
batch_size,
self.key_value_speaker_embedding_length,
self.key_value_speaker_embedding_channels,
)
# [batch_size, hidden_channels, segment_length] -> [batch_size, 1, segment_length * 240]
y_g_hat, stats = self.vocoder(x, pitch, speaker_embedding)
stats["pitch"] = pitch
if return_stats:
return y_g_hat, stats
else:
return y_g_hat
def _normalize_melsp(self, x):
return x.clamp(min=1e-10).log_()
def forward_and_compute_loss(
self,
noisy_wavs_16k: torch.Tensor,
target_speaker_id: torch.Tensor,
formant_shift_semitone: torch.Tensor,
slice_start_indices: torch.Tensor,
slice_segment_length: int,
y_all: torch.Tensor,
enable_loss_ap: bool = False,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
dict[str, float],
]:
# noisy_wavs_16k: [batch_size, 1, wav_length]
# target_speaker_id: Long[batch_size]
# formant_shift_semitone: [batch_size]
# slice_start_indices: [batch_size]
# slice_segment_length: int
# y_all: [batch_size, 1, wav_length]
stats = {}
loss_mel = 0.0
loss_loudness = 0.0
loudness_win_lengths = [512, 1024, 2048, 4096]
# [batch_size, 1, wav_length] -> [batch_size, 1, wav_length * 240]
y_hat_all, intermediates = self(
noisy_wavs_16k,
target_speaker_id,
formant_shift_semitone,
return_stats=True,
)
y_hat_all = y_hat_all.detach().where(y_all == 0.0, y_hat_all)
with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False):
periodic_signal = intermediates["periodic_signal"].float()
aperiodic_signal = intermediates["aperiodic_signal"].float()
noise_excitation = intermediates["noise_excitation"].float()
periodic_signal = periodic_signal[:, : noise_excitation.size(1)]
aperiodic_signal = aperiodic_signal[:, : noise_excitation.size(1)]
y_hat_all = y_hat_all.float()
floor_noise = torch.randn_like(y_all) * self.floor_noise_level
y_all = y_all + floor_noise
y_hat_all += floor_noise
y_hat_all_truncated = y_hat_all.squeeze(1)[:, : periodic_signal.size(1)]
y_all_truncated = y_all.squeeze(1)[:, : periodic_signal.size(1)]
y_loudness = compute_loudness(
y_all_truncated, self.out_sample_rate, loudness_win_lengths
)
y_hat_loudness = compute_loudness(
y_hat_all_truncated, self.out_sample_rate, loudness_win_lengths
)
for win_length, y_loudness_i, y_hat_loudness_i in zip(
loudness_win_lengths, y_loudness, y_hat_loudness
):
loss_loudness_i = F.mse_loss(y_hat_loudness_i, y_loudness_i)
loss_loudness += loss_loudness_i * math.sqrt(win_length)
stats[f"loss_loudness_{win_length}"] = loss_loudness_i.item()
for melspectrogram in self.melspectrograms:
melsp_periodic_signal = melspectrogram(periodic_signal)
melsp_aperiodic_signal = melspectrogram(aperiodic_signal)
melsp_noise_excitation = melspectrogram(noise_excitation)
# [1, n_mels, 1]
# 1/6 ... [-0.5, 0.5] の一様乱数の平均パワー
# 3/8 ... ハン窓をかけた時のパワー減衰
# 0.5 ... 謎
reference_melsp = melspectrogram.mel_scale(
torch.full(
(1, melspectrogram.n_fft // 2 + 1, 1),
(1 / 6) * (3 / 8) * 0.5 * melspectrogram.win_length,
device=noisy_wavs_16k.device,
)
)
aperiodic_ratio = melsp_aperiodic_signal / (
melsp_periodic_signal + melsp_aperiodic_signal + 1e-5
)
compensation_ratio = reference_melsp / (melsp_noise_excitation + 1e-5)
melsp_y_hat = melspectrogram(y_hat_all_truncated)
melsp_y_hat = melsp_y_hat * (
(1.0 - aperiodic_ratio) + aperiodic_ratio * compensation_ratio
)
y_hat_mel = self._normalize_melsp(melsp_y_hat)
y_mel = self._normalize_melsp(melspectrogram(y_all_truncated))
loss_mel_i = F.l1_loss(y_hat_mel, y_mel)
loss_mel += loss_mel_i
stats[
f"loss_mel_{melspectrogram.win_length}_{melspectrogram.n_mels}"
] = loss_mel_i.item()
loss_mel /= len(self.melspectrograms)
if enable_loss_ap:
t = (
torch.arange(intermediates["pitch"].size(1), device=y_all.device)
* 0.01
+ 0.005
)
y_coarse_aperiodicity, y_rms = d4c(
y_all.squeeze(1),
intermediates["pitch"],
t,
self.vocoder.out_sample_rate,
coarse_only=True,
)
y_coarse_aperiodicity = 10.0 ** (y_coarse_aperiodicity / 10.0)
y_hat_coarse_aperiodicity, y_hat_rms = d4c(
y_hat_all.squeeze(1),
intermediates["pitch"],
t,
self.vocoder.out_sample_rate,
coarse_only=True,
)
y_hat_coarse_aperiodicity = 10.0 ** (y_hat_coarse_aperiodicity / 10.0)
rms = torch.maximum(y_rms, y_hat_rms)
loss_ap = F.mse_loss(
y_hat_coarse_aperiodicity, y_coarse_aperiodicity, reduction="none"
)
loss_ap *= (rms / (rms + 1e-3) * (rms > 1e-5))[:, :, None]
loss_ap = loss_ap.mean()
else:
loss_ap = torch.tensor(0.0)
# [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240]
y_hat = beatrice_slice_segments(
y_hat_all, slice_start_indices * 240, slice_segment_length * 240
)
# [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240]
y = beatrice_slice_segments(y_all, slice_start_indices * 240, slice_segment_length * 240)
return y, y_hat, y_hat_all, loss_loudness, loss_mel, loss_ap, stats
def merge_weights(self):
self.vocoder.merge_weights()
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.embed_phone, f)
dump_layer(self.embed_quantized_pitch, f)
dump_layer(self.embed_pitch_features, f)
dump_layer(self.vocoder, f)
def dump_speaker_embeddings(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump_speaker_embeddings(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_params(self.vq.codebooks, f)
dump_layer(self.embed_speaker, f)
dump_layer(self.embed_formant_shift, f)
dump_layer(self.key_value_speaker_embedding, f)
def dump_embedding_setter(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump_embedding_setter(f)
return
if not hasattr(f, "write"):
raise TypeError
self.vocoder.prenet.dump_kv(f)
# Discriminator
def _normalize(tensor: torch.Tensor, dim: int) -> torch.Tensor:
denom = tensor.norm(p=2.0, dim=dim, keepdim=True).clamp_min(1e-6)
return tensor / denom
class SANConv2d(nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
bias: bool = True,
padding_mode="zeros",
device=None,
dtype=None,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding=padding,
dilation=dilation,
groups=1,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
scale = self.weight.norm(p=2.0, dim=[1, 2, 3], keepdim=True).clamp_min(1e-6)
self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight))
self.scale = nn.parameter.Parameter(scale.view(out_channels))
if bias:
self.bias = nn.parameter.Parameter(
torch.zeros(in_channels, device=device, dtype=dtype)
)
else:
self.register_parameter("bias", None)
def forward(
self, input: torch.Tensor, flg_san_train: bool = False
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.bias is not None:
input = input + self.bias.view(self.in_channels, 1, 1)
normalized_weight = self._get_normalized_weight()
scale = self.scale.view(self.out_channels, 1, 1)
if flg_san_train:
out_fun = F.conv2d(
input,
normalized_weight.detach(),
None,
self.stride,
self.padding,
self.dilation,
self.groups,
)
out_dir = F.conv2d(
input.detach(),
normalized_weight,
None,
self.stride,
self.padding,
self.dilation,
self.groups,
)
out = out_fun * scale, out_dir * scale.detach()
else:
out = F.conv2d(
input,
normalized_weight,
None,
self.stride,
self.padding,
self.dilation,
self.groups,
)
out = out * scale
return out
@torch.no_grad()
def normalize_weight(self):
self.weight.data = self._get_normalized_weight()
def _get_normalized_weight(self) -> torch.Tensor:
return _normalize(self.weight, dim=[1, 2, 3])
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return (kernel_size * dilation - dilation) // 2
class BeatriceDiscriminatorP(nn.Module):
def __init__(
self, period: int, kernel_size: int = 5, stride: int = 3, san: bool = False
):
super().__init__()
self.period = period
self.san = san
# fmt: off
self.convs = nn.ModuleList([
weight_norm(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))),
weight_norm(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))),
weight_norm(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))),
weight_norm(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))),
weight_norm(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, (get_padding(kernel_size, 1), 0))),
])
# fmt: on
if san:
self.conv_post = SANConv2d(1024, 1, (3, 1), 1, (1, 0))
else:
self.conv_post = weight_norm(nn.Conv2d(1024, 1, (3, 1), 1, (1, 0)))
def forward(
self, x: torch.Tensor, flg_san_train: bool = False
) -> tuple[
Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor]
]:
fmap = []
b, c, t = x.shape
if t % self.period != 0:
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for conv in self.convs:
x = conv(x)
x = F.silu(x, inplace=True)
fmap.append(x)
if self.san:
x = self.conv_post(x, flg_san_train=flg_san_train)
else:
x = self.conv_post(x)
if flg_san_train:
x_fun, x_dir = x
fmap.append(x_fun)
x_fun = torch.flatten(x_fun, 1, -1)
x_dir = torch.flatten(x_dir, 1, -1)
x = x_fun, x_dir
else:
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class BeatriceDiscriminatorR(nn.Module):
def __init__(self, resolution: int, san: bool = False):
super().__init__()
self.resolution = resolution
self.san = san
assert len(self.resolution) == 3
self.convs = nn.ModuleList(
[
weight_norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))),
weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))),
weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))),
weight_norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
]
)
if san:
self.conv_post = SANConv2d(32, 1, (3, 3), padding=(1, 1))
else:
self.conv_post = weight_norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
def forward(
self, x: torch.Tensor, flg_san_train: bool = False
) -> tuple[
Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor]
]:
fmap = []
x = self._spectrogram(x).unsqueeze(1)
for conv in self.convs:
x = conv(x)
x = F.silu(x, inplace=True)
fmap.append(x)
if self.san:
x = self.conv_post(x, flg_san_train=flg_san_train)
else:
x = self.conv_post(x)
if flg_san_train:
x_fun, x_dir = x
fmap.append(x_fun)
x_fun = torch.flatten(x_fun, 1, -1)
x_dir = torch.flatten(x_dir, 1, -1)
x = x_fun, x_dir
else:
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
def _spectrogram(self, x: torch.Tensor) -> torch.Tensor:
n_fft, hop_length, win_length = self.resolution
x = F.pad(
x, ((n_fft - hop_length) // 2, (n_fft - hop_length) // 2), mode="reflect"
).squeeze(1)
with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False):
mag = torch.stft(
x.float(),
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=torch.ones(win_length, device=x.device),
center=False,
return_complex=True,
).abs()
return mag
class BeatriceMultiPeriodDiscriminator(nn.Module):
def __init__(self, san: bool = False):
super().__init__()
resolutions = [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]]
periods = [2, 3, 5, 7, 11]
self.discriminators = nn.ModuleList(
[BeatriceDiscriminatorR(r, san=san) for r in resolutions]
+ [BeatriceDiscriminatorP(p, san=san) for p in periods]
)
self.discriminator_names = [f"R_{n}_{h}_{w}" for n, h, w in resolutions] + [
f"P_{p}" for p in periods
]
self.san = san
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, flg_san_train: bool = False
) -> tuple[
list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]],
list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]],
list[list[torch.Tensor]],
list[list[torch.Tensor]],
]:
batch_size = y.size(0)
concatenated_y_y_hat = torch.cat([y, y_hat])
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
if flg_san_train:
(y_d_fun, y_d_dir), fmap = d(
concatenated_y_y_hat, flg_san_train=flg_san_train
)
y_d_r_fun, y_d_g_fun = torch.split(y_d_fun, batch_size)
y_d_r_dir, y_d_g_dir = torch.split(y_d_dir, batch_size)
y_d_r = y_d_r_fun, y_d_r_dir
y_d_g = y_d_g_fun, y_d_g_dir
else:
y_d, fmap = d(concatenated_y_y_hat, flg_san_train=flg_san_train)
y_d_r, y_d_g = torch.split(y_d, batch_size)
fmap_r = []
fmap_g = []
for fm in fmap:
fm_r, fm_g = torch.split(fm, batch_size)
fmap_r.append(fm_r)
fmap_g.append(fm_g)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
def forward_and_compute_loss(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float]]:
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self(y, y_hat, flg_san_train=self.san)
stats = {}
assert len(y_d_gs) == len(y_d_rs) == len(self.discriminators)
with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=False):
# discriminator loss
d_loss = 0.0
for dr, dg, name in zip(y_d_rs, y_d_gs, self.discriminator_names):
if self.san:
dr_fun, dr_dir = map(lambda x: x.float(), dr)
dg_fun, dg_dir = map(lambda x: x.float(), dg)
r_loss_fun = F.softplus(1.0 - dr_fun).square().mean()
g_loss_fun = F.softplus(dg_fun).square().mean()
r_loss_dir = F.softplus(1.0 - dr_dir).square().mean()
g_loss_dir = -F.softplus(1.0 - dg_dir).square().mean()
r_loss = r_loss_fun + r_loss_dir
g_loss = g_loss_fun + g_loss_dir
else:
dr = dr.float()
dg = dg.float()
r_loss = (1.0 - dr).square().mean()
g_loss = dg.square().mean()
stats[f"{name}_dr_loss"] = r_loss.item()
stats[f"{name}_dg_loss"] = g_loss.item()
d_loss += r_loss + g_loss
# adversarial loss
adv_loss = 0.0
for dg, name in zip(y_d_gs, self.discriminator_names):
if self.san:
dg_fun = dg[0].float()
g_loss = F.softplus(1.0 - dg_fun).square().mean()
else:
dg = dg.float()
g_loss = (1.0 - dg).square().mean()
stats[f"{name}_gg_loss"] = g_loss.item()
adv_loss += g_loss
# feature mathcing loss
fm_loss = 0.0
for fr, fg, name in zip(fmap_rs, fmap_gs, self.discriminator_names):
fm_loss_i = 0.0
for j, (r, g) in enumerate(zip(fr, fg)):
fm_loss_ij = (r.detach().float() - g.float()).abs().mean()
stats[f"~{name}_fm_loss_{j}"] = fm_loss_ij.item()
fm_loss_i += fm_loss_ij
stats[f"{name}_fm_loss"] = fm_loss_i.item()
fm_loss += fm_loss_i
return d_loss, adv_loss, fm_loss, stats
class GradBalancer:
"""Adapted from https://github.com/facebookresearch/encodec/blob/main/encodec/balancer.py"""
def __init__(
self,
weights: dict[str, float],
rescale_grads: bool = True,
total_norm: float = 1.0,
ema_decay: float = 0.999,
per_batch_item: bool = True,
):
self.weights = weights
self.per_batch_item = per_batch_item
self.total_norm = total_norm
self.ema_decay = ema_decay
self.rescale_grads = rescale_grads
self.ema_total: dict[str, float] = defaultdict(float)
self.ema_fix: dict[str, float] = defaultdict(float)
def backward(
self,
losses: dict[str, torch.Tensor],
input: torch.Tensor,
scaler: Optional[torch.amp.GradScaler] = None,
skip_update_ema: bool = False,
) -> dict[str, float]:
stats = {}
if skip_update_ema:
assert len(losses) == len(self.ema_total)
ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()}
else:
# 各 loss に対して d loss / d input とそのノルムを計算する
norms = {}
grads = {}
for name, loss in losses.items():
if scaler is not None:
loss = scaler.scale(loss)
(grad,) = torch.autograd.grad(loss, [input], retain_graph=True)
if not grad.isfinite().all():
input.backward(grad)
return {}
grad = grad.detach() / (1.0 if scaler is None else scaler.get_scale())
if self.per_batch_item:
dims = tuple(range(1, grad.dim()))
ema_norm = grad.norm(dim=dims).mean()
else:
ema_norm = grad.norm()
norms[name] = float(ema_norm)
grads[name] = grad
# ノルムの移動平均を計算する
for key, value in norms.items():
self.ema_total[key] = self.ema_total[key] * self.ema_decay + value
self.ema_fix[key] = self.ema_fix[key] * self.ema_decay + 1.0
ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()}
# ログを取る
total_ema_norm = sum(ema_norms.values())
for k, ema_norm in ema_norms.items():
stats[f"grad_norm_value_{k}"] = ema_norm
stats[f"grad_norm_ratio_{k}"] = ema_norm / (total_ema_norm + 1e-12)
# loss の係数の比率を計算する
if self.rescale_grads:
total_weights = sum([self.weights[k] for k in ema_norms])
ratios = {k: w / total_weights for k, w in self.weights.items()}
# 勾配を修正する
loss = 0.0
for name, ema_norm in ema_norms.items():
if self.rescale_grads:
scale = ratios[name] * self.total_norm / (ema_norm + 1e-12)
else:
scale = self.weights[name]
loss += (losses if skip_update_ema else grads)[name] * scale
if scaler is not None:
loss = scaler.scale(loss)
if skip_update_ema:
(loss,) = torch.autograd.grad(loss, [input])
input.backward(loss)
return stats
def state_dict(self) -> dict[str, dict[str, float]]:
return {
"ema_total": dict(self.ema_total),
"ema_fix": dict(self.ema_fix),
}
def load_state_dict(self, state_dict):
self.ema_total = defaultdict(float, state_dict["ema_total"])
self.ema_fix = defaultdict(float, state_dict["ema_fix"])
class QualityTester(nn.Module):
def __init__(self):
super().__init__()
self.utmos = torch.hub.load(
"tarepan/SpeechMOS:v1.0.0", "utmos22_strong", trust_repo=True
).eval()
@torch.inference_mode()
def compute_mos(self, wav: torch.Tensor) -> dict[str, list[float]]:
res = {"utmos": self.utmos(wav, sr=16000).tolist()}
return res
def test(
self, converted_wav: torch.Tensor, source_wav: torch.Tensor
) -> dict[str, list[float]]:
# [batch_size, wav_length]
res = {}
res.update(self.compute_mos(converted_wav))
return res
def test_many(
self, converted_wavs: list[torch.Tensor], source_wavs: list[torch.Tensor]
) -> tuple[dict[str, float], dict[str, list[float]]]:
# list[batch_size, wav_length]
results = defaultdict(list)
assert len(converted_wavs) == len(source_wavs)
for converted_wav, source_wav in zip(converted_wavs, source_wavs):
res = self.test(converted_wav, source_wav)
for metric_name, value in res.items():
results[metric_name].extend(value)
return {
metric_name: sum(values) / len(values)
for metric_name, values in results.items()
}, results
def compute_grad_norm(
model: nn.Module, return_stats: bool = False
) -> Union[float, dict[str, float]]:
total_norm = 0.0
stats = {}
for name, p in model.named_parameters():
if p.grad is None:
continue
param_norm = p.grad.data.norm().item()
if not math.isfinite(param_norm):
param_norm = p.grad.data.float().norm().item()
total_norm += param_norm * param_norm
if return_stats:
stats[f"grad_norm_{name}"] = param_norm
total_norm = math.sqrt(total_norm)
if return_stats:
return total_norm, stats
else:
return total_norm
def compute_mean_f0(
files: list[Path], method: Literal["dio", "harvest"] = "dio"
) -> float:
sum_log_f0 = 0.0
n_frames = 0
for file in files:
wav, sr = beatrice_load_audio(file)
if method == "dio":
f0, _ = pyworld.dio(wav.ravel().numpy().astype(np.float64), sr)
elif method == "harvest":
f0, _ = pyworld.harvest(wav.ravel().numpy().astype(np.float64), sr)
else:
raise ValueError(f"Invalid method: {method}")
f0 = f0[f0 > 0]
sum_log_f0 += float(np.log(f0).sum())
n_frames += len(f0)
if n_frames == 0:
return math.nan
mean_log_f0 = sum_log_f0 / n_frames
return math.exp(mean_log_f0)
def get_resampler(
sr_before: int, sr_after: int, device="cpu", cache={}
) -> torchaudio.transforms.Resample:
if not isinstance(device, str):
device = str(device)
if (sr_before, sr_after, device) not in cache:
cache[(sr_before, sr_after, device)] = torchaudio.transforms.Resample(
sr_before, sr_after
).to(device)
return cache[(sr_before, sr_after, device)]
def convolve(signal: torch.Tensor, ir: torch.Tensor) -> torch.Tensor:
n = 1 << (signal.size(-1) + ir.size(-1) - 2).bit_length()
res = torch.fft.irfft(torch.fft.rfft(signal, n=n) * torch.fft.rfft(ir, n=n), n=n)
return res[..., : signal.size(-1)]
def random_formant_shift(
wav: torch.Tensor,
sample_rate: int,
formant_shift_semitone_min: float = -3.0,
formant_shift_semitone_max: float = 3.0,
) -> torch.Tensor:
assert wav.ndim == 2
assert wav.size(0) == 1
device = wav.device
hop_length = 256
# [wav_length]
wav_np = wav.ravel().double().cpu().numpy()
f0, t = pyworld.dio(
wav_np,
sample_rate,
f0_floor=55,
f0_ceil=1400,
frame_period=hop_length * 1000 / sample_rate,
)
f0 = pyworld.stonemask(wav_np, f0, t, sample_rate)
world_sp = pyworld.cheaptrick(wav_np, f0, t, sample_rate)
world_sp = (
torch.from_numpy(world_sp).float().to(device).sqrt_()[None]
) # [1, length, n_fft // 2 + 1]
n_fft = win_length = (world_sp.size(2) - 1) * 2
window = torch.hann_window(win_length, device=device)
# [1, n_fft // 2 + 1, length]
stft_sp = torch.stft(
wav,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
return_complex=True,
)
assert world_sp.size(1) == stft_sp.size(2), (world_sp.size(), stft_sp.size())
assert world_sp.size(2) == stft_sp.size(1), (world_sp.size(), stft_sp.size())
shift_semitones = (
torch.rand(()).item()
* (formant_shift_semitone_max - formant_shift_semitone_min)
+ formant_shift_semitone_min
)
shift_ratio = 2.0 ** (shift_semitones / 12.0)
shifted_world_sp = F.interpolate(
world_sp, scale_factor=shift_ratio, mode="linear", align_corners=True
)
if shifted_world_sp.size(2) > n_fft // 2 + 1:
shifted_world_sp = shifted_world_sp[:, :, : n_fft // 2 + 1]
elif shifted_world_sp.size(2) < n_fft // 2 + 1:
shifted_world_sp = F.pad(
shifted_world_sp, (0, n_fft // 2 + 1 - shifted_world_sp.size(2))
)
ratio = ((shifted_world_sp + 1e-5) / (world_sp + 1e-5)).clamp(0.1, 10.0)
stft_sp *= ratio.transpose(-2, -1) # [1, n_fft // 2 + 1, length]
out = torch.istft(
stft_sp,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
length=wav.size(-1),
)
return out
def random_filter(audio: torch.Tensor) -> torch.Tensor:
assert audio.ndim == 2
ab = torch.rand(audio.size(0), 6) * 0.75 - 0.375
a, b = ab[:, :3], ab[:, 3:]
a[:, 0] = 1.0
b[:, 0] = 1.0
audio = torchaudio.functional.lfilter(audio, a, b, clamp=False)
return audio
def get_noise(
n_samples: int, sample_rate: float, files: list[Union[str, bytes, os.PathLike]]
) -> torch.Tensor:
resample_augmentation_candidates = [0.9, 0.95, 1.0, 1.05, 1.1]
wavs = []
current_length = 0
while current_length < n_samples:
idx_files = torch.randint(0, len(files), ())
file = files[idx_files]
wav, sr = beatrice_load_audio(file)
assert wav.size(0) == 1
augmented_sample_rate = int(
round(
sample_rate
* resample_augmentation_candidates[
torch.randint(0, len(resample_augmentation_candidates), ())
]
)
)
resampler = get_resampler(sr, augmented_sample_rate)
wav = resampler(wav)
wav = random_filter(wav)
wav *= 0.99 / (wav.abs().max() + 1e-5)
wavs.append(wav)
current_length += wav.size(1)
start = torch.randint(0, current_length - n_samples + 1, ())
wav = torch.cat(wavs, dim=1)[:, start : start + n_samples]
assert wav.size() == (1, n_samples), wav.size()
return wav
def get_butterworth_lpf(
cutoff_freq: float, sample_rate: int, cache={}
) -> tuple[torch.Tensor, torch.Tensor]:
if (cutoff_freq, sample_rate) not in cache:
q = math.sqrt(0.5)
omega = math.tau * cutoff_freq / sample_rate
cos_omega = math.cos(omega)
alpha = math.sin(omega) / (2.0 * q)
b1 = (1.0 - cos_omega) / (1.0 + alpha)
b0 = b1 * 0.5
a1 = -2.0 * cos_omega / (1.0 + alpha)
a2 = (1.0 - alpha) / (1.0 + alpha)
cache[(cutoff_freq, sample_rate)] = (
torch.tensor([b0, b1, b0]),
torch.tensor([1.0, a1, a2]),
)
return cache[(cutoff_freq, sample_rate)]
def augment_audio(
clean: torch.Tensor,
sample_rate: int,
noise_files: list[Union[str, bytes, os.PathLike]],
ir_files: list[Union[str, bytes, os.PathLike]],
snr_candidates: list[float] = [20.0, 25.0, 30.0, 35.0, 40.0, 45.0],
formant_shift_probability: float = 0.5,
formant_shift_semitone_min: float = -3.0,
formant_shift_semitone_max: float = 3.0,
reverb_probability: float = 0.5,
lpf_probability: float = 0.2,
lpf_cutoff_freq_candidates: list[float] = [2000.0, 3000.0, 4000.0, 6000.0],
) -> torch.Tensor:
# [1, wav_length]
assert clean.size(0) == 1
n_samples = clean.size(1)
original_clean_rms = clean.square().mean().sqrt_()
# clean をフォルマントシフトする
if torch.rand(()) < formant_shift_probability:
clean = random_formant_shift(
clean, sample_rate, formant_shift_semitone_min, formant_shift_semitone_max
)
# noise を取得して clean と concat する
noise = get_noise(n_samples, sample_rate, noise_files)
signals = torch.cat([clean, noise])
# clean, noise に異なるランダムフィルタをかける
signals = random_filter(signals)
# clean, noise にリバーブをかける
if torch.rand(()) < reverb_probability:
ir_file = ir_files[torch.randint(0, len(ir_files), ())]
ir, sr = beatrice_load_audio(ir_file)
assert ir.size() == (2, sr), ir.size()
assert sr == sample_rate, (sr, sample_rate)
signals = convolve(signals, ir)
# clean, noise に同じ LPF をかける
if torch.rand(()) < lpf_probability:
if signals.abs().max() > 0.8:
signals /= signals.abs().max() * 1.25
cutoff_freq = lpf_cutoff_freq_candidates[
torch.randint(0, len(lpf_cutoff_freq_candidates), ())
]
b, a = get_butterworth_lpf(cutoff_freq, sample_rate)
signals = torchaudio.functional.lfilter(signals, a, b, clamp=False)
# clean の音量を合わせる
clean, noise = signals
clean_rms = clean.square().mean().sqrt_()
clean *= original_clean_rms / clean_rms
if len(snr_candidates) >= 1:
# clean, noise の音量をピークを重視して取る
clean_level = clean.square().square_().mean().sqrt_().sqrt_()
noise_level = noise.square().square_().mean().sqrt_().sqrt_()
# SNR
snr = snr_candidates[torch.randint(0, len(snr_candidates), ())]
# noisy を生成
noisy = clean + noise * (
0.1 ** (snr / 20.0) * clean_level / (noise_level + 1e-5)
)
return noisy
class WavDataset(torch.utils.data.Dataset):
def __init__(
self,
audio_files: list[tuple[Path, int]],
in_sample_rate: int = 16000,
out_sample_rate: int = 24000,
wav_length: int = 4 * 24000, # 4s
segment_length: int = 100, # 1s
noise_files: Optional[list[Union[str, bytes, os.PathLike]]] = None,
ir_files: Optional[list[Union[str, bytes, os.PathLike]]] = None,
augmentation_snr_candidates: list[float] = [20.0, 25.0, 30.0, 35.0, 40.0, 45.0],
augmentation_formant_shift_probability: float = 0.5,
augmentation_formant_shift_semitone_min: float = -3.0,
augmentation_formant_shift_semitone_max: float = 3.0,
augmentation_reverb_probability: float = 0.5,
augmentation_lpf_probability: float = 0.2,
augmentation_lpf_cutoff_freq_candidates: list[float] = [
2000.0,
3000.0,
4000.0,
6000.0,
],
):
self.audio_files = audio_files
self.in_sample_rate = in_sample_rate
self.out_sample_rate = out_sample_rate
self.wav_length = wav_length
self.segment_length = segment_length
self.noise_files = noise_files
self.ir_files = ir_files
self.augmentation_snr_candidates = augmentation_snr_candidates
self.augmentation_formant_shift_probability = (
augmentation_formant_shift_probability
)
self.augmentation_formant_shift_semitone_min = (
augmentation_formant_shift_semitone_min
)
self.augmentation_formant_shift_semitone_max = (
augmentation_formant_shift_semitone_max
)
self.augmentation_reverb_probability = augmentation_reverb_probability
self.augmentation_lpf_probability = augmentation_lpf_probability
self.augmentation_lpf_cutoff_freq_candidates = (
augmentation_lpf_cutoff_freq_candidates
)
if (noise_files is None) is not (ir_files is None):
raise ValueError("noise_files and ir_files must be both None or not None")
self.in_hop_length = in_sample_rate // 100
self.out_hop_length = out_sample_rate // 100 # 10ms 刻み
def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, int, int]:
file, speaker_id = self.audio_files[index]
clean_wav, sample_rate = beatrice_load_audio(file)
if clean_wav.size(0) != 1:
ch = torch.randint(0, clean_wav.size(0), ())
clean_wav = clean_wav[ch : ch + 1]
formant_shift_candidates = [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]
formant_shift = formant_shift_candidates[
torch.randint(0, len(formant_shift_candidates), ()).item()
]
resampler_fraction = Fraction(
sample_rate / self.out_sample_rate * 2.0 ** (formant_shift / 12.0)
).limit_denominator(300)
clean_wav = get_resampler(
resampler_fraction.numerator, resampler_fraction.denominator
)(clean_wav)
assert clean_wav.size(0) == 1
assert clean_wav.size(1) != 0
clean_wav = F.pad(clean_wav, (self.wav_length, self.wav_length))
if self.noise_files is None:
noisy_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)(
clean_wav
)
else:
clean_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)(
clean_wav
)
noisy_wav_16k = augment_audio(
clean_wav_16k,
self.in_sample_rate,
self.noise_files,
self.ir_files,
self.augmentation_snr_candidates,
self.augmentation_formant_shift_probability,
self.augmentation_formant_shift_semitone_min,
self.augmentation_formant_shift_semitone_max,
self.augmentation_reverb_probability,
self.augmentation_lpf_probability,
self.augmentation_lpf_cutoff_freq_candidates,
)
clean_wav = clean_wav.squeeze_(0)
noisy_wav_16k = noisy_wav_16k.squeeze_(0)
# 音量をランダマイズする
amplitude = torch.rand(()).item() * 0.899 + 0.1
factor = amplitude / clean_wav.abs().max()
clean_wav *= factor
noisy_wav_16k *= factor
while noisy_wav_16k.abs().max() >= 1.0:
clean_wav *= 0.5
noisy_wav_16k *= 0.5
return clean_wav, noisy_wav_16k, speaker_id, formant_shift
def __len__(self) -> int:
return len(self.audio_files)
def collate(
self, batch: list[tuple[torch.Tensor, torch.Tensor, int, int]]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert self.wav_length % self.out_hop_length == 0
length = self.wav_length // self.out_hop_length
clean_wavs = []
noisy_wavs = []
slice_starts = []
speaker_ids = []
formant_shifts = []
for clean_wav, noisy_wav, speaker_id, formant_shift in batch:
# 発声部分をランダムに 1 箇所選ぶ
(voiced,) = clean_wav.nonzero(as_tuple=True)
assert voiced.numel() != 0
center = voiced[torch.randint(0, voiced.numel(), ()).item()].item()
# 発声部分が中央にくるように、スライス区間を選ぶ
slice_start = center - self.segment_length * self.out_hop_length // 2
assert slice_start >= 0
# スライス区間が含まれるように、ランダムに wav_length の長さを切り出す
r = torch.randint(0, length - self.segment_length + 1, ()).item()
offset = slice_start - r * self.out_hop_length
clean_wavs.append(clean_wav[offset : offset + self.wav_length])
offset_in_sample_rate = int(
round(offset * self.in_sample_rate / self.out_sample_rate)
)
noisy_wavs.append(
noisy_wav[
offset_in_sample_rate : offset_in_sample_rate
+ length * self.in_hop_length
]
)
slice_start = r
slice_starts.append(slice_start)
speaker_ids.append(speaker_id)
formant_shifts.append(formant_shift)
clean_wavs = torch.stack(clean_wavs)
noisy_wavs = torch.stack(noisy_wavs)
slice_starts = torch.tensor(slice_starts)
speaker_ids = torch.tensor(speaker_ids)
formant_shifts = torch.tensor(formant_shifts)
return (
clean_wavs, # [batch_size, wav_length]
noisy_wavs, # [batch_size, wav_length]
slice_starts, # Long[batch_size]
speaker_ids, # Long[batch_size]
formant_shifts, # Long[batch_size]
)
AUDIO_FILE_SUFFIXES = {
".wav",
".aif",
".aiff",
".fla",
".flac",
".oga",
".ogg",
".opus",
".mp3",
}
def get_compressed_optimizer_state_dict(
optimizer: torch.optim.Optimizer,
) -> dict:
state_dict = {}
for k0, v0 in optimizer.state_dict().items():
if k0 != "state":
state_dict[k0] = v0
continue
state_dict[k0] = {}
for k1, v1 in v0.items():
state_dict[k0][k1] = {}
for k2, v2 in v1.items():
if isinstance(v2, torch.Tensor):
state_dict[k0][k1][k2] = v2.bfloat16()
assert state_dict[k0][k1][k2].isfinite().all()
else:
state_dict[k0][k1][k2] = v2
return state_dict
def get_decompressed_optimizer_state_dict(compressed_state_dict: dict) -> dict:
state_dict = {}
for k0, v0 in compressed_state_dict.items():
if k0 != "state":
state_dict[k0] = v0
continue
state_dict[k0] = {}
for k1, v1 in v0.items():
state_dict[k0][k1] = {}
for k2, v2 in v1.items():
if isinstance(v2, torch.Tensor):
state_dict[k0][k1][k2] = v2.float()
assert state_dict[k0][k1][k2].isfinite().all()
else:
state_dict[k0][k1][k2] = v2
return state_dict
# ============================================================
# BEATRICE V2 TRAINING - Embedded (downloads assets from HuggingFace)
# ============================================================
BEATRICE_AUDIO_FILE_SUFFIXES = {".wav", ".aif", ".aiff", ".fla", ".flac", ".oga", ".ogg", ".opus", ".mp3"}
def preprocess_audio_for_beatrice(audio_path: str, output_dir: str, speaker_name: str = "speaker"):
"""Preprocess audio for Beatrice training using silence-based splitting"""
# Create speaker directory structure required by Beatrice
speaker_dir = os.path.join(output_dir, speaker_name)
os.makedirs(speaker_dir, exist_ok=True)
# Load audio at 16kHz (Beatrice input sample rate)
audio, sr = librosa.load(audio_path, sr=16000, mono=True)
# Simple silence-based splitting (RMS threshold)
chunk_size = int(4.0 * sr) # 4 second chunks
hop = int(3.5 * sr) # 0.5s overlap
threshold = 0.01 # RMS threshold
chunks_saved = 0
for i, start in enumerate(range(0, len(audio) - chunk_size, hop)):
chunk = audio[start:start + chunk_size]
rms = np.sqrt(np.mean(chunk ** 2))
if rms > threshold: # Skip silence
# Normalize
max_val = np.abs(chunk).max()
if max_val > 0:
chunk = chunk / max_val * 0.9
chunk_path = os.path.join(speaker_dir, f"{speaker_name}_{chunks_saved:04d}.wav")
sf.write(chunk_path, chunk, sr)
chunks_saved += 1
logger.info(f"Beatrice preprocessing: {chunks_saved} chunks saved to {speaker_dir}")
return chunks_saved, output_dir
def train_beatrice_generator(
data_dir: str,
output_dir: str,
epochs: int = 30,
batch_size: int = 8,
lr_g: float = 5e-5,
lr_d: float = 5e-5,
use_augmentation: bool = False,
resume: bool = False,
progress_callback=None,
):
"""Train Beatrice v2 model - generator yielding (message, model_path) tuples"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download pretrained
yield "Downloading pretrained models...", None
phone_extractor_path = download_beatrice_asset("phone_extractor")
pitch_estimator_path = download_beatrice_asset("pitch_estimator")
pretrained_model_path = download_beatrice_asset("pretrained_model")
# Discover speakers from directory structure
# Expected: data_dir/speaker_name/*.wav
speakers = []
training_filelist = []
speaker_audio_files = []
for speaker_dir in sorted(Path(data_dir).iterdir()):
if not speaker_dir.is_dir():
continue
candidates = [f for f in sorted(speaker_dir.rglob("*"))
if f.is_file() and f.suffix.lower() in BEATRICE_AUDIO_FILE_SUFFIXES]
if candidates:
speaker_id = len(speakers)
speakers.append(speaker_dir.name)
training_filelist.extend([(f, speaker_id) for f in candidates])
speaker_audio_files.append(candidates)
n_speakers = len(speakers)
if n_speakers == 0:
yield "Error: No speakers found in data directory", None
return
yield f"Found {n_speakers} speaker(s), {len(training_filelist)} files", None
# Augmentation assets (optional)
noise_files = None
ir_files = None
if use_augmentation:
try:
noise_dir, ir_dir = download_beatrice_augmentation()
if noise_dir and ir_dir:
noise_files = sorted(list(Path(noise_dir).rglob("*.wav")) + list(Path(noise_dir).rglob("*.flac")))
ir_files = sorted(list(Path(ir_dir).rglob("*.wav")) + list(Path(ir_dir).rglob("*.flac")))
if noise_files and ir_files:
yield f"Loaded augmentation: {len(noise_files)} noise, {len(ir_files)} IR files", None
else:
noise_files = None
ir_files = None
except Exception as e:
yield f"Warning: Could not load augmentation assets: {e}", None
# Build models
yield "Building models...", None
phone_extractor = PhoneExtractor().to(device).eval().requires_grad_(False)
pe_ckpt = torch.load(phone_extractor_path, map_location="cpu", weights_only=True)
phone_extractor.load_state_dict(pe_ckpt["phone_extractor"], strict=False)
del pe_ckpt
pitch_estimator = PitchEstimator().to(device).eval().requires_grad_(False)
pi_ckpt = torch.load(pitch_estimator_path, map_location="cpu", weights_only=True)
pitch_estimator.load_state_dict(pi_ckpt["pitch_estimator"])
del pi_ckpt
hidden_channels = 256
pitch_bins = 448
net_g = ConverterNetwork(
phone_extractor, pitch_estimator,
n_speakers=n_speakers,
pitch_bins=pitch_bins,
hidden_channels=hidden_channels,
vq_topk=4,
training_time_vq="none",
phone_noise_ratio=0.5,
floor_noise_level=1e-3,
).to(device)
net_d = BeatriceMultiPeriodDiscriminator(san=True).to(device)
# Optimizers
optim_g = torch.optim.AdamW(net_g.parameters(), lr_g, betas=(0.8, 0.99), eps=1e-6)
optim_d = torch.optim.AdamW(net_d.parameters(), lr_d, betas=(0.8, 0.99), eps=1e-6)
grad_scaler = torch.amp.GradScaler(device.type, enabled=device.type == "cuda")
grad_balancer = GradBalancer(
weights={
"loss_loudness": 1.0,
"loss_mel": 45.0,
"loss_adv": 1.0,
"loss_fm": 2.0,
},
ema_decay=0.999,
)
initial_iteration = 0
os.makedirs(output_dir, exist_ok=True)
# Load pretrained or resume
if resume:
latest_ckpt = os.path.join(output_dir, "checkpoint_latest.pt.gz")
if os.path.isfile(latest_ckpt):
yield "Resuming from checkpoint...", None
with gzip.open(latest_ckpt, "rb") as f:
ckpt = torch.load(f, map_location="cpu", weights_only=True)
net_g.load_state_dict(ckpt["net_g"], strict=False)
# Filter discriminator for shape mismatches
net_d_state = net_d.state_dict()
filtered_d = {k: v for k, v in ckpt["net_d"].items()
if k in net_d_state and v.shape == net_d_state[k].shape}
net_d.load_state_dict(filtered_d, strict=False)
optim_g.load_state_dict(get_decompressed_optimizer_state_dict(ckpt["optim_g"]))
optim_d.load_state_dict(get_decompressed_optimizer_state_dict(ckpt["optim_d"]))
if "grad_balancer" in ckpt:
grad_balancer.load_state_dict(ckpt["grad_balancer"])
if "grad_scaler" in ckpt:
grad_scaler.load_state_dict(ckpt["grad_scaler"])
initial_iteration = ckpt.get("iteration", 0)
del ckpt
else:
yield "No checkpoint found, starting fresh with pretrained", None
resume = False
if not resume:
yield "Loading pretrained weights...", None
with gzip.open(pretrained_model_path, "rb") as f:
pretrained_ckpt = torch.load(f, map_location="cpu", weights_only=True)
# Adapt pretrained for our n_speakers
initial_speaker_emb = pretrained_ckpt["net_g"]["embed_speaker.weight"][:1]
pretrained_ckpt["net_g"]["embed_speaker.weight"] = initial_speaker_emb[[0] * n_speakers]
initial_kv_emb = pretrained_ckpt["net_g"]["key_value_speaker_embedding.weight"][:1]
pretrained_ckpt["net_g"]["key_value_speaker_embedding.weight"] = initial_kv_emb[[0] * n_speakers]
pretrained_ckpt["net_g"]["vq.codebooks"] = pretrained_ckpt["net_g"]["vq.codebooks"][[0] * n_speakers]
net_g.load_state_dict(pretrained_ckpt["net_g"], strict=False)
# Filter discriminator state dict for shape mismatches (pretrained may use san=False)
net_d_state = net_d.state_dict()
filtered_d = {k: v for k, v in pretrained_ckpt["net_d"].items()
if k in net_d_state and v.shape == net_d_state[k].shape}
net_d.load_state_dict(filtered_d, strict=False)
logger.info(f"Loaded {len(filtered_d)}/{len(pretrained_ckpt['net_d'])} discriminator weights")
# Don't load grad_balancer/grad_scaler from pretrained - our loss weights may differ
# These will be re-initialized fresh for fine-tuning
del pretrained_ckpt
# Build VQ codebooks
yield "Building VQ codebooks...", None
def wav_iterator(files):
for file in files:
wav, sr = beatrice_load_audio(file)
wav = wav.to(device)
if sr != 16000:
wav = get_resampler(sr, 16000, str(device))(wav)
yield wav[:, None, :]
if resume:
net_g.enable_hook()
else:
net_g.initialize_vq([wav_iterator(files) for files in speaker_audio_files])
# Dataset
dataset = WavDataset(
training_filelist,
in_sample_rate=16000,
out_sample_rate=24000,
wav_length=96000,
segment_length=100,
noise_files=noise_files,
ir_files=ir_files,
)
_num_workers = min(4, os.cpu_count() or 1)
effective_batch = min(batch_size, len(training_filelist))
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=_num_workers,
collate_fn=dataset.collate,
shuffle=True,
batch_size=effective_batch,
pin_memory=True,
drop_last=len(training_filelist) > effective_batch,
persistent_workers=_num_workers > 0,
)
# Calculate steps
steps_per_epoch = max(1, len(training_filelist) // batch_size)
total_steps = epochs * steps_per_epoch
warmup_steps = min(total_steps // 4, 5000)
# LR scheduler with warmup
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
return 0.999 ** (step - warmup_steps)
scheduler_g = torch.optim.lr_scheduler.LambdaLR(optim_g, lr_lambda)
scheduler_d = torch.optim.lr_scheduler.LambdaLR(optim_d, lr_lambda)
# Advance schedulers if resuming
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=r"Detected call of `lr_scheduler\.step\(\)")
for _ in range(initial_iteration + 1):
scheduler_g.step()
scheduler_d.step()
net_g.train()
net_d.train()
yield f"Training {total_steps} steps ({epochs} epochs x {steps_per_epoch} steps/epoch)", None
# Training loop
step = initial_iteration
data_iter = None
ckpt_path = None
for epoch in range(epochs):
epoch_loss_g = 0.0
epoch_loss_d = 0.0
epoch_steps = 0
for batch_idx in range(steps_per_epoch):
if data_iter is None:
data_iter = iter(dataloader)
batch = next(data_iter, None)
if batch is None:
data_iter = iter(dataloader)
batch = next(data_iter, None)
if batch is None:
break
clean_wavs, noisy_wavs_16k, slice_starts, speaker_ids, formant_shifts = \
[x.to(device, non_blocking=True) for x in batch]
with torch.amp.autocast(device.type, enabled=device.type == "cuda"):
# Generator forward
y, y_hat, y_hat_for_backward, loss_loudness, loss_mel, loss_ap, gen_stats = \
net_g.forward_and_compute_loss(
noisy_wavs_16k[:, None, :],
speaker_ids,
formant_shifts,
slice_start_indices=slice_starts,
slice_segment_length=100,
y_all=clean_wavs[:, None, :],
)
# Discriminator forward
loss_disc, loss_adv, loss_fm, disc_stats = \
net_d.forward_and_compute_loss(y, y_hat)
# Discriminator backward
optim_d.zero_grad(set_to_none=True)
grad_scaler.scale(loss_disc).backward(retain_graph=True, inputs=list(net_d.parameters()))
grad_scaler.unscale_(optim_d)
# Generator backward
optim_g.zero_grad(set_to_none=True)
grad_balancer.backward(
{"loss_loudness": loss_loudness, "loss_mel": loss_mel,
"loss_adv": loss_adv, "loss_fm": loss_fm},
y_hat_for_backward, grad_scaler,
skip_update_ema=step > 10 and step % 5 != 0,
)
grad_scaler.unscale_(optim_g)
# Update
grad_scaler.step(optim_g)
grad_scaler.step(optim_d)
grad_scaler.update()
optim_g.zero_grad(set_to_none=True)
optim_d.zero_grad(set_to_none=True)
scheduler_g.step()
scheduler_d.step()
epoch_loss_g += loss_mel.item()
epoch_loss_d += loss_disc.item()
epoch_steps += 1
step += 1
if progress_callback:
progress_callback(step / total_steps)
avg_loss_g = epoch_loss_g / max(1, epoch_steps)
avg_loss_d = epoch_loss_d / max(1, epoch_steps)
yield f"Epoch {epoch+1}/{epochs} | G loss: {avg_loss_g:.4f} | D loss: {avg_loss_d:.4f} | LR: {scheduler_g.get_last_lr()[0]:.2e}", None
# Save checkpoint periodically
if (epoch + 1) % max(1, epochs // 5) == 0 or epoch == epochs - 1:
ckpt_path = os.path.join(output_dir, f"checkpoint_{step:08d}.pt.gz")
with gzip.open(ckpt_path, "wb") as f:
torch.save({
"iteration": step,
"net_g": net_g.state_dict(),
"phone_extractor": phone_extractor.state_dict(),
"pitch_estimator": pitch_estimator.state_dict(),
"net_d": {k: v.half() for k, v in net_d.state_dict().items()},
"optim_g": get_compressed_optimizer_state_dict(optim_g),
"optim_d": get_compressed_optimizer_state_dict(optim_d),
"grad_balancer": grad_balancer.state_dict(),
"grad_scaler": grad_scaler.state_dict(),
"h": {
"hidden_channels": hidden_channels,
"pitch_bins": pitch_bins,
"vq_topk": 4,
"training_time_vq": "none",
"phone_noise_ratio": 0.5,
"floor_noise_level": 1e-3,
"san": True,
},
"speakers": speakers,
}, f)
shutil.copy(ckpt_path, os.path.join(output_dir, "checkpoint_latest.pt.gz"))
yield f"Saved checkpoint: {ckpt_path}", ckpt_path
# Cleanup
purge_memory(net_g, net_d, optim_g, optim_d, phone_extractor, pitch_estimator)
yield "Training complete!", ckpt_path
def convert_voice_beatrice(
source_audio,
model_file,
target_speaker: int = 0,
pitch_shift: int = 0,
formant_shift: float = 0.0,
progress=None,
):
"""Convert voice using Beatrice v2 model
Args:
source_audio: Path to source audio file
model_file: Path to Beatrice checkpoint (.pt.gz) or file object with .name
target_speaker: Target speaker index
pitch_shift: Pitch shift in semitones
formant_shift: Formant shift in semitones (-2 to 2)
progress: Gradio progress callback
Returns:
(output_path, status_message) tuple
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Get model path
if hasattr(model_file, 'name'):
model_path = model_file.name
elif isinstance(model_file, str):
model_path = model_file
else:
return None, "Invalid model file"
if not model_path or not os.path.exists(model_path):
return None, f"Model file not found: {model_path}"
try:
if progress:
progress(0.1, "Loading models...")
# Download pretrained assets (phone extractor + pitch estimator)
phone_extractor_path = download_beatrice_asset("phone_extractor")
pitch_estimator_path = download_beatrice_asset("pitch_estimator")
# Build phone extractor
phone_extractor = PhoneExtractor().to(device).eval().requires_grad_(False)
pe_ckpt = torch.load(phone_extractor_path, map_location="cpu", weights_only=True)
phone_extractor.load_state_dict(pe_ckpt["phone_extractor"], strict=False)
del pe_ckpt
# Build pitch estimator
pitch_estimator = PitchEstimator().to(device).eval().requires_grad_(False)
pi_ckpt = torch.load(pitch_estimator_path, map_location="cpu", weights_only=True)
pitch_estimator.load_state_dict(pi_ckpt["pitch_estimator"])
del pi_ckpt
if progress:
progress(0.3, "Loading trained model...")
# Load trained checkpoint
with gzip.open(model_path, "rb") as f:
checkpoint = torch.load(f, map_location="cpu", weights_only=True)
# Determine model params from checkpoint
n_speakers = checkpoint["net_g"]["embed_speaker.weight"].shape[0]
h = checkpoint.get("h", {})
hidden_channels = h.get("hidden_channels", 256)
pitch_bins = h.get("pitch_bins", 448)
speakers = checkpoint.get("speakers", [f"Speaker {i}" for i in range(n_speakers)])
if target_speaker >= n_speakers:
target_speaker = 0
net_g = ConverterNetwork(
phone_extractor, pitch_estimator,
n_speakers=n_speakers,
pitch_bins=pitch_bins,
hidden_channels=hidden_channels,
vq_topk=h.get("vq_topk", 4),
training_time_vq=h.get("training_time_vq", "none"),
phone_noise_ratio=h.get("phone_noise_ratio", 0.5),
floor_noise_level=h.get("floor_noise_level", 1e-3),
).to(device).eval()
net_g.load_state_dict(checkpoint["net_g"], strict=False)
net_g.enable_hook()
del checkpoint
if progress:
progress(0.5, "Converting voice...")
# Load audio at 16kHz
audio_path = source_audio if isinstance(source_audio, str) else source_audio
audio, sr = librosa.load(audio_path, sr=16000, mono=True)
audio_tensor = torch.from_numpy(audio).float().unsqueeze(0).unsqueeze(0).to(device)
# Pad to multiple of 160 (phone extractor stride)
original_length = audio_tensor.shape[-1]
if original_length % 160 != 0:
pad_len = 160 - original_length % 160
audio_tensor = F.pad(audio_tensor, (0, pad_len))
# Convert
with torch.inference_mode():
y_hat = net_g(
audio_tensor,
torch.tensor([target_speaker], device=device),
torch.tensor([formant_shift], device=device),
torch.tensor([float(pitch_shift)], device=device),
)
# Output is 24kHz, trim to match input duration
output_length = original_length // 160 * 240 # 16kHz→24kHz frame ratio
output = y_hat.squeeze().cpu().numpy()[:output_length]
# Save
fd, output_path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
sf.write(output_path, output, 24000)
# Cleanup
purge_memory(net_g, phone_extractor, pitch_estimator)
speaker_name = speakers[target_speaker] if target_speaker < len(speakers) else f"Speaker {target_speaker}"
return output_path, f"Converted using Beatrice v2 | 24kHz | Speaker: {speaker_name} | Pitch: {pitch_shift:+d} | Formant: {formant_shift:+.1f}"
except Exception as e:
logger.exception("Beatrice inference error")
return None, f"Error: {str(e)}"
# ============================================================
# GRADIO UI - Gradio 6 Compatible
# ============================================================
def train_ui(
audio_file,
model_name: str,
epochs: int,
batch_size: int,
sample_rate: int,
f0_method: str = "rmvpe",
progress=gr.Progress()
):
"""Training function for Gradio UI - Generator for live log updates"""
if audio_file is None:
yield None, None, "❌ Please upload training audio"
return
if not model_name or model_name.strip() == "":
yield None, None, "❌ Please enter a model name"
return
# Check if CUDA available
has_cuda = torch.cuda.is_available()
device_info = "GPU (CUDA)" if has_cuda else "CPU"
# Log accumulator for live updates
logs = []
try:
model_name = sanitize_model_name(model_name)
output_dir = f"trained_models/{model_name}"
data_dir = f"{output_dir}/data"
# Preprocessing phase
logs.append(f"🚀 Starting on {device_info}")
logs.append(f"📂 Output: {output_dir}")
yield None, None, "\n".join(logs)
progress(0.1, "Preprocessing...")
logs.append("🔄 Preprocessing audio...")
yield None, None, "\n".join(logs)
audio_path = audio_file if isinstance(audio_file, str) else audio_file.name
result = preprocess_audio_for_training(audio_path, data_dir, target_sr=sample_rate, f0_method=f0_method)
if result is None:
logs.append("❌ Preprocessing failed - no valid audio chunks")
yield None, None, "\n".join(logs)
return
logs.append("✅ Preprocessing complete")
logs.append(f"🏋️ Training {epochs} epochs...")
logs.append("─" * 40)
yield None, None, "\n".join(logs)
# Training phase - iterate over generator for live updates
ckpt = None
idx = None
for msg, path, index in train_rvc_generator(
data_dir=data_dir,
output_dir=output_dir,
epochs=epochs,
batch_size=batch_size,
lr=1e-5,
target_sr=sample_rate,
progress_callback=progress
):
logs.append(msg)
yield None, None, "\n".join(logs)
if path:
ckpt = path
if index:
idx = index
if ckpt:
logs.append("─" * 40)
logs.append(f"✅ Training complete!")
logs.append(f"📦 Model: {ckpt}")
if idx:
logs.append(f"📦 Index: {idx}")
progress(1.0, "Done!")
yield ckpt, idx, "\n".join(logs)
else:
logs.append("❌ Training failed")
yield None, None, "\n".join(logs)
except Exception as e:
logger.exception("Training error")
logs.append(f"❌ Error: {str(e)}")
yield None, None, "\n".join(logs)
# ============================================================
# SPLIT-AND-STITCH BACKGROUND PROCESSOR
#
# Architecture rationale:
# Heavy RVC inference on a 2-core CPU can push 10–14 GB RAM for
# long audio. The solution is to:
# 1. Chop the input into CHUNK_SEC-second slices with a
# OVERLAP_SEC-second overlap on each side. The overlap gives
# the vocoder context at boundaries so there are no clicks.
# 2. Process each chunk independently through the full RVC
# pipeline (RMVPE + FAISS + vocoder). After each chunk,
# call purge_memory() so RAM never accumulates.
# 3. Cross-fade adjacent chunks over the overlap region using a
# linear fade-out/fade-in (equal-power is not needed here
# because RVC output is already bandlimited). The crossfade
# is the only "small render" step — it's just numpy slicing
# and a linear ramp, not another model pass.
# 4. Concatenate the stitched segments and write the final WAV.
#
# Non-negotiable settings preserved:
# - f0_method is always forwarded as-is (RMVPE by default).
# - index_file is always forwarded (FAISS stays active).
# - Model sample rate (40k/48k) is respected — chunks are saved
# at 16k for processing and the stitched output is at tgt_sr.
# ============================================================
CHUNK_SEC = 30 # seconds per chunk fed to RVC (keeps RAM under ~4 GB/chunk)
OVERLAP_SEC = 0.4 # seconds of overlap for crossfade seam (at tgt_sr)
MIN_CHUNK_SEC = 5 # don't split files shorter than this
def _crossfade_join(seg_a: np.ndarray, seg_b: np.ndarray,
overlap_samples: int) -> np.ndarray:
"""
Overlap-add two mono float32 segments.
seg_a: …audio… [overlap_samples tail]
seg_b: [overlap_samples head] …audio…
Returns the seamlessly stitched result.
"""
if overlap_samples <= 0 or len(seg_a) < overlap_samples or len(seg_b) < overlap_samples:
return np.concatenate([seg_a, seg_b])
fade_out = np.linspace(1.0, 0.0, overlap_samples, dtype=np.float32)
fade_in = np.linspace(0.0, 1.0, overlap_samples, dtype=np.float32)
blended = seg_a[-overlap_samples:] * fade_out + seg_b[:overlap_samples] * fade_in
return np.concatenate([seg_a[:-overlap_samples], blended, seg_b[overlap_samples:]])
def convert_voice_chunked(
source_audio,
model_file,
index_file=None,
pitch_shift: int = 0,
f0_method: str = "rmvpe", # RMVPE is non-negotiable — forwarded as-is
index_rate: float = 0.75, # FAISS active — forwarded as-is
protect: float = 0.33,
volume_envelope: float = 1.0,
progress=None,
chunk_sec: int = CHUNK_SEC,
overlap_sec: float = OVERLAP_SEC,
) -> Tuple[str, str]:
"""
Split-and-stitch wrapper around convert_voice().
For audio shorter than MIN_CHUNK_SEC seconds, falls straight
through to convert_voice() with no splitting overhead.
For longer audio:
• Loads the full waveform once at 16 kHz (source SR for RVC).
• Slices into overlapping chunks at the 16k level.
• Each chunk is written to a temp WAV, run through convert_voice(),
and the result is loaded back as a numpy array.
• purge_memory() is called after every chunk so RAM is returned
to the OS via malloc_trim() before the next chunk starts.
• Chunks are crossfaded and concatenated.
• The final stitched audio is written to a single output WAV.
"""
if source_audio is None:
return None, "Please upload source audio"
if model_file is None:
return None, "Please upload RVC model (.pth)"
# ------------------------------------------------------------------
# 1. Load source audio once to check duration
# ------------------------------------------------------------------
try:
audio_full, _ = librosa.load(source_audio, sr=16000, mono=True)
except Exception as e:
return None, f"Failed to load audio: {e}"
duration_sec = len(audio_full) / 16000.0
# Short file — no splitting needed, avoids overhead
if duration_sec <= MIN_CHUNK_SEC:
purge_memory(audio_full)
return convert_voice(
source_audio, model_file, index_file,
pitch_shift, f0_method, index_rate, protect, volume_envelope,
progress if progress is not None else gr.Progress()
)
# ------------------------------------------------------------------
# 2. Determine chunk boundaries (in 16k samples)
# ------------------------------------------------------------------
sr_in = 16000
chunk_samp = int(chunk_sec * sr_in)
overlap_samp = int(overlap_sec * sr_in)
hop_samp = chunk_samp - overlap_samp # non-overlapping step
starts = list(range(0, len(audio_full), hop_samp))
n_chunks = len(starts)
logger.info(f"[Chunked] {duration_sec:.1f}s → {n_chunks} chunks "
f"({chunk_sec}s each, {overlap_sec}s overlap)")
# ------------------------------------------------------------------
# 3. Process each chunk through convert_voice()
# ------------------------------------------------------------------
stitched_segments: list[np.ndarray] = []
tgt_sr = None # learned from first chunk output
overlap_out_samp = 0 # overlap at target SR (learned after first chunk)
for i, start in enumerate(starts):
end = min(start + chunk_samp, len(audio_full))
chunk = audio_full[start:end]
if progress is not None:
try:
progress((i + 0.5) / n_chunks,
f"Processing chunk {i+1}/{n_chunks}…")
except Exception:
pass
# Write chunk to temp WAV
fd, chunk_path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
sf.write(chunk_path, chunk, sr_in)
try:
# Run full RVC pipeline on this chunk (RMVPE + FAISS inside)
out_path, status = convert_voice(
chunk_path, model_file, index_file,
pitch_shift, f0_method, index_rate, protect, volume_envelope,
# Suppress inner progress — we own the bar
gr.Progress() if progress is None else progress
)
except Exception as e:
logger.warning(f"[Chunked] Chunk {i+1} failed: {e}")
# Remove temp and skip — silence gap is better than crash
try: os.unlink(chunk_path)
except OSError: pass
purge_memory(chunk)
continue
finally:
try: os.unlink(chunk_path)
except OSError: pass
if out_path is None:
logger.warning(f"[Chunked] Chunk {i+1} returned no output: {status}")
purge_memory(chunk)
continue
# Load the converted chunk
chunk_out, chunk_sr = sf.read(out_path, dtype="float32")
try: os.unlink(out_path)
except OSError: pass
# Learn target SR from first successful chunk
if tgt_sr is None:
tgt_sr = chunk_sr
# Scale overlap to target SR
overlap_out_samp = int(overlap_sec * tgt_sr)
stitched_segments.append(chunk_out)
# Critical: free everything before next chunk loads the model
purge_memory(chunk, chunk_out)
# ------------------------------------------------------------------
# 4. Stitch segments with crossfade
# ------------------------------------------------------------------
if not stitched_segments:
return None, "All chunks failed — no output produced"
result = stitched_segments[0]
for seg in stitched_segments[1:]:
result = _crossfade_join(result, seg, overlap_out_samp)
# Final normalization (matches convert_voice behaviour)
audio_max = np.abs(result).max() / 0.99
if audio_max > 1.0:
result /= audio_max
# ------------------------------------------------------------------
# 5. Write stitched output
# ------------------------------------------------------------------
final_sr = tgt_sr if tgt_sr is not None else 40000
fd, output_path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
sf.write(output_path, result, final_sr)
purge_memory(result, audio_full)
logger.info(f"[Chunked] Stitched output: {len(stitched_segments)} chunks → {output_path}")
return output_path, (
f"Chunked conversion: {n_chunks} chunks stitched | "
f"sr={final_sr} | pitch={pitch_shift:+d} | f0={f0_method}"
)
with gr.Blocks() as demo:
gr.Markdown(f"# 🎤 Voice Conversion (RVC + Beatrice)\nInference: CPU • Training: {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}")
with gr.Tabs():
# ==================== TAB 1: VOICE CONVERSION ====================
with gr.Tab("🎵 Voice Conversion"):
with gr.Row():
with gr.Column():
source_audio = gr.Audio(label="Source Audio", type="filepath")
gr.Markdown("### Model")
model_type = gr.Radio(
["RVC v2", "Beatrice v2"], value="RVC v2",
label="Model Type",
info="RVC: .pth files | Beatrice: .pt.gz files"
)
# RVC model inputs
with gr.Group(visible=True) as rvc_model_group:
_available_models = list_rvc_models()
rvc_model_dropdown = gr.Dropdown(
choices=_available_models,
label="Select Voice Model",
info="Models auto-loaded from weights/ folder",
value="model_kunni.pth" if "model_kunni.pth" in _available_models else (_available_models[0] if _available_models else None),
)
with gr.Row():
model_file = gr.File(label="OR Upload New (.pth)", file_types=[".pth"])
load_example_btn = gr.Button("Load Example (Benee)", size="sm")
index_file = gr.File(label="Index File (.index) - Optional", file_types=[".index"])
# Beatrice model inputs
with gr.Group(visible=False) as beatrice_model_group:
beatrice_model_file = gr.File(label="Beatrice Model (.pt.gz)", file_types=[".gz"])
with gr.Row():
beatrice_target_speaker = gr.Number(value=0, label="Target Speaker", precision=0)
beatrice_formant_shift = gr.Slider(-2, 2, value=0.0, step=0.5, label="Formant Shift")
with gr.Row():
pitch_shift = gr.Slider(-12, 12, value=0, step=1, label="Pitch (semitones)")
f0_method = gr.Radio(["rmvpe", "pm", "harvest"], value="rmvpe", label="F0 Method", visible=True)
with gr.Row(visible=True) as rvc_extra_options:
index_rate = gr.Slider(0, 1, value=0.75, step=0.05, label="Index Rate")
protect = gr.Slider(0, 0.5, value=0.33, step=0.01, label="Protect (voiceless consonants)")
convert_btn = gr.Button("Convert", variant="primary")
with gr.Column():
output_audio = gr.Audio(label="Converted Audio", type="filepath")
output_info = gr.Textbox(label="Status", lines=2)
def update_model_type(model_type_val):
is_rvc = model_type_val == "RVC v2"
return (
gr.update(visible=is_rvc), # rvc_model_group
gr.update(visible=not is_rvc), # beatrice_model_group
gr.update(visible=is_rvc), # f0_method
gr.update(visible=is_rvc), # rvc_extra_options
)
model_type.change(
update_model_type,
[model_type],
[rvc_model_group, beatrice_model_group, f0_method, rvc_extra_options]
)
load_example_btn.click(
load_example_model,
[],
[model_file, index_file, output_info]
)
def convert_unified(source, m_type, rvc_dropdown_model, rvc_model, rvc_index, beat_model,
beat_speaker, beat_formant, pitch, f0, idx_rate, prot,
progress=gr.Progress()):
if m_type == "RVC v2":
# Resolve model: prefer uploaded file, fall back to dropdown selection
resolved_model = rvc_model
resolved_index = rvc_index
if resolved_model is None and rvc_dropdown_model:
pth_path = Path("weights") / rvc_dropdown_model
# Wrap in a simple object so convert_voice can call .name on it
class _FileObj:
def __init__(self, p): self.name = str(p)
resolved_model = _FileObj(pth_path)
# Auto-hunt for matching .index in weights/
if resolved_index is None:
stem = pth_path.stem
index_path = Path("weights") / f"{stem}.index"
if index_path.exists():
resolved_index = _FileObj(index_path)
# Use the split-and-stitch processor for RVC.
# It falls through to plain convert_voice() for short clips
# and auto-chunks long audio to prevent OOM on HF Spaces.
return convert_voice_chunked(source, resolved_model, resolved_index, pitch, f0, idx_rate, prot, progress=progress)
else:
return convert_voice_beatrice(
source, beat_model,
target_speaker=int(beat_speaker),
pitch_shift=int(pitch),
formant_shift=float(beat_formant),
progress=progress
)
convert_btn.click(
convert_unified,
[source_audio, model_type, rvc_model_dropdown, model_file, index_file, beatrice_model_file,
beatrice_target_speaker, beatrice_formant_shift,
pitch_shift, f0_method, index_rate, protect],
[output_audio, output_info],
api_name="convert",
concurrency_limit=1,
)
gr.Markdown("**Models:** [HuggingFace](https://huggingface.co/models?search=rvc) | [Weights.gg](https://weights.gg)")
# ==================== TAB 2: TRAINING ====================
with gr.Tab("🏋️ Training"):
# GPU Warning
gpu_status = "🟢 GPU Available (CUDA)" if torch.cuda.is_available() else "🟡 CPU Only (Training will be slow)"
gr.Markdown(f"""
### Training Status: {gpu_status}
{'**GPU detected!** Training will use CUDA acceleration.' if torch.cuda.is_available() else '**No GPU detected.** Training will run on CPU (~30 sec/epoch). For faster training, run locally with CUDA GPU.'}
""")
# Trainer selector - Beatrice always available (downloads assets from HF)
trainer_selector = gr.Dropdown(
choices=["RVC v2", "Beatrice v2"],
value="RVC v2",
label="Trainer",
info="RVC v2: general purpose | Beatrice v2: low latency (~50ms)"
)
with gr.Row():
with gr.Column():
train_audio = gr.Audio(label="Training Audio", type="filepath")
train_model_name = gr.Textbox(label="Model Name", placeholder="my_voice", value="my_voice")
# RVC-specific options
with gr.Group(visible=True) as rvc_options:
with gr.Row():
train_epochs = gr.Slider(1, 500, value=50, step=1, label="Epochs")
train_batch = gr.Slider(1, 8, value=2, step=1, label="Batch Size")
with gr.Row():
train_sr = gr.Radio([32000, 40000, 48000], value=40000, label="Sample Rate")
train_f0 = gr.Radio(["rmvpe", "pm", "harvest"], value="rmvpe", label="F0 Method")
# Beatrice-specific options
with gr.Group(visible=False) as beatrice_options:
with gr.Row():
beatrice_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs (30 recommended)")
beatrice_batch = gr.Slider(1, 64, value=8, step=1, label="Batch Size")
beatrice_resume = gr.Checkbox(label="Resume from checkpoint", value=False)
train_btn = gr.Button("Start Training", variant="primary")
with gr.Column():
train_output_model = gr.File(label="Trained Model (.pth)")
train_output_index = gr.File(label="Index File (.index)")
train_status = gr.Textbox(label="Training Status", lines=6)
# Toggle visibility based on trainer selection
def update_trainer_options(trainer):
is_rvc = trainer == "RVC v2"
return gr.update(visible=is_rvc), gr.update(visible=not is_rvc)
trainer_selector.change(
update_trainer_options,
[trainer_selector],
[rvc_options, beatrice_options]
)
# Unified training function
def train_unified(trainer, audio, name, rvc_epochs, rvc_batch, rvc_sr, rvc_f0,
beat_epochs, beat_batch, beat_resume, progress=gr.Progress()):
if trainer == "RVC v2":
# Use generator for live updates (yields 3-tuples: model, index, status)
for result in train_ui(audio, name, rvc_epochs, rvc_batch, rvc_sr, rvc_f0, progress):
yield result
else:
# Beatrice training (embedded - downloads assets from HF)
if audio is None:
yield None, None, "❌ Please upload training audio"
return
if not name or name.strip() == "":
yield None, None, "❌ Please enter a model name"
return
try:
name = sanitize_model_name(name)
output_dir = f"trained_models/beatrice_{name}"
data_dir = f"{output_dir}/training_data"
logs = []
# Preprocess audio into speaker chunks
logs.append(f"🚀 Starting Beatrice training on {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}")
yield None, None, "\n".join(logs)
progress(0.05, "Preprocessing audio...")
audio_path = audio if isinstance(audio, str) else audio.name
chunks, _ = preprocess_audio_for_beatrice(audio_path, data_dir, name)
if chunks == 0:
yield None, None, "❌ Preprocessing failed - no valid audio chunks"
return
logs.append(f"✅ Preprocessed {chunks} audio chunks")
logs.append("─" * 40)
yield None, None, "\n".join(logs)
# Train using embedded generator
ckpt = None
for msg, path in train_beatrice_generator(
data_dir=data_dir,
output_dir=output_dir,
epochs=beat_epochs,
batch_size=beat_batch,
resume=beat_resume,
progress_callback=progress
):
logs.append(msg)
yield None, None, "\n".join(logs)
if path:
ckpt = path
if ckpt:
logs.append("─" * 40)
logs.append(f"✅ Training complete!")
logs.append(f"📦 Model: {ckpt}")
progress(1.0, "Done!")
yield ckpt, None, "\n".join(logs)
else:
logs.append("❌ Training failed")
yield None, None, "\n".join(logs)
except Exception as e:
logger.exception("Beatrice training error")
yield None, None, f"❌ Error: {str(e)}"
train_btn.click(
train_unified,
[trainer_selector, train_audio, train_model_name,
train_epochs, train_batch, train_sr, train_f0,
beatrice_epochs, beatrice_batch, beatrice_resume],
[train_output_model, train_output_index, train_status],
api_name="train",
concurrency_limit=1,
)
gr.Markdown("""
---
### Training Tips
**RVC v2:**
- 50-100 epochs for quick test, 200-500 for quality
- CPU training: ~30 sec/epoch (100 epochs ≈ 50 min)
**Beatrice v2:**
- 20-50 epochs recommended, GPU recommended (CPU works but slow)
- Pretrained assets downloaded automatically from HuggingFace
- Output: .pt.gz checkpoint (use in Voice Conversion tab)
- Lower latency (~50ms vs ~100ms)
### CLI
```bash
python app.py train -a voice.mp3 -o ./model --epochs 100
python app.py train-beatrice -a voice.mp3 -o ./beatrice_model --epochs 30
python app.py infer -i input.wav -m beatrice_model.pt.gz -o output.wav
```
""")
def cli_convert(args):
"""CLI mode conversion - supports both RVC and Beatrice models"""
print(f"Converting: {args.input}")
print(f"Model: {args.model}")
# Auto-detect model type from extension or --type flag
model_type = getattr(args, 'type', None)
model_path = args.model
index_path = getattr(args, 'index', None)
if args.example:
print("Downloading example model...")
model_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_MODEL_FILE)
index_path = hf_hub_download(repo_id=DEFAULT_MODEL_REPO, filename=DEFAULT_INDEX_FILE)
model_type = "rvc"
print(f"Model: {model_path}")
if not model_path:
print("Error: No model specified. Use -m MODEL.pth or --example")
sys.exit(1)
# Auto-detect type from extension if not specified
if model_type is None:
if model_path.endswith('.pt.gz') or model_path.endswith('.gz'):
model_type = "beatrice"
else:
model_type = "rvc"
if model_type == "beatrice":
# Beatrice inference
print(f"Using Beatrice v2 inference")
output_path, status = convert_voice_beatrice(
source_audio=args.input,
model_file=model_path,
target_speaker=getattr(args, 'speaker', 0),
pitch_shift=args.pitch,
formant_shift=getattr(args, 'formant_shift', 0.0),
)
else:
# RVC inference
class FileObj:
def __init__(self, path):
self.name = path
model_file = FileObj(model_path)
index_file = FileObj(index_path) if index_path else None
output_path, status = convert_voice(
source_audio=args.input,
model_file=model_file,
index_file=index_file,
pitch_shift=args.pitch,
f0_method=args.f0,
index_rate=args.index_rate,
progress=lambda *a, **k: None
)
if output_path:
shutil.copy(output_path, args.output)
print(f"Output: {args.output}")
print(status)
else:
print(f"Failed: {status}")
sys.exit(1)
def cli_train_beatrice(args):
"""CLI Beatrice training mode (embedded - downloads assets from HF)"""
print(f"=== Beatrice v2 Training (Embedded) ===")
print(f"Input audio: {args.audio}")
print(f"Output dir: {args.output}")
# Preprocess audio - Beatrice expects: data_dir/speaker_name/*.wav
data_dir = os.path.join(args.output, "training_data")
speaker_name = os.path.splitext(os.path.basename(args.audio))[0]
print(f"\n[1/2] Preprocessing audio for Beatrice...")
chunks, _ = preprocess_audio_for_beatrice(args.audio, data_dir, speaker_name)
if chunks == 0:
print("Preprocessing failed - no valid audio chunks")
sys.exit(1)
print(f"Created {chunks} audio chunks")
# Train using embedded code
print(f"\n[2/2] Training Beatrice model ({args.epochs} epochs)...")
ckpt = None
for msg, path in train_beatrice_generator(
data_dir=data_dir,
output_dir=args.output,
epochs=args.epochs,
batch_size=args.batch,
resume=args.resume,
):
print(msg)
if path:
ckpt = path
if ckpt:
print(f"\nTraining complete!")
print(f"Model: {ckpt}")
else:
print("Training failed!")
sys.exit(1)
def cli_train(args):
"""CLI training mode"""
print(f"=== RVC Training ===")
print(f"Input audio: {args.audio}")
print(f"Output dir: {args.output}")
# Create temp dir for preprocessing
data_dir = f"{args.output}/data"
# Preprocess
print("\n[1/2] Preprocessing audio...")
result = preprocess_audio_for_training(args.audio, data_dir, target_sr=args.sr, f0_method=args.f0)
if result is None:
print("Preprocessing failed!")
sys.exit(1)
# Train
print(f"\n[2/2] Training for {args.epochs} epochs...")
ckpt, idx = train_rvc(
data_dir=data_dir,
output_dir=args.output,
epochs=args.epochs,
batch_size=args.batch,
lr=args.lr,
target_sr=args.sr
)
if ckpt:
print(f"\nTraining complete!")
print(f"Model saved: {ckpt}")
if idx:
print(f"Index saved: {idx}")
else:
print("Training failed!")
sys.exit(1)
if __name__ == "__main__":
# Check if any CLI args (besides script name)
if len(sys.argv) > 1:
parser = argparse.ArgumentParser(
description="RVC Voice Conversion - Inference and Training",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
subparsers = parser.add_subparsers(dest="command", help="Commands")
# Inference subcommand
infer_parser = subparsers.add_parser("infer", help="Voice conversion inference (RVC + Beatrice)",
epilog="""
Examples:
python app.py infer -i voice.wav -m model.pth -o output.wav
python app.py infer -i voice.wav -m beatrice_model.pt.gz -o output.wav --type beatrice
python app.py infer -i voice.wav --example -o output.wav
""")
infer_parser.add_argument("-i", "--input", required=True, help="Input audio file")
infer_parser.add_argument("-o", "--output", required=True, help="Output audio file")
infer_parser.add_argument("-m", "--model", help="Model file (.pth for RVC, .pt.gz for Beatrice)")
infer_parser.add_argument("--type", choices=["rvc", "beatrice"], default=None, help="Model type (auto-detected from extension)")
infer_parser.add_argument("--index", help="Index file (.index) - RVC only")
infer_parser.add_argument("--example", action="store_true", help="Use example model (Benee-RVC)")
infer_parser.add_argument("-p", "--pitch", type=int, default=0, help="Pitch shift (-12 to 12)")
infer_parser.add_argument("--f0", choices=["rmvpe", "pm", "harvest"], default="rmvpe", help="F0 method (RVC only)")
infer_parser.add_argument("--index-rate", type=float, default=0.75, help="Index rate 0-1 (RVC only)")
infer_parser.add_argument("--speaker", type=int, default=0, help="Target speaker index (Beatrice only)")
infer_parser.add_argument("--formant-shift", type=float, default=0.0, help="Formant shift -2 to 2 (Beatrice only)")
# Training subcommand (RVC)
train_parser = subparsers.add_parser("train", help="Train RVC model",
epilog="""
Examples:
python app.py train -a voice.mp3 -o ./my_model --epochs 5
python app.py train -a dataset.wav -o ./trained --epochs 10 --batch 4
""")
train_parser.add_argument("-a", "--audio", required=True, help="Training audio file")
train_parser.add_argument("-o", "--output", required=True, help="Output directory for model")
train_parser.add_argument("--epochs", type=int, default=5, help="Number of epochs (default: 5)")
train_parser.add_argument("--batch", type=int, default=2, help="Batch size (default: 2)")
train_parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate (default: 1e-5)")
train_parser.add_argument("--sr", type=int, default=40000, help="Sample rate (default: 40000)")
train_parser.add_argument("--f0", choices=["rmvpe", "pm", "harvest"], default="rmvpe", help="F0 method (default: rmvpe)")
# Beatrice v2 Training subcommand
beatrice_parser = subparsers.add_parser("train-beatrice", help="Train Beatrice v2 model (downloads assets from HF)",
epilog="""
Examples:
python app.py train-beatrice -a voice.mp3 -o ./beatrice_model --epochs 20
python app.py train-beatrice -a dataset.wav -o ./trained --epochs 50 --batch 8 --resume
Pretrained assets are downloaded automatically from HuggingFace.
""")
beatrice_parser.add_argument("-a", "--audio", required=True, help="Training audio file")
beatrice_parser.add_argument("-o", "--output", required=True, help="Output directory for model")
beatrice_parser.add_argument("--epochs", type=int, default=20, help="Number of epochs (default: 20)")
beatrice_parser.add_argument("--batch", type=int, default=24, help="Batch size (default: 24, reduce for less VRAM)")
beatrice_parser.add_argument("--resume", action="store_true", help="Resume from checkpoint")
args = parser.parse_args()
if args.command == "infer":
cli_convert(args)
elif args.command == "train":
cli_train(args)
elif args.command == "train-beatrice":
cli_train_beatrice(args)
else:
parser.print_help()
else:
# ============================================================
# GRADIO QUEUE — Stability config for 2-core CPU HF Spaces
#
# Why these settings matter:
#
# max_size=3
# Hard cap on pending jobs. On a 2-core CPU, RVC can take
# 3–10 minutes per job. Without a cap, users pile up and
# the Space exhausts RAM while holding dozens of audio
# blobs in the queue. 3 is a safe upper bound: one running
# + two waiting. Any new request beyond that gets a clear
# "queue full" HTTP 503 instead of a silent timeout.
#
# default_concurrency_limit=1
# Only one inference job runs at a time. Two simultaneous
# RVC jobs on a 2-core CPU would each get 1 thread, which
# is slower than one job using both cores sequentially.
# Serial execution is strictly faster here.
#
# The convert_btn.click already has concurrency_limit=1 set
# at the event level; this queue-level setting enforces it
# globally (including any API calls) so nothing bypasses it.
#
# status_update_rate="auto"
# Gradio sends SSE heartbeats to the browser on this cadence.
# "auto" (≈ 1 Hz) is enough to keep the WebSocket / SSE
# connection alive across long jobs without flooding the
# 2-core CPU with keep-alive overhead.
# ============================================================
demo.queue(
max_size=3,
default_concurrency_limit=1,
status_update_rate="auto",
).launch(
mcp_server=True,
show_error=True,
ssr_mode=False,
)