Spaces:
Running on Zero
Running on Zero
| """RE-USE (nvidia/RE-USE) speech-enhancement wrapper. | |
| Used by ``TTSServer._denoise_voice_ref`` to denoise the input voice reference | |
| before VAE conditioning. Lazy-loads weights + code on first call so importing | |
| this module is cheap. | |
| up = REUSEUpsampler(target_sr=48000, device="cuda") | |
| clean, sr = up(wav, in_sr=24000) # wav: (C, T) or (T,) float | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| import torch | |
| # REUSE_DIR is resolved lazily via model_downloader.get_reuse_code_path on | |
| # first use of REUSEUpsampler — it returns the vendored third_party/RE-USE/ | |
| # tree if present, otherwise snapshot-downloads just the code from HF. | |
| _REUSE_DIR: Optional[Path] = None | |
| def _resolve_reuse_dir() -> Path: | |
| global _REUSE_DIR | |
| if _REUSE_DIR is None: | |
| from model_downloader import get_reuse_code_path | |
| _REUSE_DIR = Path(get_reuse_code_path()) | |
| return _REUSE_DIR | |
| class REUSEUpsampler: | |
| """Universal speech enhancement with optional bandwidth extension. | |
| nvidia/RE-USE is a 9.6 M-param bidirectional-Mamba model that operates on | |
| STFT amplitude+phase. With ``target_sr`` set it both denoises *and* extends | |
| the bandwidth to that rate via librosa kaiser-best resample + restoration. | |
| License: NSCLv1 (noncommercial). The base ``SEMamba`` class lives in the | |
| HF repo under ``models/generator_SEMamba_time_d4.py`` and pulls in the | |
| ``mamba_ssm`` / ``causal-conv1d`` CUDA kernels. | |
| """ | |
| def __init__( | |
| self, | |
| target_sr: int = 48000, | |
| config_path: Optional[str] = None, | |
| chunk_size_s: float = 1.0, | |
| hop_portion: float = 0.5, | |
| device: str | torch.device = "cuda", | |
| ) -> None: | |
| # chunk_size_s: peak VRAM scales linearly with chunk length. | |
| # 5.0s -> 2.95 GB | 2.5s -> 1.52 GB | 1.0s -> 0.67 GB (default). | |
| # 1.0s is chosen as default so RE-USE fits comfortably on top of the | |
| # rest of the DramaBox pipeline on any 24 GB-class GPU. | |
| self.device = torch.device(device) | |
| self.target_sr = int(target_sr) | |
| self.chunk_size_s = float(chunk_size_s) | |
| self.hop_portion = float(hop_portion) | |
| # Config path is resolved lazily on first use (alongside the code tree) | |
| # so importing this module never triggers a download. | |
| self._config_path_override = Path(config_path) if config_path else None | |
| self.config_path: Optional[Path] = None | |
| self._model = None | |
| self._cfg = None | |
| self._stft_fns = None # (mag_phase_stft, mag_phase_istft, compress_factor, pad_or_trim) | |
| def _ensure_mamba_ssm_importable() -> None: | |
| """Import ``mamba_ssm`` cleanly, with a kernel-free fallback if needed. | |
| Normal path (kernels present): just import — fast path uses | |
| ``selective_scan_cuda`` natively. | |
| Fallback (kernels missing): the official package does an unconditional | |
| ``import selective_scan_cuda`` at module load. We stub it into | |
| ``sys.modules`` before importing, then redirect ``selective_scan_fn`` | |
| to the pure-PyTorch ``selective_scan_ref`` so the model still runs | |
| (~5-10x slower). | |
| """ | |
| try: | |
| import selective_scan_cuda # noqa: F401 | |
| import mamba_ssm # noqa: F401 | |
| return # Fast path: kernel present. | |
| except ImportError: | |
| pass | |
| import types | |
| if "selective_scan_cuda" not in sys.modules: | |
| stub = types.ModuleType("selective_scan_cuda") | |
| def _missing(*a, **kw): # pragma: no cover - safety net only | |
| raise NotImplementedError( | |
| "selective_scan_cuda kernel missing; the call should have " | |
| "been routed to selective_scan_ref via the runtime patch." | |
| ) | |
| stub.fwd = _missing | |
| stub.bwd = _missing | |
| sys.modules["selective_scan_cuda"] = stub | |
| from mamba_ssm.ops import selective_scan_interface as ssi | |
| from mamba_ssm.modules import mamba_simple | |
| if getattr(ssi, "_dramabox_kernel_free_patch_applied", False): | |
| return | |
| ssi.selective_scan_fn = ssi.selective_scan_ref | |
| ssi.mamba_inner_fn = ssi.mamba_inner_ref | |
| # mamba_simple imported these names by reference at module load - | |
| # rebind there too, otherwise Mamba.forward keeps the original handles. | |
| mamba_simple.selective_scan_fn = ssi.selective_scan_ref | |
| mamba_simple.mamba_inner_fn = ssi.mamba_inner_ref | |
| ssi._dramabox_kernel_free_patch_applied = True | |
| logging.info( | |
| "mamba_ssm kernel missing - using kernel-free fallback " | |
| "(selective_scan_fn -> selective_scan_ref). Expect ~5-10x slowdown." | |
| ) | |
| def _lazy_load(self) -> None: | |
| if self._model is not None: | |
| return | |
| # Prefer real CUDA kernels; gracefully fall back to pure-PyTorch impl. | |
| self._ensure_mamba_ssm_importable() | |
| # The RE-USE module imports `from models...` and `from utils...` — | |
| # both relative to the repo root. Add to path during load. | |
| reuse_dir = _resolve_reuse_dir() | |
| if str(reuse_dir) not in sys.path: | |
| sys.path.insert(0, str(reuse_dir)) | |
| if self.config_path is None: | |
| self.config_path = self._config_path_override or ( | |
| reuse_dir / "recipes" / | |
| "USEMamba_30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_001.yaml" | |
| ) | |
| from models.generator_SEMamba_time_d4 import SEMamba # type: ignore | |
| from models.stfts import mag_phase_stft, mag_phase_istft # type: ignore | |
| from utils.util import load_config, pad_or_trim_to_match # type: ignore | |
| self._cfg = load_config(str(self.config_path)) | |
| compress_factor = self._cfg["model_cfg"]["compress_factor"] | |
| self._stft_fns = (mag_phase_stft, mag_phase_istft, compress_factor, pad_or_trim_to_match) | |
| # SEMamba is a PyTorchModelHubMixin; from_pretrained pulls weights from HF. | |
| model = SEMamba.from_pretrained("nvidia/RE-USE", cfg=self._cfg).to(self.device) | |
| model.train(False) | |
| self._model = model | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| logging.info(f"RE-USE loaded: SEMamba ({n_params / 1e6:.1f}M params) -> {self.target_sr} Hz") | |
| def _make_even(v: float) -> int: | |
| v = int(round(v)) | |
| return v if v % 2 == 0 else v + 1 | |
| def __call__(self, waveform: torch.Tensor, in_sr: int = 16000) -> Tuple[torch.Tensor, int]: | |
| """Chunked overlap-add denoise / BWE (ports nvidia/RE-USE inference_chunk.py). | |
| Peak VRAM is bounded by ``chunk_size_s * target_sr`` rather than the | |
| whole clip, so a 60 s clip costs the same as a 5 s one. Crossfade is | |
| a Hann-window normalized overlap-add with default 50% hop. | |
| """ | |
| import math | |
| self._lazy_load() | |
| import librosa | |
| mag_phase_stft, mag_phase_istft, compress_factor, pad_or_trim_to_match = self._stft_fns | |
| # STFT params are scaled relative to the config's training rate (8000). | |
| base_n_fft = self._cfg["stft_cfg"]["n_fft"] | |
| base_hop = self._cfg["stft_cfg"]["hop_size"] | |
| base_win = self._cfg["stft_cfg"]["win_size"] | |
| base_sr = self._cfg["stft_cfg"]["sampling_rate"] | |
| if waveform.dim() == 1: | |
| waveform = waveform.unsqueeze(0) | |
| # 1. Resample to target rate first (skips if target_sr == in_sr). | |
| if self.target_sr != in_sr: | |
| wav_np = waveform.cpu().float().numpy() | |
| wav_np = librosa.resample( | |
| wav_np, orig_sr=in_sr, target_sr=self.target_sr, res_type="kaiser_best" | |
| ) | |
| wav = torch.from_numpy(wav_np).to(self.device, dtype=torch.float32) | |
| else: | |
| wav = waveform.to(self.device, dtype=torch.float32) | |
| op_sr = self.target_sr | |
| n_fft = self._make_even(base_n_fft * op_sr // base_sr) | |
| hop = self._make_even(base_hop * op_sr // base_sr) | |
| win = self._make_even(base_win * op_sr // base_sr) | |
| # 2. Chunked OLA with Hann analysis window. Mirrors inference_chunk.py. | |
| chunk_size = int(self.chunk_size_s * op_sr) | |
| hop_length = int(self.hop_portion * chunk_size) | |
| window = torch.hann_window(chunk_size, device=self.device) | |
| n_ch, total = wav.shape | |
| enhanced = torch.zeros_like(wav) | |
| window_sum = torch.zeros_like(wav) | |
| n_chunks = max(1, math.ceil((total - chunk_size) / hop_length) + 1) if total > chunk_size else 1 | |
| for c in range(n_ch): | |
| ch_in = wav[c : c + 1] # (1, T) | |
| for i in range(n_chunks): | |
| start = i * hop_length | |
| end = min(start + chunk_size, total) | |
| chunk = ch_in[:, start:end] | |
| if chunk.shape[-1] < 2: # skip degenerate tail | |
| continue | |
| noisy_mag, noisy_pha, _ = mag_phase_stft( | |
| chunk, n_fft=n_fft, hop_size=hop, win_size=win, | |
| compress_factor=compress_factor, center=True, addeps=False, | |
| ) | |
| amp_g, pha_g, _ = self._model(noisy_mag, noisy_pha) | |
| # "Sweep artifact" filter — match the official inference. | |
| mag = torch.expm1(torch.relu(amp_g)) | |
| zero_portion = (mag == 0).sum(dim=1) / mag.shape[1] | |
| amp_g[:, :, (zero_portion > 0.5)[0]] = 0 | |
| audio_g = mag_phase_istft(amp_g, pha_g, n_fft, hop, win, compress_factor) | |
| audio_g = pad_or_trim_to_match(chunk.detach(), audio_g, pad_value=1e-8) | |
| w_slice = window[: audio_g.shape[-1]] | |
| enhanced[c : c + 1, start : start + audio_g.shape[-1]] += audio_g * w_slice | |
| window_sum[c : c + 1, start : start + audio_g.shape[-1]] += w_slice | |
| # 3. Normalize where windows overlap. Avoid divide-by-zero at clip tails. | |
| mask = window_sum > 1e-8 | |
| enhanced[mask] = enhanced[mask] / window_sum[mask] | |
| return enhanced.clamp(-1.0, 1.0).cpu().float(), op_sr | |