Dramabox / src /super_resolution.py
Manmay's picture
Long-form chunking + RE-USE on reference
7e0eb32
"""RE-USE (nvidia/RE-USE) speech-enhancement wrapper.
Used by ``TTSServer._denoise_voice_ref`` to denoise the input voice reference
before VAE conditioning. Lazy-loads weights + code on first call so importing
this module is cheap.
up = REUSEUpsampler(target_sr=48000, device="cuda")
clean, sr = up(wav, in_sr=24000) # wav: (C, T) or (T,) float
"""
from __future__ import annotations
import logging
import sys
from pathlib import Path
from typing import Optional, Tuple
import torch
# REUSE_DIR is resolved lazily via model_downloader.get_reuse_code_path on
# first use of REUSEUpsampler — it returns the vendored third_party/RE-USE/
# tree if present, otherwise snapshot-downloads just the code from HF.
_REUSE_DIR: Optional[Path] = None
def _resolve_reuse_dir() -> Path:
global _REUSE_DIR
if _REUSE_DIR is None:
from model_downloader import get_reuse_code_path
_REUSE_DIR = Path(get_reuse_code_path())
return _REUSE_DIR
class REUSEUpsampler:
"""Universal speech enhancement with optional bandwidth extension.
nvidia/RE-USE is a 9.6 M-param bidirectional-Mamba model that operates on
STFT amplitude+phase. With ``target_sr`` set it both denoises *and* extends
the bandwidth to that rate via librosa kaiser-best resample + restoration.
License: NSCLv1 (noncommercial). The base ``SEMamba`` class lives in the
HF repo under ``models/generator_SEMamba_time_d4.py`` and pulls in the
``mamba_ssm`` / ``causal-conv1d`` CUDA kernels.
"""
def __init__(
self,
target_sr: int = 48000,
config_path: Optional[str] = None,
chunk_size_s: float = 1.0,
hop_portion: float = 0.5,
device: str | torch.device = "cuda",
) -> None:
# chunk_size_s: peak VRAM scales linearly with chunk length.
# 5.0s -> 2.95 GB | 2.5s -> 1.52 GB | 1.0s -> 0.67 GB (default).
# 1.0s is chosen as default so RE-USE fits comfortably on top of the
# rest of the DramaBox pipeline on any 24 GB-class GPU.
self.device = torch.device(device)
self.target_sr = int(target_sr)
self.chunk_size_s = float(chunk_size_s)
self.hop_portion = float(hop_portion)
# Config path is resolved lazily on first use (alongside the code tree)
# so importing this module never triggers a download.
self._config_path_override = Path(config_path) if config_path else None
self.config_path: Optional[Path] = None
self._model = None
self._cfg = None
self._stft_fns = None # (mag_phase_stft, mag_phase_istft, compress_factor, pad_or_trim)
@staticmethod
def _ensure_mamba_ssm_importable() -> None:
"""Import ``mamba_ssm`` cleanly, with a kernel-free fallback if needed.
Normal path (kernels present): just import — fast path uses
``selective_scan_cuda`` natively.
Fallback (kernels missing): the official package does an unconditional
``import selective_scan_cuda`` at module load. We stub it into
``sys.modules`` before importing, then redirect ``selective_scan_fn``
to the pure-PyTorch ``selective_scan_ref`` so the model still runs
(~5-10x slower).
"""
try:
import selective_scan_cuda # noqa: F401
import mamba_ssm # noqa: F401
return # Fast path: kernel present.
except ImportError:
pass
import types
if "selective_scan_cuda" not in sys.modules:
stub = types.ModuleType("selective_scan_cuda")
def _missing(*a, **kw): # pragma: no cover - safety net only
raise NotImplementedError(
"selective_scan_cuda kernel missing; the call should have "
"been routed to selective_scan_ref via the runtime patch."
)
stub.fwd = _missing
stub.bwd = _missing
sys.modules["selective_scan_cuda"] = stub
from mamba_ssm.ops import selective_scan_interface as ssi
from mamba_ssm.modules import mamba_simple
if getattr(ssi, "_dramabox_kernel_free_patch_applied", False):
return
ssi.selective_scan_fn = ssi.selective_scan_ref
ssi.mamba_inner_fn = ssi.mamba_inner_ref
# mamba_simple imported these names by reference at module load -
# rebind there too, otherwise Mamba.forward keeps the original handles.
mamba_simple.selective_scan_fn = ssi.selective_scan_ref
mamba_simple.mamba_inner_fn = ssi.mamba_inner_ref
ssi._dramabox_kernel_free_patch_applied = True
logging.info(
"mamba_ssm kernel missing - using kernel-free fallback "
"(selective_scan_fn -> selective_scan_ref). Expect ~5-10x slowdown."
)
def _lazy_load(self) -> None:
if self._model is not None:
return
# Prefer real CUDA kernels; gracefully fall back to pure-PyTorch impl.
self._ensure_mamba_ssm_importable()
# The RE-USE module imports `from models...` and `from utils...` —
# both relative to the repo root. Add to path during load.
reuse_dir = _resolve_reuse_dir()
if str(reuse_dir) not in sys.path:
sys.path.insert(0, str(reuse_dir))
if self.config_path is None:
self.config_path = self._config_path_override or (
reuse_dir / "recipes" /
"USEMamba_30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_001.yaml"
)
from models.generator_SEMamba_time_d4 import SEMamba # type: ignore
from models.stfts import mag_phase_stft, mag_phase_istft # type: ignore
from utils.util import load_config, pad_or_trim_to_match # type: ignore
self._cfg = load_config(str(self.config_path))
compress_factor = self._cfg["model_cfg"]["compress_factor"]
self._stft_fns = (mag_phase_stft, mag_phase_istft, compress_factor, pad_or_trim_to_match)
# SEMamba is a PyTorchModelHubMixin; from_pretrained pulls weights from HF.
model = SEMamba.from_pretrained("nvidia/RE-USE", cfg=self._cfg).to(self.device)
model.train(False)
self._model = model
n_params = sum(p.numel() for p in model.parameters())
logging.info(f"RE-USE loaded: SEMamba ({n_params / 1e6:.1f}M params) -> {self.target_sr} Hz")
@staticmethod
def _make_even(v: float) -> int:
v = int(round(v))
return v if v % 2 == 0 else v + 1
@torch.inference_mode()
def __call__(self, waveform: torch.Tensor, in_sr: int = 16000) -> Tuple[torch.Tensor, int]:
"""Chunked overlap-add denoise / BWE (ports nvidia/RE-USE inference_chunk.py).
Peak VRAM is bounded by ``chunk_size_s * target_sr`` rather than the
whole clip, so a 60 s clip costs the same as a 5 s one. Crossfade is
a Hann-window normalized overlap-add with default 50% hop.
"""
import math
self._lazy_load()
import librosa
mag_phase_stft, mag_phase_istft, compress_factor, pad_or_trim_to_match = self._stft_fns
# STFT params are scaled relative to the config's training rate (8000).
base_n_fft = self._cfg["stft_cfg"]["n_fft"]
base_hop = self._cfg["stft_cfg"]["hop_size"]
base_win = self._cfg["stft_cfg"]["win_size"]
base_sr = self._cfg["stft_cfg"]["sampling_rate"]
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
# 1. Resample to target rate first (skips if target_sr == in_sr).
if self.target_sr != in_sr:
wav_np = waveform.cpu().float().numpy()
wav_np = librosa.resample(
wav_np, orig_sr=in_sr, target_sr=self.target_sr, res_type="kaiser_best"
)
wav = torch.from_numpy(wav_np).to(self.device, dtype=torch.float32)
else:
wav = waveform.to(self.device, dtype=torch.float32)
op_sr = self.target_sr
n_fft = self._make_even(base_n_fft * op_sr // base_sr)
hop = self._make_even(base_hop * op_sr // base_sr)
win = self._make_even(base_win * op_sr // base_sr)
# 2. Chunked OLA with Hann analysis window. Mirrors inference_chunk.py.
chunk_size = int(self.chunk_size_s * op_sr)
hop_length = int(self.hop_portion * chunk_size)
window = torch.hann_window(chunk_size, device=self.device)
n_ch, total = wav.shape
enhanced = torch.zeros_like(wav)
window_sum = torch.zeros_like(wav)
n_chunks = max(1, math.ceil((total - chunk_size) / hop_length) + 1) if total > chunk_size else 1
for c in range(n_ch):
ch_in = wav[c : c + 1] # (1, T)
for i in range(n_chunks):
start = i * hop_length
end = min(start + chunk_size, total)
chunk = ch_in[:, start:end]
if chunk.shape[-1] < 2: # skip degenerate tail
continue
noisy_mag, noisy_pha, _ = mag_phase_stft(
chunk, n_fft=n_fft, hop_size=hop, win_size=win,
compress_factor=compress_factor, center=True, addeps=False,
)
amp_g, pha_g, _ = self._model(noisy_mag, noisy_pha)
# "Sweep artifact" filter — match the official inference.
mag = torch.expm1(torch.relu(amp_g))
zero_portion = (mag == 0).sum(dim=1) / mag.shape[1]
amp_g[:, :, (zero_portion > 0.5)[0]] = 0
audio_g = mag_phase_istft(amp_g, pha_g, n_fft, hop, win, compress_factor)
audio_g = pad_or_trim_to_match(chunk.detach(), audio_g, pad_value=1e-8)
w_slice = window[: audio_g.shape[-1]]
enhanced[c : c + 1, start : start + audio_g.shape[-1]] += audio_g * w_slice
window_sum[c : c + 1, start : start + audio_g.shape[-1]] += w_slice
# 3. Normalize where windows overlap. Avoid divide-by-zero at clip tails.
mask = window_sum > 1e-8
enhanced[mask] = enhanced[mask] / window_sum[mask]
return enhanced.clamp(-1.0, 1.0).cpu().float(), op_sr