PatnaikAshish commited on
Commit
66525f3
Β·
verified Β·
1 Parent(s): 8da48d8

Create chunked_convert.py

Browse files
Files changed (1) hide show
  1. core/chunked_convert.py +185 -0
core/chunked_convert.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ core/chunked_convert.py
3
+ -----------------------
4
+ VRAM-aware chunked voice conversion using the Kanade model.
5
+
6
+ On CUDA devices, the source waveform is split into overlapping chunks so that
7
+ peak activation memory stays within a configurable fraction of total VRAM
8
+ (default 50%). On CPU the waveform is still chunked to respect the model's
9
+ RoPE sequence-length limit.
10
+
11
+ RoPE ceiling (why chunks must be small)
12
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
13
+ The Kanade ``mel_decoder`` Transformer processes mel-spectrogram frames of the
14
+ source chunk. Its RoPE positional embeddings are precomputed for
15
+ ``_ROPE_MAX_FRAMES = 1024`` positions. The mel frame count for a window of
16
+ ``W`` samples is ``W // hop_length + 1``. Keeping that ≀ 1024 requires:
17
+
18
+ W ≀ (1024 βˆ’ 1) Γ— hop_length = 1023 Γ— 256 = 261,888 samples β‰ˆ 10.9 s
19
+
20
+ Each chunk window includes a 0.5 s overlap on both sides for boundary
21
+ smoothing, so the *chunk* itself must be:
22
+
23
+ chunk ≀ 261,888 βˆ’ 2 Γ— (0.5 s Γ— sample_rate) β‰ˆ 9.9 s
24
+
25
+ A 10 % safety margin is applied, giving ``_ROPE_SAFE_CHUNK_FACTOR β‰ˆ 8.9 s``
26
+ worth of source audio per chunk.
27
+
28
+ Overlap / boundary handling
29
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
30
+ Each chunk includes a short overlap window on both sides. After the
31
+ voice-conversion forward pass, the overlap frames are trimmed from the mel
32
+ output before the pieces are concatenated. The final assembled mel is vocoded
33
+ in a single pass.
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ import time
39
+ import torch
40
+ from kanade_tokenizer import vocode
41
+
42
+
43
+ # Empirical constant: ~10 seconds of audio fit in 1 GB of VRAM budget for the
44
+ # Kanade-12.5hz model. Adjust downward if you observe OOM errors.
45
+ _SECONDS_PER_GB: float = 10.0
46
+
47
+ # Overlap window on each side of a chunk (seconds).
48
+ _OVERLAP_SECONDS: float = 0.5
49
+
50
+ # --------------------------------------------------------------------------
51
+ # RoPE safety ceiling β€” derived from the mel_decoder Transformer
52
+ # --------------------------------------------------------------------------
53
+ # mel_decoder seqlen = audio_length // hop_length + 1 (center-padding mel).
54
+ # Its RoPE freqs_cis is precomputed for _ROPE_MAX_FRAMES positions.
55
+ # hop_length comes directly from KanadeModelConfig (hop_length = 256).
56
+ _ROPE_MAX_FRAMES: int = 1024 # precomputed RoPE window (freqs_cis.shape[0])
57
+ _MEL_HOP_LENGTH: int = 256 # KanadeModelConfig.hop_length
58
+ _ROPE_SAFETY_MARGIN: float = 0.75
59
+
60
+ # Output mel frame rate β€” kept for reference only; NOT used for overlap trimming.
61
+ # Mel frames used internally are at sample_rate / hop_length (93.75 fps), not 12.5 fps.
62
+ _MEL_FPS: float = 12.5
63
+
64
+
65
+ def chunked_voice_conversion(
66
+ kanade,
67
+ vocoder_model,
68
+ source_wav: torch.Tensor,
69
+ ref_wav: torch.Tensor,
70
+ sample_rate: int,
71
+ vram_fraction: float = 0.9,
72
+ ) -> torch.Tensor:
73
+ """Convert *source_wav* to the reference voice in VRAM-safe chunks.
74
+
75
+ Parameters
76
+ ----------
77
+ kanade:
78
+ A loaded ``KanadeModel`` instance (already on the target device).
79
+ vocoder_model:
80
+ The vocoder loaded via ``load_vocoder`` (already on the target device).
81
+ source_wav:
82
+ Source waveform tensor of shape ``[T]`` or ``[1, T]``, on the same
83
+ device as *kanade*.
84
+ ref_wav:
85
+ Reference waveform tensor of shape ``[T]`` or ``[1, T]``, on the same
86
+ device as *kanade*.
87
+ sample_rate:
88
+ Audio sample rate in Hz (taken from ``kanade.config.sample_rate``).
89
+ vram_fraction:
90
+ Fraction of total VRAM to target per chunk. Default ``0.5`` β†’ 50 %.
91
+
92
+ Returns
93
+ -------
94
+ torch.Tensor
95
+ Converted waveform as a 1-D CPU float32 tensor.
96
+ """
97
+ device: torch.device = source_wav.device
98
+ n_samples: int = source_wav.shape[-1]
99
+ _start = time.perf_counter()
100
+
101
+ # ── 1. Determine chunk size ──────────────────────────────────────────────
102
+ # The mel_decoder RoPE ceiling limits the total window (chunk + overlaps).
103
+ # Max window in samples: (ROPE_MAX_FRAMES - 1) * MEL_HOP_LENGTH
104
+ # Subtract both overlap sides, then apply a safety margin.
105
+ overlap_samples = int(_OVERLAP_SECONDS * sample_rate)
106
+ rope_max_window = (_ROPE_MAX_FRAMES - 1) * _MEL_HOP_LENGTH # 261,888 samples β‰ˆ 10.9 s
107
+ rope_safe_chunk = int((rope_max_window - 2 * overlap_samples) * _ROPE_SAFETY_MARGIN)
108
+ rope_safe_seconds = rope_safe_chunk / sample_rate
109
+
110
+ if device.type == "cuda":
111
+ total_vram_bytes = torch.cuda.get_device_properties(device).total_memory
112
+ budget_bytes = total_vram_bytes * vram_fraction
113
+ budget_gb = budget_bytes / (1024 ** 3)
114
+
115
+ vram_chunk_samples = int(max(5.0, budget_gb * _SECONDS_PER_GB) * sample_rate)
116
+
117
+ # Take the smaller of VRAM-based and RoPE-safe limits.
118
+ chunk_samples = min(vram_chunk_samples, rope_safe_chunk)
119
+ chunk_seconds = chunk_samples / sample_rate
120
+
121
+ print(
122
+ f"[chunked_convert] VRAM budget: {budget_gb:.2f} GB "
123
+ f"({vram_fraction*100:.0f}% of {total_vram_bytes / (1024**3):.2f} GB) "
124
+ f"β†’ chunk size: {chunk_seconds:.1f}s / {chunk_samples:,} samples "
125
+ f"(RoPE ceiling: {rope_safe_seconds:.1f}s)"
126
+ )
127
+ else:
128
+ # CPU: no VRAM limit, but still respect the RoPE ceiling for quality.
129
+ chunk_samples = rope_safe_chunk
130
+
131
+ # ── 2. Short-circuit when the whole file fits in one chunk ───────────────
132
+ if n_samples <= chunk_samples:
133
+ with torch.inference_mode():
134
+ mel = kanade.voice_conversion(
135
+ source_waveform=source_wav, reference_waveform=ref_wav
136
+ )
137
+ wav = vocode(vocoder_model, mel.unsqueeze(0))
138
+ elapsed = time.perf_counter() - _start
139
+ print(f"[chunked_convert] Completed in {elapsed:.1f}s")
140
+ return wav.squeeze().cpu()
141
+
142
+ # ── 3. Chunked processing with overlap ──────────────────────────────────
143
+ # Mel frames corresponding to the overlap window.
144
+ # The mel output is at sample_rate / hop_length = 93.75 fps, NOT _MEL_FPS.
145
+ overlap_frames = overlap_samples // _MEL_HOP_LENGTH # 12000 // 256 = 46
146
+
147
+ mel_parts: list[torch.Tensor] = []
148
+ pos = 0
149
+
150
+ while pos < n_samples:
151
+ # Extend the window on both sides by overlap_samples so the model has
152
+ # context at each boundary.
153
+ win_start = max(0, pos - overlap_samples)
154
+ win_end = min(n_samples, pos + chunk_samples + overlap_samples)
155
+
156
+ chunk = source_wav[..., win_start:win_end]
157
+
158
+ with torch.inference_mode():
159
+ mel_chunk: torch.Tensor = kanade.voice_conversion(
160
+ source_waveform=chunk, reference_waveform=ref_wav
161
+ )
162
+
163
+ # Move to CPU immediately so the GPU buffer is freed before the next chunk.
164
+ mel_chunk = mel_chunk.cpu()
165
+
166
+ # Trim overlap frames that were only there for context.
167
+ left_trim = 0 if pos == 0 else overlap_frames
168
+ right_trim = mel_chunk.shape[-1] if win_end >= n_samples else mel_chunk.shape[-1] - overlap_frames
169
+
170
+ mel_parts.append(mel_chunk[..., left_trim:right_trim])
171
+
172
+ pos += chunk_samples
173
+
174
+ if device.type == "cuda":
175
+ torch.cuda.empty_cache()
176
+
177
+ # ── 4. Assemble full mel and vocode in one pass ──────────────────────────
178
+ full_mel = torch.cat(mel_parts, dim=-1).to(device)
179
+
180
+ with torch.inference_mode():
181
+ wav = vocode(vocoder_model, full_mel.unsqueeze(0))
182
+
183
+ elapsed = time.perf_counter() - _start
184
+ print(f"[chunked_convert] Completed in {elapsed:.1f}s")
185
+ return wav.squeeze().cpu()