clone / core /chunked_convert.py
PatnaikAshish's picture
Create chunked_convert.py
66525f3 verified
"""
core/chunked_convert.py
-----------------------
VRAM-aware chunked voice conversion using the Kanade model.
On CUDA devices, the source waveform is split into overlapping chunks so that
peak activation memory stays within a configurable fraction of total VRAM
(default 50%). On CPU the waveform is still chunked to respect the model's
RoPE sequence-length limit.
RoPE ceiling (why chunks must be small)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The Kanade ``mel_decoder`` Transformer processes mel-spectrogram frames of the
source chunk. Its RoPE positional embeddings are precomputed for
``_ROPE_MAX_FRAMES = 1024`` positions. The mel frame count for a window of
``W`` samples is ``W // hop_length + 1``. Keeping that ≀ 1024 requires:
W ≀ (1024 βˆ’ 1) Γ— hop_length = 1023 Γ— 256 = 261,888 samples β‰ˆ 10.9 s
Each chunk window includes a 0.5 s overlap on both sides for boundary
smoothing, so the *chunk* itself must be:
chunk ≀ 261,888 βˆ’ 2 Γ— (0.5 s Γ— sample_rate) β‰ˆ 9.9 s
A 10 % safety margin is applied, giving ``_ROPE_SAFE_CHUNK_FACTOR β‰ˆ 8.9 s``
worth of source audio per chunk.
Overlap / boundary handling
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Each chunk includes a short overlap window on both sides. After the
voice-conversion forward pass, the overlap frames are trimmed from the mel
output before the pieces are concatenated. The final assembled mel is vocoded
in a single pass.
"""
from __future__ import annotations
import time
import torch
from kanade_tokenizer import vocode
# Empirical constant: ~10 seconds of audio fit in 1 GB of VRAM budget for the
# Kanade-12.5hz model. Adjust downward if you observe OOM errors.
_SECONDS_PER_GB: float = 10.0
# Overlap window on each side of a chunk (seconds).
_OVERLAP_SECONDS: float = 0.5
# --------------------------------------------------------------------------
# RoPE safety ceiling β€” derived from the mel_decoder Transformer
# --------------------------------------------------------------------------
# mel_decoder seqlen = audio_length // hop_length + 1 (center-padding mel).
# Its RoPE freqs_cis is precomputed for _ROPE_MAX_FRAMES positions.
# hop_length comes directly from KanadeModelConfig (hop_length = 256).
_ROPE_MAX_FRAMES: int = 1024 # precomputed RoPE window (freqs_cis.shape[0])
_MEL_HOP_LENGTH: int = 256 # KanadeModelConfig.hop_length
_ROPE_SAFETY_MARGIN: float = 0.75
# Output mel frame rate β€” kept for reference only; NOT used for overlap trimming.
# Mel frames used internally are at sample_rate / hop_length (93.75 fps), not 12.5 fps.
_MEL_FPS: float = 12.5
def chunked_voice_conversion(
kanade,
vocoder_model,
source_wav: torch.Tensor,
ref_wav: torch.Tensor,
sample_rate: int,
vram_fraction: float = 0.9,
) -> torch.Tensor:
"""Convert *source_wav* to the reference voice in VRAM-safe chunks.
Parameters
----------
kanade:
A loaded ``KanadeModel`` instance (already on the target device).
vocoder_model:
The vocoder loaded via ``load_vocoder`` (already on the target device).
source_wav:
Source waveform tensor of shape ``[T]`` or ``[1, T]``, on the same
device as *kanade*.
ref_wav:
Reference waveform tensor of shape ``[T]`` or ``[1, T]``, on the same
device as *kanade*.
sample_rate:
Audio sample rate in Hz (taken from ``kanade.config.sample_rate``).
vram_fraction:
Fraction of total VRAM to target per chunk. Default ``0.5`` β†’ 50 %.
Returns
-------
torch.Tensor
Converted waveform as a 1-D CPU float32 tensor.
"""
device: torch.device = source_wav.device
n_samples: int = source_wav.shape[-1]
_start = time.perf_counter()
# ── 1. Determine chunk size ──────────────────────────────────────────────
# The mel_decoder RoPE ceiling limits the total window (chunk + overlaps).
# Max window in samples: (ROPE_MAX_FRAMES - 1) * MEL_HOP_LENGTH
# Subtract both overlap sides, then apply a safety margin.
overlap_samples = int(_OVERLAP_SECONDS * sample_rate)
rope_max_window = (_ROPE_MAX_FRAMES - 1) * _MEL_HOP_LENGTH # 261,888 samples β‰ˆ 10.9 s
rope_safe_chunk = int((rope_max_window - 2 * overlap_samples) * _ROPE_SAFETY_MARGIN)
rope_safe_seconds = rope_safe_chunk / sample_rate
if device.type == "cuda":
total_vram_bytes = torch.cuda.get_device_properties(device).total_memory
budget_bytes = total_vram_bytes * vram_fraction
budget_gb = budget_bytes / (1024 ** 3)
vram_chunk_samples = int(max(5.0, budget_gb * _SECONDS_PER_GB) * sample_rate)
# Take the smaller of VRAM-based and RoPE-safe limits.
chunk_samples = min(vram_chunk_samples, rope_safe_chunk)
chunk_seconds = chunk_samples / sample_rate
print(
f"[chunked_convert] VRAM budget: {budget_gb:.2f} GB "
f"({vram_fraction*100:.0f}% of {total_vram_bytes / (1024**3):.2f} GB) "
f"β†’ chunk size: {chunk_seconds:.1f}s / {chunk_samples:,} samples "
f"(RoPE ceiling: {rope_safe_seconds:.1f}s)"
)
else:
# CPU: no VRAM limit, but still respect the RoPE ceiling for quality.
chunk_samples = rope_safe_chunk
# ── 2. Short-circuit when the whole file fits in one chunk ───────────────
if n_samples <= chunk_samples:
with torch.inference_mode():
mel = kanade.voice_conversion(
source_waveform=source_wav, reference_waveform=ref_wav
)
wav = vocode(vocoder_model, mel.unsqueeze(0))
elapsed = time.perf_counter() - _start
print(f"[chunked_convert] Completed in {elapsed:.1f}s")
return wav.squeeze().cpu()
# ── 3. Chunked processing with overlap ──────────────────────────────────
# Mel frames corresponding to the overlap window.
# The mel output is at sample_rate / hop_length = 93.75 fps, NOT _MEL_FPS.
overlap_frames = overlap_samples // _MEL_HOP_LENGTH # 12000 // 256 = 46
mel_parts: list[torch.Tensor] = []
pos = 0
while pos < n_samples:
# Extend the window on both sides by overlap_samples so the model has
# context at each boundary.
win_start = max(0, pos - overlap_samples)
win_end = min(n_samples, pos + chunk_samples + overlap_samples)
chunk = source_wav[..., win_start:win_end]
with torch.inference_mode():
mel_chunk: torch.Tensor = kanade.voice_conversion(
source_waveform=chunk, reference_waveform=ref_wav
)
# Move to CPU immediately so the GPU buffer is freed before the next chunk.
mel_chunk = mel_chunk.cpu()
# Trim overlap frames that were only there for context.
left_trim = 0 if pos == 0 else overlap_frames
right_trim = mel_chunk.shape[-1] if win_end >= n_samples else mel_chunk.shape[-1] - overlap_frames
mel_parts.append(mel_chunk[..., left_trim:right_trim])
pos += chunk_samples
if device.type == "cuda":
torch.cuda.empty_cache()
# ── 4. Assemble full mel and vocode in one pass ──────────────────────────
full_mel = torch.cat(mel_parts, dim=-1).to(device)
with torch.inference_mode():
wav = vocode(vocoder_model, full_mel.unsqueeze(0))
elapsed = time.perf_counter() - _start
print(f"[chunked_convert] Completed in {elapsed:.1f}s")
return wav.squeeze().cpu()