Spaces:
Running on Zero
Running on Zero
File size: 10,275 Bytes
7e0eb32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 | """RE-USE (nvidia/RE-USE) speech-enhancement wrapper.
Used by ``TTSServer._denoise_voice_ref`` to denoise the input voice reference
before VAE conditioning. Lazy-loads weights + code on first call so importing
this module is cheap.
up = REUSEUpsampler(target_sr=48000, device="cuda")
clean, sr = up(wav, in_sr=24000) # wav: (C, T) or (T,) float
"""
from __future__ import annotations
import logging
import sys
from pathlib import Path
from typing import Optional, Tuple
import torch
# REUSE_DIR is resolved lazily via model_downloader.get_reuse_code_path on
# first use of REUSEUpsampler — it returns the vendored third_party/RE-USE/
# tree if present, otherwise snapshot-downloads just the code from HF.
_REUSE_DIR: Optional[Path] = None
def _resolve_reuse_dir() -> Path:
global _REUSE_DIR
if _REUSE_DIR is None:
from model_downloader import get_reuse_code_path
_REUSE_DIR = Path(get_reuse_code_path())
return _REUSE_DIR
class REUSEUpsampler:
"""Universal speech enhancement with optional bandwidth extension.
nvidia/RE-USE is a 9.6 M-param bidirectional-Mamba model that operates on
STFT amplitude+phase. With ``target_sr`` set it both denoises *and* extends
the bandwidth to that rate via librosa kaiser-best resample + restoration.
License: NSCLv1 (noncommercial). The base ``SEMamba`` class lives in the
HF repo under ``models/generator_SEMamba_time_d4.py`` and pulls in the
``mamba_ssm`` / ``causal-conv1d`` CUDA kernels.
"""
def __init__(
self,
target_sr: int = 48000,
config_path: Optional[str] = None,
chunk_size_s: float = 1.0,
hop_portion: float = 0.5,
device: str | torch.device = "cuda",
) -> None:
# chunk_size_s: peak VRAM scales linearly with chunk length.
# 5.0s -> 2.95 GB | 2.5s -> 1.52 GB | 1.0s -> 0.67 GB (default).
# 1.0s is chosen as default so RE-USE fits comfortably on top of the
# rest of the DramaBox pipeline on any 24 GB-class GPU.
self.device = torch.device(device)
self.target_sr = int(target_sr)
self.chunk_size_s = float(chunk_size_s)
self.hop_portion = float(hop_portion)
# Config path is resolved lazily on first use (alongside the code tree)
# so importing this module never triggers a download.
self._config_path_override = Path(config_path) if config_path else None
self.config_path: Optional[Path] = None
self._model = None
self._cfg = None
self._stft_fns = None # (mag_phase_stft, mag_phase_istft, compress_factor, pad_or_trim)
@staticmethod
def _ensure_mamba_ssm_importable() -> None:
"""Import ``mamba_ssm`` cleanly, with a kernel-free fallback if needed.
Normal path (kernels present): just import — fast path uses
``selective_scan_cuda`` natively.
Fallback (kernels missing): the official package does an unconditional
``import selective_scan_cuda`` at module load. We stub it into
``sys.modules`` before importing, then redirect ``selective_scan_fn``
to the pure-PyTorch ``selective_scan_ref`` so the model still runs
(~5-10x slower).
"""
try:
import selective_scan_cuda # noqa: F401
import mamba_ssm # noqa: F401
return # Fast path: kernel present.
except ImportError:
pass
import types
if "selective_scan_cuda" not in sys.modules:
stub = types.ModuleType("selective_scan_cuda")
def _missing(*a, **kw): # pragma: no cover - safety net only
raise NotImplementedError(
"selective_scan_cuda kernel missing; the call should have "
"been routed to selective_scan_ref via the runtime patch."
)
stub.fwd = _missing
stub.bwd = _missing
sys.modules["selective_scan_cuda"] = stub
from mamba_ssm.ops import selective_scan_interface as ssi
from mamba_ssm.modules import mamba_simple
if getattr(ssi, "_dramabox_kernel_free_patch_applied", False):
return
ssi.selective_scan_fn = ssi.selective_scan_ref
ssi.mamba_inner_fn = ssi.mamba_inner_ref
# mamba_simple imported these names by reference at module load -
# rebind there too, otherwise Mamba.forward keeps the original handles.
mamba_simple.selective_scan_fn = ssi.selective_scan_ref
mamba_simple.mamba_inner_fn = ssi.mamba_inner_ref
ssi._dramabox_kernel_free_patch_applied = True
logging.info(
"mamba_ssm kernel missing - using kernel-free fallback "
"(selective_scan_fn -> selective_scan_ref). Expect ~5-10x slowdown."
)
def _lazy_load(self) -> None:
if self._model is not None:
return
# Prefer real CUDA kernels; gracefully fall back to pure-PyTorch impl.
self._ensure_mamba_ssm_importable()
# The RE-USE module imports `from models...` and `from utils...` —
# both relative to the repo root. Add to path during load.
reuse_dir = _resolve_reuse_dir()
if str(reuse_dir) not in sys.path:
sys.path.insert(0, str(reuse_dir))
if self.config_path is None:
self.config_path = self._config_path_override or (
reuse_dir / "recipes" /
"USEMamba_30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_001.yaml"
)
from models.generator_SEMamba_time_d4 import SEMamba # type: ignore
from models.stfts import mag_phase_stft, mag_phase_istft # type: ignore
from utils.util import load_config, pad_or_trim_to_match # type: ignore
self._cfg = load_config(str(self.config_path))
compress_factor = self._cfg["model_cfg"]["compress_factor"]
self._stft_fns = (mag_phase_stft, mag_phase_istft, compress_factor, pad_or_trim_to_match)
# SEMamba is a PyTorchModelHubMixin; from_pretrained pulls weights from HF.
model = SEMamba.from_pretrained("nvidia/RE-USE", cfg=self._cfg).to(self.device)
model.train(False)
self._model = model
n_params = sum(p.numel() for p in model.parameters())
logging.info(f"RE-USE loaded: SEMamba ({n_params / 1e6:.1f}M params) -> {self.target_sr} Hz")
@staticmethod
def _make_even(v: float) -> int:
v = int(round(v))
return v if v % 2 == 0 else v + 1
@torch.inference_mode()
def __call__(self, waveform: torch.Tensor, in_sr: int = 16000) -> Tuple[torch.Tensor, int]:
"""Chunked overlap-add denoise / BWE (ports nvidia/RE-USE inference_chunk.py).
Peak VRAM is bounded by ``chunk_size_s * target_sr`` rather than the
whole clip, so a 60 s clip costs the same as a 5 s one. Crossfade is
a Hann-window normalized overlap-add with default 50% hop.
"""
import math
self._lazy_load()
import librosa
mag_phase_stft, mag_phase_istft, compress_factor, pad_or_trim_to_match = self._stft_fns
# STFT params are scaled relative to the config's training rate (8000).
base_n_fft = self._cfg["stft_cfg"]["n_fft"]
base_hop = self._cfg["stft_cfg"]["hop_size"]
base_win = self._cfg["stft_cfg"]["win_size"]
base_sr = self._cfg["stft_cfg"]["sampling_rate"]
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
# 1. Resample to target rate first (skips if target_sr == in_sr).
if self.target_sr != in_sr:
wav_np = waveform.cpu().float().numpy()
wav_np = librosa.resample(
wav_np, orig_sr=in_sr, target_sr=self.target_sr, res_type="kaiser_best"
)
wav = torch.from_numpy(wav_np).to(self.device, dtype=torch.float32)
else:
wav = waveform.to(self.device, dtype=torch.float32)
op_sr = self.target_sr
n_fft = self._make_even(base_n_fft * op_sr // base_sr)
hop = self._make_even(base_hop * op_sr // base_sr)
win = self._make_even(base_win * op_sr // base_sr)
# 2. Chunked OLA with Hann analysis window. Mirrors inference_chunk.py.
chunk_size = int(self.chunk_size_s * op_sr)
hop_length = int(self.hop_portion * chunk_size)
window = torch.hann_window(chunk_size, device=self.device)
n_ch, total = wav.shape
enhanced = torch.zeros_like(wav)
window_sum = torch.zeros_like(wav)
n_chunks = max(1, math.ceil((total - chunk_size) / hop_length) + 1) if total > chunk_size else 1
for c in range(n_ch):
ch_in = wav[c : c + 1] # (1, T)
for i in range(n_chunks):
start = i * hop_length
end = min(start + chunk_size, total)
chunk = ch_in[:, start:end]
if chunk.shape[-1] < 2: # skip degenerate tail
continue
noisy_mag, noisy_pha, _ = mag_phase_stft(
chunk, n_fft=n_fft, hop_size=hop, win_size=win,
compress_factor=compress_factor, center=True, addeps=False,
)
amp_g, pha_g, _ = self._model(noisy_mag, noisy_pha)
# "Sweep artifact" filter — match the official inference.
mag = torch.expm1(torch.relu(amp_g))
zero_portion = (mag == 0).sum(dim=1) / mag.shape[1]
amp_g[:, :, (zero_portion > 0.5)[0]] = 0
audio_g = mag_phase_istft(amp_g, pha_g, n_fft, hop, win, compress_factor)
audio_g = pad_or_trim_to_match(chunk.detach(), audio_g, pad_value=1e-8)
w_slice = window[: audio_g.shape[-1]]
enhanced[c : c + 1, start : start + audio_g.shape[-1]] += audio_g * w_slice
window_sum[c : c + 1, start : start + audio_g.shape[-1]] += w_slice
# 3. Normalize where windows overlap. Avoid divide-by-zero at clip tails.
mask = window_sum > 1e-8
enhanced[mask] = enhanced[mask] / window_sum[mask]
return enhanced.clamp(-1.0, 1.0).cpu().float(), op_sr
|