Manmay commited on
Commit
7e0eb32
·
1 Parent(s): e53641f

Long-form chunking + RE-USE on reference

Browse files

Port long-form generation from the upstream DramaBox repo (commit 382a37c)
to the HF Space:

- Auto chunk-and-stitch when the prompt's estimated (or explicit) duration
exceeds the max_chunk_duration cap; quote-aware sentence splitter (new
src/text_chunker.py) and shared duration estimator (new
src/duration_estimator.py). Chunks are stitched with an equal-power
crossfade so independently-generated joins are inaudible.

- RE-USE input-side voice-reference denoise (new src/super_resolution.py).
Applied to the *reference* before VAE encoding so the model conditions on
a clean speaker / style anchor; the generated paralinguistic content
(laughs, breaths, sighs) stays untouched. Cached per session so chunked
runs don't re-denoise the same reference per chunk. Silently falls back
if the mamba_ssm / causal-conv1d kernels can't be loaded.

- /generate_audio API gains denoise_ref + chunking knobs with sensible
defaults, so the existing index.html client (which sends only the
original kwargs) keeps working.

- @spaces.GPU duration bumped 60s → 600s for multi-chunk runs.

- requirements.txt: add resampy / mamba-ssm / causal-conv1d (Linux only;
optional — failures fall back to skipping the denoise).

app.py CHANGED
@@ -183,7 +183,7 @@ async def homepage():
183
 
184
 
185
  @app.api()
186
- @spaces.GPU(duration=60)
187
  def generate_audio(
188
  prompt: str,
189
  audio_ref: FileData | None,
@@ -192,11 +192,15 @@ def generate_audio(
192
  dur_mult: float,
193
  gen_dur: float,
194
  ref_dur: float,
195
- seed: int
 
 
 
 
196
  ) -> FileData:
197
  if not prompt or not prompt.strip():
198
  raise gr.Error("Prompt is empty.")
199
-
200
  t0 = time.time()
201
  ref_path = None
202
  if audio_ref:
@@ -206,8 +210,12 @@ def generate_audio(
206
  ref_path = audio_ref.path
207
  if ref_path and not os.path.exists(ref_path):
208
  ref_path = None
209
-
210
  output = tempfile.mktemp(suffix=".wav", prefix="dramabox_")
 
 
 
 
211
  tts.generate_to_file(
212
  prompt=prompt,
213
  output=output,
@@ -218,6 +226,10 @@ def generate_audio(
218
  seed=int(seed),
219
  gen_duration=float(gen_dur),
220
  ref_duration=float(ref_dur),
 
 
 
 
221
  )
222
  elapsed = time.time() - t0
223
  logging.info(f"Generated in {elapsed:.2f}s -> {output}")
 
183
 
184
 
185
  @app.api()
186
+ @spaces.GPU(duration=600)
187
  def generate_audio(
188
  prompt: str,
189
  audio_ref: FileData | None,
 
192
  dur_mult: float,
193
  gen_dur: float,
194
  ref_dur: float,
195
+ seed: int,
196
+ denoise_ref: bool = True,
197
+ max_chunk_duration: float = 45.0,
198
+ target_chunk_duration: float = 37.0,
199
+ crossfade_ms: float = 50.0,
200
  ) -> FileData:
201
  if not prompt or not prompt.strip():
202
  raise gr.Error("Prompt is empty.")
203
+
204
  t0 = time.time()
205
  ref_path = None
206
  if audio_ref:
 
210
  ref_path = audio_ref.path
211
  if ref_path and not os.path.exists(ref_path):
212
  ref_path = None
213
+
214
  output = tempfile.mktemp(suffix=".wav", prefix="dramabox_")
215
+ # Long-form: generate_to_file auto-routes to the chunk-and-stitch path when
216
+ # the estimated (or explicit gen_dur) duration exceeds max_chunk_duration.
217
+ # denoise_ref runs RE-USE on the voice reference before VAE encoding so the
218
+ # model conditions on a cleaner speaker / style anchor.
219
  tts.generate_to_file(
220
  prompt=prompt,
221
  output=output,
 
226
  seed=int(seed),
227
  gen_duration=float(gen_dur),
228
  ref_duration=float(ref_dur),
229
+ denoise_ref=bool(denoise_ref),
230
+ max_chunk_duration=float(max_chunk_duration),
231
+ target_chunk_duration=float(target_chunk_duration),
232
+ crossfade_ms=float(crossfade_ms),
233
  )
234
  elapsed = time.time() - t0
235
  logging.info(f"Generated in {elapsed:.2f}s -> {output}")
requirements.txt CHANGED
@@ -24,3 +24,15 @@ gradio==6.14.0
24
  spaces>=0.30.0
25
  soundfile>=0.12.0
26
  resemble-perth @ git+https://github.com/resemble-ai/Perth.git@master
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  spaces>=0.30.0
25
  soundfile>=0.12.0
26
  resemble-perth @ git+https://github.com/resemble-ai/Perth.git@master
27
+
28
+ # ── Optional: NVIDIA RE-USE speech enhancement (input-side voice-ref denoise) ─
29
+ # RE-USE is applied to the uploaded voice reference before VAE encoding so the
30
+ # model conditions on a clean speaker / style anchor (generated paralinguistic
31
+ # events — laughs, breaths, sighs — stay untouched because the denoiser only
32
+ # touches the reference). resampy is used for its pre-resample step; mamba_ssm
33
+ # + causal-conv1d power the bi-Mamba kernels. The kernels have no pre-built
34
+ # wheels for every CUDA toolkit — if installs fail, app.py logs a warning
35
+ # once and silently skips the reference denoise for the rest of the session.
36
+ resampy>=0.4.0
37
+ mamba-ssm>=2.2.0 ; platform_system == "Linux"
38
+ causal-conv1d>=1.4.0 ; platform_system == "Linux"
src/duration_estimator.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pure-Python speech-duration estimator for DramaBox prompts.
2
+
3
+ Originally lived in ``inference.py`` but pulled out so chunkers / tooling /
4
+ unit tests can import it without dragging torch + the LTX pipeline through
5
+ sys.path. ``inference.py`` and ``inference_server.py`` continue to import
6
+ ``estimate_speech_duration`` from here.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import re
11
+
12
+
13
+ _LAUGH_VERBS = {
14
+ # base seconds per occurrence; gets scaled by the modifier found nearby.
15
+ # Verb regex covers inflections: laugh/laughs/laughed/laughing.
16
+ r"\blaugh(?:s|ed|ing)?\b": 1.5,
17
+ r"\bcackl(?:e|es|ed|ing)\b": 1.5,
18
+ r"\bchuckl(?:e|es|ed|ing)\b": 1.0,
19
+ r"\bgiggl(?:e|es|ed|ing)\b": 1.0,
20
+ r"\bsnicker(?:s|ed|ing)?\b": 0.8,
21
+ r"\bcru?el laugh\b": 1.5,
22
+ }
23
+
24
+
25
+ def _contextual_laugh_duration(text: str) -> float:
26
+ """Context-aware laugh budget.
27
+
28
+ For each laugh verb in the prompt, look at the adjective/adverb that
29
+ modifies it and scale the base duration:
30
+ - short modifiers (briefly, softly, once) -> 0.4x base
31
+ - long modifiers (maniacally, heartily, ...) -> 1.2x base
32
+ - default (no mod / neutral) -> 1.0x base
33
+ Also reward phonetic repetition inside quotes -- 'Hahahahahaha' buys more
34
+ time than 'Haha' -- at ~0.2s per extra repeated syllable.
35
+ """
36
+ short_mod = re.compile(
37
+ r"^\s*(?:[a-z]+ly )?(?:briefly|shortly|once|quickly)",
38
+ re.IGNORECASE)
39
+ long_mod = re.compile(
40
+ r"^\s*(?:[a-z]+ly )?(?:maniacally|heartily|uproariously|uncontrollably|"
41
+ r"hysterically|darkly|wickedly|evilly|loudly|long)"
42
+ r"|^\s*between phrases", re.IGNORECASE)
43
+
44
+ total = 0.0
45
+ for pat, base_dur in _LAUGH_VERBS.items():
46
+ for m in re.finditer(pat, text, re.IGNORECASE):
47
+ ctx = text[m.end(): m.end() + 40]
48
+ if short_mod.match(ctx):
49
+ total += base_dur * 0.4
50
+ elif long_mod.match(ctx):
51
+ total += base_dur * 1.2
52
+ else:
53
+ total += base_dur
54
+
55
+ # Phonetic laugh repetition inside quotes.
56
+ for q in re.findall(r'"([^"]+)"', text) + re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text):
57
+ for run in re.findall(r"(?:h[ae]){3,}|(?:h[ae][ \-]?){3,}", q, re.IGNORECASE):
58
+ syls = len(re.findall(r"h[ae]", run, re.IGNORECASE))
59
+ total += 0.2 * max(syls - 2, 0)
60
+ return total
61
+
62
+
63
+ def _estimate_nonverbal_duration(text: str) -> float:
64
+ """Estimate extra duration for non-verbal sounds and actions in the prompt.
65
+
66
+ Laugh-verb handling lives in ``_contextual_laugh_duration`` so cackle /
67
+ chuckle / laugh budgets scale with the adjective ("maniacally" vs
68
+ "briefly") and with the repetition length of 'Ha'/'He' tokens inside
69
+ quotes.
70
+ """
71
+ PATTERNS = {
72
+ r'\bsighs?\b': 0.8, r'\bshaky breath\b': 1.0, r'\bbreathing deeply\b': 1.0,
73
+ r'\bgasps?\b': 0.5, r'\bburps?\b': 0.5, r'\byawns?\b': 1.0,
74
+ r'\bpants?\b': 0.8, r'\bwheezes?\b': 0.8, r'\bcoughs?\b': 0.8,
75
+ r'\bsniffles?\b': 0.5, r'\bsnorts?\b': 0.3, r'\bgroans?\b': 0.8,
76
+ r'\blong pause\b': 1.0, r'\bpauses? briefly\b': 0.3,
77
+ r'\bpauses?\b': 0.5, r'\bsilence\b': 1.0,
78
+ r'\blets? the .{1,20} hang\b': 1.0, r'\blets? .{1,20} sink in\b': 1.0,
79
+ r'\bslams?\b': 0.5, r'\bclaps?\b': 0.3,
80
+ r'\bdraws? (?:his|her|a) sword\b': 0.5,
81
+ r'\btakes? a (?:drag|swig|sip|drink)\b': 0.5,
82
+ r'\bwhistles?\b': 1.0, r'\bhums?\b': 0.8,
83
+ r'\bmutters?\b': 1.5, r'\bmumbles?\b': 1.0, r'\bwhispers?\b': 0.0,
84
+ r'\bclears? (?:his|her) throat\b': 0.5, r'\bgulps?\b': 0.5,
85
+ r'\bswallows?\b': 0.5,
86
+ r'\bvoice (?:breaks?|cracks?|trembles?|drops?|rises?)\b': 0.5,
87
+ r'\bsteadies? (?:him|her)self\b': 1.0,
88
+ r'\bcatches? (?:his|her) breath\b': 1.0,
89
+ r'\bcomposes? (?:him|her)self\b': 0.8,
90
+ r'\bdemeanor shifts?\b': 0.5, r'\bsettles? in\b': 0.5,
91
+ r'\bleans? in\b': 0.3, r'\bwipes? (?:his|her) eyes\b': 0.5,
92
+ }
93
+ extra = 0.0
94
+ for pattern, dur in PATTERNS.items():
95
+ extra += dur * len(re.findall(pattern, text, re.IGNORECASE))
96
+ extra += _contextual_laugh_duration(text)
97
+ return extra
98
+
99
+
100
+ def estimate_speech_duration(text: str, speed: float = 1.0) -> float:
101
+ """Estimate speech duration from spoken content + non-verbal actions.
102
+
103
+ Extracts spoken text by priority:
104
+ 1. Quoted text ('...' or "...") -- official prompt guide format
105
+ 2. Text after colon -- simple "Speaker: dialogue" format
106
+ 3. Full text -- fallback
107
+
108
+ Also scans the full prompt for non-verbal cues (laughs, pauses, sighs,
109
+ gasps, etc.) and adds estimated duration for each.
110
+ """
111
+ quotes = re.findall(r'"([^"]+)"', text)
112
+ if not quotes:
113
+ quotes = re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text)
114
+ quotes = [q for q in quotes if len(q.split()) > 3]
115
+ if quotes:
116
+ spoken = " ".join(quotes)
117
+ elif ":" in text:
118
+ spoken = text.split(":", 1)[1].strip()
119
+ else:
120
+ spoken = text
121
+
122
+ CHARS_PER_SEC = 14.0
123
+ text_len = len(spoken)
124
+
125
+ if text_len < 40:
126
+ chars_per_sec = CHARS_PER_SEC * 0.6
127
+ elif text_len < 80:
128
+ chars_per_sec = CHARS_PER_SEC * 0.8
129
+ else:
130
+ chars_per_sec = CHARS_PER_SEC
131
+
132
+ chars_per_sec *= speed
133
+ duration = text_len / chars_per_sec
134
+
135
+ sentence_count = spoken.count(".") + spoken.count("!") + spoken.count("?")
136
+ duration += sentence_count * 0.3
137
+
138
+ duration += _estimate_nonverbal_duration(text)
139
+
140
+ return max(3.0, round(duration + 2.0, 1))
src/inference.py CHANGED
@@ -74,151 +74,14 @@ def detect_model_type(checkpoint_path: str) -> str:
74
  return "distilled"
75
 
76
 
77
- _LAUGH_VERBS = {
78
- # base seconds per occurrence; gets scaled by the modifier found nearby.
79
- # Verb regex covers inflections: laugh/laughs/laughed/laughing.
80
- r"\blaugh(?:s|ed|ing)?\b": 1.5,
81
- r"\bcackl(?:e|es|ed|ing)\b": 1.5,
82
- r"\bchuckl(?:e|es|ed|ing)\b": 1.0,
83
- r"\bgiggl(?:e|es|ed|ing)\b": 1.0,
84
- r"\bsnicker(?:s|ed|ing)?\b": 0.8,
85
- r"\bcru?el laugh\b": 1.5,
86
- }
87
-
88
-
89
- def _contextual_laugh_duration(text: str) -> float:
90
- """Context-aware laugh budget.
91
-
92
- For each laugh verb in the prompt, look at the adjective/adverb that
93
- modifies it and scale the base duration:
94
- - short modifiers (briefly, softly, once) -> 0.4x base
95
- - long modifiers (maniacally, heartily, ...) -> 1.2x base
96
- - default (no mod / neutral) -> 1.0x base
97
- Also reward phonetic repetition inside quotes -- 'Hahahahahaha' buys more
98
- time than 'Haha' -- at ~0.2s per extra repeated syllable.
99
- """
100
- # "softly" / "quietly" describe volume not length, so keep at default 1.0x.
101
- short_mod = re.compile(
102
- r"^\s*(?:[a-z]+ly )?(?:briefly|shortly|once|quickly)",
103
- re.IGNORECASE)
104
- long_mod = re.compile(
105
- r"^\s*(?:[a-z]+ly )?(?:maniacally|heartily|uproariously|uncontrollably|"
106
- r"hysterically|darkly|wickedly|evilly|loudly|long)"
107
- r"|^\s*between phrases", re.IGNORECASE)
108
-
109
- total = 0.0
110
- for pat, base_dur in _LAUGH_VERBS.items():
111
- for m in re.finditer(pat, text, re.IGNORECASE):
112
- ctx = text[m.end(): m.end() + 40]
113
- if short_mod.match(ctx):
114
- total += base_dur * 0.4
115
- elif long_mod.match(ctx):
116
- total += base_dur * 1.2
117
- else:
118
- total += base_dur
119
-
120
- # Phonetic laugh repetition inside quotes:
121
- # 'Haha' = 2 syllables (base, no bonus)
122
- # 'Hahahaha' = 4 syllables (+0.4s)
123
- # 'Hehehehahahahahahahaha' ~ 10 syllables (+1.6s)
124
- for q in re.findall(r'"([^"]+)"', text) + re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text):
125
- for run in re.findall(r"(?:h[ae]){3,}|(?:h[ae][ \-]?){3,}", q, re.IGNORECASE):
126
- syls = len(re.findall(r"h[ae]", run, re.IGNORECASE))
127
- total += 0.2 * max(syls - 2, 0)
128
- return total
129
-
130
-
131
- def _estimate_nonverbal_duration(text: str) -> float:
132
- """Estimate extra duration for non-verbal sounds and actions in the prompt.
133
-
134
- Laugh-verb handling lives in ``_contextual_laugh_duration`` so cackle /
135
- chuckle / laugh budgets scale with the adjective ("maniacally" vs
136
- "briefly") and with the repetition length of 'Ha'/'He' tokens inside
137
- quotes.
138
- """
139
- PATTERNS = {
140
- # Breathing / sighs
141
- r'\bsighs?\b': 0.8, r'\bshaky breath\b': 1.0, r'\bbreathing deeply\b': 1.0,
142
- r'\bgasps?\b': 0.5, r'\bburps?\b': 0.5, r'\byawns?\b': 1.0,
143
- r'\bpants?\b': 0.8, r'\bwheezes?\b': 0.8, r'\bcoughs?\b': 0.8,
144
- r'\bsniffles?\b': 0.5, r'\bsnorts?\b': 0.3, r'\bgroans?\b': 0.8,
145
- # Pauses (trimmed; earlier values over-budgeted silence)
146
- r'\blong pause\b': 1.0, r'\bpauses? briefly\b': 0.3,
147
- r'\bpauses?\b': 0.5, r'\bsilence\b': 1.0,
148
- r'\blets? the .{1,20} hang\b': 1.0, r'\blets? .{1,20} sink in\b': 1.0,
149
- # Physical actions that produce sound
150
- r'\bslams?\b': 0.5, r'\bclaps?\b': 0.3,
151
- r'\bdraws? (?:his|her|a) sword\b': 0.5,
152
- r'\btakes? a (?:drag|swig|sip|drink)\b': 0.5,
153
- r'\bwhistles?\b': 1.0, r'\bhums?\b': 0.8,
154
- # Vocal actions (not in quotes but take time)
155
- r'\bmutters?\b': 1.5, r'\bmumbles?\b': 1.0, r'\bwhispers?\b': 0.0,
156
- r'\bclears? (?:his|her) throat\b': 0.5, r'\bgulps?\b': 0.5,
157
- r'\bswallows?\b': 0.5,
158
- # (laugh / chuckle / cackle / giggle / snicker handled by
159
- # _contextual_laugh_duration below -- modifier-aware, not flat.)
160
- # Emotional transitions
161
- r'\bvoice (?:breaks?|cracks?|trembles?|drops?|rises?)\b': 0.5,
162
- r'\bsteadies? (?:him|her)self\b': 1.0,
163
- r'\bcatches? (?:his|her) breath\b': 1.0,
164
- r'\bcomposes? (?:him|her)self\b': 0.8,
165
- # Scene transitions that imply time
166
- r'\bdemeanor shifts?\b': 0.5, r'\bsettles? in\b': 0.5,
167
- r'\bleans? in\b': 0.3, r'\bwipes? (?:his|her) eyes\b': 0.5,
168
- }
169
- extra = 0.0
170
- for pattern, dur in PATTERNS.items():
171
- extra += dur * len(re.findall(pattern, text, re.IGNORECASE))
172
- extra += _contextual_laugh_duration(text)
173
- return extra
174
-
175
-
176
- def estimate_speech_duration(text: str, speed: float = 1.0) -> float:
177
- """Estimate speech duration from spoken content + non-verbal actions.
178
-
179
- Extracts spoken text by priority:
180
- 1. Quoted text ('...' or "...") -- official prompt guide format
181
- 2. Text after colon -- simple "Speaker: dialogue" format
182
- 3. Full text -- fallback
183
-
184
- Also scans the full prompt for non-verbal cues (laughs, pauses, sighs,
185
- gasps, etc.) and adds estimated duration for each.
186
- """
187
- # Try double quotes first (clean, no contraction issues)
188
- quotes = re.findall(r'"([^"]+)"', text)
189
- if not quotes:
190
- # Single quotes: allow apostrophes in contractions (don't, can't, it's)
191
- # Match ' to ' but apostrophes NOT followed by space/punctuation are kept inside
192
- quotes = re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text)
193
- # Filter out short fragments (scene directions like "He pauses")
194
- quotes = [q for q in quotes if len(q.split()) > 3]
195
- if quotes:
196
- spoken = " ".join(quotes)
197
- elif ":" in text:
198
- spoken = text.split(":", 1)[1].strip()
199
- else:
200
- spoken = text
201
-
202
- CHARS_PER_SEC = 14.0
203
- text_len = len(spoken)
204
-
205
- if text_len < 40:
206
- chars_per_sec = CHARS_PER_SEC * 0.6
207
- elif text_len < 80:
208
- chars_per_sec = CHARS_PER_SEC * 0.8
209
- else:
210
- chars_per_sec = CHARS_PER_SEC
211
-
212
- chars_per_sec *= speed
213
- duration = text_len / chars_per_sec
214
-
215
- sentence_count = spoken.count(".") + spoken.count("!") + spoken.count("?")
216
- duration += sentence_count * 0.3
217
-
218
- # Add time for non-verbal sounds/actions in the full prompt
219
- duration += _estimate_nonverbal_duration(text)
220
-
221
- return max(3.0, round(duration + 2.0, 1))
222
 
223
 
224
  def parse_args():
 
74
  return "distilled"
75
 
76
 
77
+ # Duration estimator lives in duration_estimator.py so that text_chunker and
78
+ # other tooling can import it without dragging the torch / LTX pipeline.
79
+ from duration_estimator import ( # noqa: E402,F401
80
+ estimate_speech_duration,
81
+ _contextual_laugh_duration,
82
+ _estimate_nonverbal_duration,
83
+ _LAUGH_VERBS,
84
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
  def parse_args():
src/inference_server.py CHANGED
@@ -14,6 +14,7 @@ import re
14
  import sys
15
  import time
16
  from pathlib import Path
 
17
 
18
  import torch
19
  import torchaudio
@@ -53,13 +54,40 @@ DEFAULT_NEG = "worst quality, inconsistent, robotic, distorted, noise, static, m
53
 
54
 
55
  def estimate_duration(prompt, multiplier=1.1):
56
- """Defer to the richer CLI estimator (sentence-aware + non-verbal action
57
- budget) so warm-server outputs match the lengths of the per-call CLI runs."""
58
- from inference import estimate_speech_duration
59
  base = estimate_speech_duration(prompt)
60
  return max(3.0, round(base * multiplier, 1))
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def auto_rescale_for_cfg(cfg: float) -> float:
64
  """CFG-aware std-rescale schedule that prevents output clipping at high cfg.
65
 
@@ -110,6 +138,11 @@ class TTSServer:
110
  self._velocity_model = None
111
  self._audio_conditioner = None
112
  self._audio_decoder = None
 
 
 
 
 
113
 
114
  logging.info(f"TTSServer loading on {device}...")
115
  t0 = time.time()
@@ -203,10 +236,78 @@ class TTSServer:
203
  )
204
  logging.info(f" AudioDecoder (warm): {time.time()-t0:.1f}s")
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  @torch.inference_mode()
207
  def generate(self, prompt, voice_ref=None, cfg_scale=2.5, stg_scale=1.5,
208
  duration_multiplier=1.1, seed=42, ref_duration=10.0,
209
- rescale_scale="auto", gen_duration: float = 0.0):
 
210
  """Generate audio. Returns (waveform_path, duration_seconds).
211
 
212
  rescale_scale: latent-side CFG std-rescale that prevents clipping at
@@ -214,6 +315,10 @@ class TTSServer:
214
  float in [0, 1] for a fixed override, or 0 to disable.
215
  gen_duration: explicit target duration in seconds. 0 (default) → auto
216
  from prompt + duration_multiplier; >0 overrides everything else.
 
 
 
 
217
  """
218
  t_total = time.time()
219
 
@@ -236,6 +341,8 @@ class TTSServer:
236
  if voice_ref and os.path.exists(voice_ref):
237
  t0 = time.time()
238
  voice = decode_audio_from_file(voice_ref, self.device, 0.0, ref_duration)
 
 
239
  w = voice.waveform
240
  if w.dim() == 2:
241
  if w.shape[0] == 1:
@@ -323,15 +430,115 @@ class TTSServer:
323
 
324
  t0 = time.time()
325
  decoded = self._audio_decoder(latent)
326
- logging.info(f"Decode: {time.time()-t0:.2f}s")
 
327
 
328
  total = time.time() - t_total
329
- dur = decoded.waveform.shape[-1] / decoded.sampling_rate
330
  logging.info(f"Total: {total:.2f}s for {dur:.1f}s audio")
331
- return decoded.waveform, decoded.sampling_rate
332
 
333
- def generate_to_file(self, prompt, output, watermark: bool = True, **kwargs):
334
- waveform, sr = self.generate(prompt, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  wav_cpu = waveform.cpu().float()
336
  if watermark:
337
  try:
 
14
  import sys
15
  import time
16
  from pathlib import Path
17
+ from typing import Optional
18
 
19
  import torch
20
  import torchaudio
 
54
 
55
 
56
  def estimate_duration(prompt, multiplier=1.1):
57
+ """Defer to the shared sentence-aware + non-verbal action budget estimator
58
+ so warm-server outputs match the lengths of the per-call CLI runs."""
59
+ from duration_estimator import estimate_speech_duration
60
  base = estimate_speech_duration(prompt)
61
  return max(3.0, round(base * multiplier, 1))
62
 
63
 
64
+ def _equal_power_crossfade(prev: torch.Tensor, nxt: torch.Tensor,
65
+ sample_rate: int, fade_ms: float = 50.0) -> torch.Tensor:
66
+ """Equal-power crossfade concat: ``[prev | nxt]`` with a smooth boundary.
67
+
68
+ Both tensors are (C, T). Returns (C, T_prev + T_nxt - T_fade).
69
+
70
+ Equal-power (cos/sin envelopes) keeps perceived loudness constant through
71
+ the join — unlike a linear fade, which dips by ~3 dB in the middle when
72
+ the two sources are uncorrelated. Default 50 ms is short enough to be
73
+ inaudible on speech while still masking any waveform-level discontinuity
74
+ between independently-generated chunks.
75
+ """
76
+ fade_samples = int(round(fade_ms * 1e-3 * sample_rate))
77
+ fade_samples = max(1, min(fade_samples, prev.shape[-1], nxt.shape[-1]))
78
+ if fade_samples <= 1:
79
+ return torch.cat([prev, nxt], dim=-1)
80
+
81
+ t = torch.linspace(0.0, 1.0, fade_samples, device=prev.device, dtype=prev.dtype)
82
+ fade_out = torch.cos(t * torch.pi / 2) # 1.0 -> 0.0
83
+ fade_in = torch.sin(t * torch.pi / 2) # 0.0 -> 1.0
84
+
85
+ prev_tail = prev[..., -fade_samples:] * fade_out
86
+ nxt_head = nxt[..., :fade_samples] * fade_in
87
+ mixed = prev_tail + nxt_head
88
+ return torch.cat([prev[..., :-fade_samples], mixed, nxt[..., fade_samples:]], dim=-1)
89
+
90
+
91
  def auto_rescale_for_cfg(cfg: float) -> float:
92
  """CFG-aware std-rescale schedule that prevents output clipping at high cfg.
93
 
 
138
  self._velocity_model = None
139
  self._audio_conditioner = None
140
  self._audio_decoder = None
141
+ # RE-USE denoiser for the voice reference (input-side denoise).
142
+ # Lazy-loaded on first use; the cleaned-waveform cache below keeps
143
+ # chunked generations from re-denoising the same 10 s clip per chunk.
144
+ self._ref_denoiser = None
145
+ self._ref_denoise_cache: dict[tuple, "torch.Tensor"] = {}
146
 
147
  logging.info(f"TTSServer loading on {device}...")
148
  t0 = time.time()
 
236
  )
237
  logging.info(f" AudioDecoder (warm): {time.time()-t0:.1f}s")
238
 
239
+ def _denoise_voice_ref(self, voice, voice_ref_path: str, ref_duration: float):
240
+ """Run RE-USE on the loaded voice reference and replace its waveform
241
+ with a cleaned mono signal.
242
+
243
+ Why pre-condition rather than post-generate: applying RE-USE to the
244
+ *output* suppresses paralinguistic events the model generates (laughs,
245
+ gasps, breaths, sighs) because they're broadband, non-tonal — exactly
246
+ what universal speech enhancement targets as "noise". Running it on
247
+ the *reference* instead gives the model a clean speaker / style
248
+ anchor, which it generalises from at inference time, while leaving
249
+ the generated paralinguistic content untouched.
250
+
251
+ Cached by ``(path, ref_duration, sampling_rate)`` so chunked
252
+ generations don't re-denoise the same 10 s clip per chunk.
253
+ """
254
+ cache_key = (voice_ref_path, float(ref_duration), int(voice.sampling_rate))
255
+ if cache_key in self._ref_denoise_cache:
256
+ return Audio(
257
+ waveform=self._ref_denoise_cache[cache_key],
258
+ sampling_rate=voice.sampling_rate,
259
+ )
260
+
261
+ # Lazy-load the denoiser. target_sr = input sr → no librosa resample
262
+ # round-trip; RE-USE does pure denoise. (The 48 kHz BWE that
263
+ # REUSEUpsampler can do is irrelevant here — the VAE conditioner
264
+ # resamples internally to whatever the audio branch expects.)
265
+ if self._ref_denoiser is None:
266
+ from super_resolution import REUSEUpsampler
267
+ try:
268
+ self._ref_denoiser = REUSEUpsampler(
269
+ target_sr=int(voice.sampling_rate),
270
+ device=self.device,
271
+ chunk_size_s=1.0,
272
+ )
273
+ except Exception as e:
274
+ # Mamba kernels / weights missing → silently skip the denoise
275
+ # rather than blocking generation. Surfaces once per session.
276
+ logging.warning(f"Voice-ref denoise disabled (RE-USE unavailable: {e})")
277
+ self._ref_denoiser = False # sentinel: don't retry this session
278
+ return voice
279
+
280
+ if self._ref_denoiser is False:
281
+ return voice
282
+
283
+ w = voice.waveform
284
+ # Collapse to mono — voice cloning is speaker-as-mono-source; we'll
285
+ # re-broadcast back to stereo after the conditioner.
286
+ if w.dim() == 3:
287
+ mono = w[0].mean(dim=0)
288
+ elif w.dim() == 2:
289
+ mono = w.mean(dim=0)
290
+ else:
291
+ mono = w
292
+ mono = mono.contiguous()
293
+
294
+ t0 = time.time()
295
+ cleaned, _ = self._ref_denoiser(mono, in_sr=int(voice.sampling_rate))
296
+ if cleaned.dim() == 2 and cleaned.shape[0] == 1:
297
+ cleaned = cleaned[0]
298
+ # Restore the (1, C=1, T) shape that the rest of the pipeline expects
299
+ # to consume — downstream code re-expands channels via repeat().
300
+ cleaned = cleaned.unsqueeze(0).unsqueeze(0).to(self.device, dtype=w.dtype)
301
+ logging.info(f"Voice-ref denoise (RE-USE): {time.time() - t0:.2f}s")
302
+
303
+ self._ref_denoise_cache[cache_key] = cleaned
304
+ return Audio(waveform=cleaned, sampling_rate=voice.sampling_rate)
305
+
306
  @torch.inference_mode()
307
  def generate(self, prompt, voice_ref=None, cfg_scale=2.5, stg_scale=1.5,
308
  duration_multiplier=1.1, seed=42, ref_duration=10.0,
309
+ rescale_scale="auto", gen_duration: float = 0.0,
310
+ denoise_ref: bool = True):
311
  """Generate audio. Returns (waveform_path, duration_seconds).
312
 
313
  rescale_scale: latent-side CFG std-rescale that prevents clipping at
 
315
  float in [0, 1] for a fixed override, or 0 to disable.
316
  gen_duration: explicit target duration in seconds. 0 (default) → auto
317
  from prompt + duration_multiplier; >0 overrides everything else.
318
+ denoise_ref: when True (default) and a voice reference is provided,
319
+ RE-USE is applied to the *reference* before VAE encoding so the
320
+ model conditions on a clean speaker / style anchor. Generated
321
+ output (24→48 kHz) always goes through the LTX BigVGAN BWE.
322
  """
323
  t_total = time.time()
324
 
 
341
  if voice_ref and os.path.exists(voice_ref):
342
  t0 = time.time()
343
  voice = decode_audio_from_file(voice_ref, self.device, 0.0, ref_duration)
344
+ if denoise_ref:
345
+ voice = self._denoise_voice_ref(voice, voice_ref, ref_duration)
346
  w = voice.waveform
347
  if w.dim() == 2:
348
  if w.shape[0] == 1:
 
430
 
431
  t0 = time.time()
432
  decoded = self._audio_decoder(latent)
433
+ out_waveform, out_sr = decoded.waveform, decoded.sampling_rate
434
+ logging.info(f"Decode (LTX BWE): {time.time()-t0:.2f}s")
435
 
436
  total = time.time() - t_total
437
+ dur = out_waveform.shape[-1] / out_sr
438
  logging.info(f"Total: {total:.2f}s for {dur:.1f}s audio")
439
+ return out_waveform, out_sr
440
 
441
+ @torch.inference_mode()
442
+ def generate_long(self, prompt, max_chunk_duration: float = 45.0,
443
+ target_chunk_duration: float = 37.0,
444
+ crossfade_ms: float = 50.0,
445
+ progress_callback=None,
446
+ **kwargs):
447
+ """Chunk-and-stitch generation for prompts whose estimated duration
448
+ exceeds ``max_chunk_duration``.
449
+
450
+ Splits ``prompt`` into <= ``max_chunk_duration`` chunks via
451
+ :func:`text_chunker.chunk_prompt_for_duration`, generates each one
452
+ through :meth:`generate` (same voice reference + seed for every
453
+ chunk, so speaker identity stays coherent across joins), and
454
+ concatenates the waveforms with an equal-power crossfade.
455
+
456
+ Returns ``(waveform, sample_rate)`` matching :meth:`generate`.
457
+ """
458
+ from text_chunker import chunk_prompt_for_duration
459
+
460
+ # gen_duration / duration_multiplier are per-chunk; pop them out so we
461
+ # control sizing here and forward only the per-chunk values.
462
+ per_chunk_mul = float(kwargs.pop("duration_multiplier", 1.1))
463
+ # gen_duration coming in as a global target only makes sense for the
464
+ # single-shot path; chunked generation derives durations per chunk.
465
+ kwargs.pop("gen_duration", None)
466
+
467
+ chunks = chunk_prompt_for_duration(
468
+ prompt,
469
+ max_duration_s=max_chunk_duration,
470
+ target_duration_s=target_chunk_duration,
471
+ duration_multiplier=per_chunk_mul,
472
+ )
473
+ logging.info(f"Long-form: {len(chunks)} chunks (target {target_chunk_duration:.0f}s, "
474
+ f"max {max_chunk_duration:.0f}s)")
475
+
476
+ out_waveform: Optional[torch.Tensor] = None
477
+ out_sr: Optional[int] = None
478
+ t_total = time.time()
479
+ for idx, chunk in enumerate(chunks):
480
+ logging.info(f" Chunk {idx + 1}/{len(chunks)}: est {chunk.est_duration_s:.1f}s, "
481
+ f"{len(chunk.text)} chars")
482
+ if progress_callback is not None:
483
+ try:
484
+ progress_callback(idx, len(chunks), chunk.est_duration_s)
485
+ except Exception as e:
486
+ logging.warning(f"progress_callback raised, ignoring: {e}")
487
+ wav, sr = self.generate(
488
+ chunk.text,
489
+ duration_multiplier=per_chunk_mul,
490
+ **kwargs,
491
+ )
492
+ wav = wav.cpu().float()
493
+ if out_waveform is None:
494
+ out_waveform, out_sr = wav, sr
495
+ else:
496
+ if sr != out_sr:
497
+ raise RuntimeError(f"Sample-rate mismatch between chunks: {out_sr} vs {sr}")
498
+ # Align channel counts: stereo crossfade with a mono buddy
499
+ # broadcasts cleanly via torch.cat after equalising dim 0.
500
+ if wav.shape[0] != out_waveform.shape[0]:
501
+ if wav.shape[0] == 1:
502
+ wav = wav.repeat(out_waveform.shape[0], 1)
503
+ elif out_waveform.shape[0] == 1:
504
+ out_waveform = out_waveform.repeat(wav.shape[0], 1)
505
+ out_waveform = _equal_power_crossfade(out_waveform, wav, out_sr,
506
+ fade_ms=crossfade_ms)
507
+
508
+ total_dur = out_waveform.shape[-1] / out_sr
509
+ logging.info(f"Long-form total: {time.time() - t_total:.2f}s wall, {total_dur:.1f}s audio")
510
+ return out_waveform, out_sr
511
+
512
+ def generate_to_file(self, prompt, output, watermark: bool = True,
513
+ max_chunk_duration: float = 45.0,
514
+ target_chunk_duration: float = 37.0,
515
+ crossfade_ms: float = 50.0,
516
+ progress_callback=None,
517
+ **kwargs):
518
+ # Auto-route to generate_long when the requested duration (explicit
519
+ # gen_duration if set, otherwise prompt-estimated) exceeds the chunk
520
+ # cap. Single-shot path otherwise — same as before, no regression for
521
+ # short prompts.
522
+ explicit_dur = float(kwargs.get("gen_duration") or 0.0)
523
+ est_dur = explicit_dur if explicit_dur > 0 else estimate_duration(
524
+ prompt, kwargs.get("duration_multiplier", 1.1))
525
+
526
+ if est_dur > max_chunk_duration:
527
+ waveform, sr = self.generate_long(
528
+ prompt,
529
+ max_chunk_duration=max_chunk_duration,
530
+ target_chunk_duration=target_chunk_duration,
531
+ crossfade_ms=crossfade_ms,
532
+ progress_callback=progress_callback,
533
+ **kwargs,
534
+ )
535
+ else:
536
+ if progress_callback is not None:
537
+ try:
538
+ progress_callback(0, 1, est_dur)
539
+ except Exception:
540
+ pass
541
+ waveform, sr = self.generate(prompt, **kwargs)
542
  wav_cpu = waveform.cpu().float()
543
  if watermark:
544
  try:
src/model_downloader.py CHANGED
@@ -15,12 +15,10 @@ logger = logging.getLogger(__name__)
15
 
16
  DRAMABOX_REPO = "ResembleAI/Dramabox"
17
  GEMMA_REPO = "unsloth/gemma-3-12b-it-bnb-4bit"
 
18
 
19
  # Default cache directory
20
- DEFAULT_CACHE = os.environ.get(
21
- "DRAMABOX_CACHE",
22
- os.path.join(os.path.expanduser("~"), ".cache", "dramabox"),
23
- )
24
 
25
  # Model files in the HF repo (flat structure)
26
  MODEL_FILES = {
@@ -75,6 +73,42 @@ def get_gemma_path(cache_dir: str = None) -> str:
75
  return local_dir
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def get_all_paths(cache_dir: str = None) -> dict:
79
  """Download all required models and return paths dict.
80
 
 
15
 
16
  DRAMABOX_REPO = "ResembleAI/Dramabox"
17
  GEMMA_REPO = "unsloth/gemma-3-12b-it-bnb-4bit"
18
+ REUSE_REPO = "nvidia/RE-USE"
19
 
20
  # Default cache directory
21
+ DEFAULT_CACHE = os.path.join(os.environ.get("HF_HOME", os.path.expanduser("~")), ".cache", "dramabox")
 
 
 
22
 
23
  # Model files in the HF repo (flat structure)
24
  MODEL_FILES = {
 
73
  return local_dir
74
 
75
 
76
+ def get_reuse_code_path(cache_dir: str = None) -> str:
77
+ """Fetch the nvidia/RE-USE code + configs needed by REUSEUpsampler.
78
+
79
+ Only the .py / .yaml / .json files are pulled (~150 KB) — the 38 MB
80
+ ``model.safetensors`` is intentionally skipped because
81
+ ``SEMamba.from_pretrained("nvidia/RE-USE", ...)`` re-downloads weights
82
+ through the standard HF cache on first instantiation, so vendoring them
83
+ here would just duplicate ~38 MB on disk.
84
+
85
+ Honors $REUSE_DIR for a pre-vendored copy (e.g. ``third_party/RE-USE/``):
86
+ if set and exists, that path is returned without touching the network.
87
+ Falls back to ``third_party/RE-USE/`` if it already contains the model
88
+ file, otherwise snapshot-downloads into the dramabox cache.
89
+ """
90
+ env_dir = os.environ.get("REUSE_DIR")
91
+ if env_dir and Path(env_dir).is_dir():
92
+ return env_dir
93
+
94
+ repo_root = Path(__file__).resolve().parent.parent
95
+ local_vendor = repo_root / "third_party" / "RE-USE"
96
+ if (local_vendor / "models" / "generator_SEMamba_time_d4.py").is_file():
97
+ return str(local_vendor)
98
+
99
+ cache_dir = cache_dir or DEFAULT_CACHE
100
+ logger.info(f"Fetching RE-USE code/configs from {REUSE_REPO}...")
101
+ local_dir = snapshot_download(
102
+ repo_id=REUSE_REPO,
103
+ cache_dir=cache_dir,
104
+ token=os.environ.get("HF_TOKEN"),
105
+ allow_patterns=["*.py", "*.yaml", "*.json",
106
+ "recipes/*", "models/*.py", "utils/*.py"],
107
+ )
108
+ logger.info(f" -> {local_dir}")
109
+ return local_dir
110
+
111
+
112
  def get_all_paths(cache_dir: str = None) -> dict:
113
  """Download all required models and return paths dict.
114
 
src/super_resolution.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RE-USE (nvidia/RE-USE) speech-enhancement wrapper.
2
+
3
+ Used by ``TTSServer._denoise_voice_ref`` to denoise the input voice reference
4
+ before VAE conditioning. Lazy-loads weights + code on first call so importing
5
+ this module is cheap.
6
+
7
+ up = REUSEUpsampler(target_sr=48000, device="cuda")
8
+ clean, sr = up(wav, in_sr=24000) # wav: (C, T) or (T,) float
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import logging
13
+ import sys
14
+ from pathlib import Path
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+
19
+
20
+ # REUSE_DIR is resolved lazily via model_downloader.get_reuse_code_path on
21
+ # first use of REUSEUpsampler — it returns the vendored third_party/RE-USE/
22
+ # tree if present, otherwise snapshot-downloads just the code from HF.
23
+ _REUSE_DIR: Optional[Path] = None
24
+
25
+
26
+ def _resolve_reuse_dir() -> Path:
27
+ global _REUSE_DIR
28
+ if _REUSE_DIR is None:
29
+ from model_downloader import get_reuse_code_path
30
+ _REUSE_DIR = Path(get_reuse_code_path())
31
+ return _REUSE_DIR
32
+
33
+
34
+ class REUSEUpsampler:
35
+ """Universal speech enhancement with optional bandwidth extension.
36
+
37
+ nvidia/RE-USE is a 9.6 M-param bidirectional-Mamba model that operates on
38
+ STFT amplitude+phase. With ``target_sr`` set it both denoises *and* extends
39
+ the bandwidth to that rate via librosa kaiser-best resample + restoration.
40
+
41
+ License: NSCLv1 (noncommercial). The base ``SEMamba`` class lives in the
42
+ HF repo under ``models/generator_SEMamba_time_d4.py`` and pulls in the
43
+ ``mamba_ssm`` / ``causal-conv1d`` CUDA kernels.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ target_sr: int = 48000,
49
+ config_path: Optional[str] = None,
50
+ chunk_size_s: float = 1.0,
51
+ hop_portion: float = 0.5,
52
+ device: str | torch.device = "cuda",
53
+ ) -> None:
54
+ # chunk_size_s: peak VRAM scales linearly with chunk length.
55
+ # 5.0s -> 2.95 GB | 2.5s -> 1.52 GB | 1.0s -> 0.67 GB (default).
56
+ # 1.0s is chosen as default so RE-USE fits comfortably on top of the
57
+ # rest of the DramaBox pipeline on any 24 GB-class GPU.
58
+ self.device = torch.device(device)
59
+ self.target_sr = int(target_sr)
60
+ self.chunk_size_s = float(chunk_size_s)
61
+ self.hop_portion = float(hop_portion)
62
+ # Config path is resolved lazily on first use (alongside the code tree)
63
+ # so importing this module never triggers a download.
64
+ self._config_path_override = Path(config_path) if config_path else None
65
+ self.config_path: Optional[Path] = None
66
+ self._model = None
67
+ self._cfg = None
68
+ self._stft_fns = None # (mag_phase_stft, mag_phase_istft, compress_factor, pad_or_trim)
69
+
70
+ @staticmethod
71
+ def _ensure_mamba_ssm_importable() -> None:
72
+ """Import ``mamba_ssm`` cleanly, with a kernel-free fallback if needed.
73
+
74
+ Normal path (kernels present): just import — fast path uses
75
+ ``selective_scan_cuda`` natively.
76
+
77
+ Fallback (kernels missing): the official package does an unconditional
78
+ ``import selective_scan_cuda`` at module load. We stub it into
79
+ ``sys.modules`` before importing, then redirect ``selective_scan_fn``
80
+ to the pure-PyTorch ``selective_scan_ref`` so the model still runs
81
+ (~5-10x slower).
82
+ """
83
+ try:
84
+ import selective_scan_cuda # noqa: F401
85
+ import mamba_ssm # noqa: F401
86
+ return # Fast path: kernel present.
87
+ except ImportError:
88
+ pass
89
+
90
+ import types
91
+ if "selective_scan_cuda" not in sys.modules:
92
+ stub = types.ModuleType("selective_scan_cuda")
93
+ def _missing(*a, **kw): # pragma: no cover - safety net only
94
+ raise NotImplementedError(
95
+ "selective_scan_cuda kernel missing; the call should have "
96
+ "been routed to selective_scan_ref via the runtime patch."
97
+ )
98
+ stub.fwd = _missing
99
+ stub.bwd = _missing
100
+ sys.modules["selective_scan_cuda"] = stub
101
+
102
+ from mamba_ssm.ops import selective_scan_interface as ssi
103
+ from mamba_ssm.modules import mamba_simple
104
+ if getattr(ssi, "_dramabox_kernel_free_patch_applied", False):
105
+ return
106
+ ssi.selective_scan_fn = ssi.selective_scan_ref
107
+ ssi.mamba_inner_fn = ssi.mamba_inner_ref
108
+ # mamba_simple imported these names by reference at module load -
109
+ # rebind there too, otherwise Mamba.forward keeps the original handles.
110
+ mamba_simple.selective_scan_fn = ssi.selective_scan_ref
111
+ mamba_simple.mamba_inner_fn = ssi.mamba_inner_ref
112
+ ssi._dramabox_kernel_free_patch_applied = True
113
+ logging.info(
114
+ "mamba_ssm kernel missing - using kernel-free fallback "
115
+ "(selective_scan_fn -> selective_scan_ref). Expect ~5-10x slowdown."
116
+ )
117
+
118
+ def _lazy_load(self) -> None:
119
+ if self._model is not None:
120
+ return
121
+
122
+ # Prefer real CUDA kernels; gracefully fall back to pure-PyTorch impl.
123
+ self._ensure_mamba_ssm_importable()
124
+
125
+ # The RE-USE module imports `from models...` and `from utils...` —
126
+ # both relative to the repo root. Add to path during load.
127
+ reuse_dir = _resolve_reuse_dir()
128
+ if str(reuse_dir) not in sys.path:
129
+ sys.path.insert(0, str(reuse_dir))
130
+
131
+ if self.config_path is None:
132
+ self.config_path = self._config_path_override or (
133
+ reuse_dir / "recipes" /
134
+ "USEMamba_30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_001.yaml"
135
+ )
136
+
137
+ from models.generator_SEMamba_time_d4 import SEMamba # type: ignore
138
+ from models.stfts import mag_phase_stft, mag_phase_istft # type: ignore
139
+ from utils.util import load_config, pad_or_trim_to_match # type: ignore
140
+
141
+ self._cfg = load_config(str(self.config_path))
142
+ compress_factor = self._cfg["model_cfg"]["compress_factor"]
143
+ self._stft_fns = (mag_phase_stft, mag_phase_istft, compress_factor, pad_or_trim_to_match)
144
+
145
+ # SEMamba is a PyTorchModelHubMixin; from_pretrained pulls weights from HF.
146
+ model = SEMamba.from_pretrained("nvidia/RE-USE", cfg=self._cfg).to(self.device)
147
+ model.train(False)
148
+ self._model = model
149
+ n_params = sum(p.numel() for p in model.parameters())
150
+ logging.info(f"RE-USE loaded: SEMamba ({n_params / 1e6:.1f}M params) -> {self.target_sr} Hz")
151
+
152
+ @staticmethod
153
+ def _make_even(v: float) -> int:
154
+ v = int(round(v))
155
+ return v if v % 2 == 0 else v + 1
156
+
157
+ @torch.inference_mode()
158
+ def __call__(self, waveform: torch.Tensor, in_sr: int = 16000) -> Tuple[torch.Tensor, int]:
159
+ """Chunked overlap-add denoise / BWE (ports nvidia/RE-USE inference_chunk.py).
160
+
161
+ Peak VRAM is bounded by ``chunk_size_s * target_sr`` rather than the
162
+ whole clip, so a 60 s clip costs the same as a 5 s one. Crossfade is
163
+ a Hann-window normalized overlap-add with default 50% hop.
164
+ """
165
+ import math
166
+ self._lazy_load()
167
+ import librosa
168
+ mag_phase_stft, mag_phase_istft, compress_factor, pad_or_trim_to_match = self._stft_fns
169
+
170
+ # STFT params are scaled relative to the config's training rate (8000).
171
+ base_n_fft = self._cfg["stft_cfg"]["n_fft"]
172
+ base_hop = self._cfg["stft_cfg"]["hop_size"]
173
+ base_win = self._cfg["stft_cfg"]["win_size"]
174
+ base_sr = self._cfg["stft_cfg"]["sampling_rate"]
175
+
176
+ if waveform.dim() == 1:
177
+ waveform = waveform.unsqueeze(0)
178
+
179
+ # 1. Resample to target rate first (skips if target_sr == in_sr).
180
+ if self.target_sr != in_sr:
181
+ wav_np = waveform.cpu().float().numpy()
182
+ wav_np = librosa.resample(
183
+ wav_np, orig_sr=in_sr, target_sr=self.target_sr, res_type="kaiser_best"
184
+ )
185
+ wav = torch.from_numpy(wav_np).to(self.device, dtype=torch.float32)
186
+ else:
187
+ wav = waveform.to(self.device, dtype=torch.float32)
188
+
189
+ op_sr = self.target_sr
190
+ n_fft = self._make_even(base_n_fft * op_sr // base_sr)
191
+ hop = self._make_even(base_hop * op_sr // base_sr)
192
+ win = self._make_even(base_win * op_sr // base_sr)
193
+
194
+ # 2. Chunked OLA with Hann analysis window. Mirrors inference_chunk.py.
195
+ chunk_size = int(self.chunk_size_s * op_sr)
196
+ hop_length = int(self.hop_portion * chunk_size)
197
+ window = torch.hann_window(chunk_size, device=self.device)
198
+
199
+ n_ch, total = wav.shape
200
+ enhanced = torch.zeros_like(wav)
201
+ window_sum = torch.zeros_like(wav)
202
+ n_chunks = max(1, math.ceil((total - chunk_size) / hop_length) + 1) if total > chunk_size else 1
203
+
204
+ for c in range(n_ch):
205
+ ch_in = wav[c : c + 1] # (1, T)
206
+ for i in range(n_chunks):
207
+ start = i * hop_length
208
+ end = min(start + chunk_size, total)
209
+ chunk = ch_in[:, start:end]
210
+ if chunk.shape[-1] < 2: # skip degenerate tail
211
+ continue
212
+ noisy_mag, noisy_pha, _ = mag_phase_stft(
213
+ chunk, n_fft=n_fft, hop_size=hop, win_size=win,
214
+ compress_factor=compress_factor, center=True, addeps=False,
215
+ )
216
+ amp_g, pha_g, _ = self._model(noisy_mag, noisy_pha)
217
+ # "Sweep artifact" filter — match the official inference.
218
+ mag = torch.expm1(torch.relu(amp_g))
219
+ zero_portion = (mag == 0).sum(dim=1) / mag.shape[1]
220
+ amp_g[:, :, (zero_portion > 0.5)[0]] = 0
221
+
222
+ audio_g = mag_phase_istft(amp_g, pha_g, n_fft, hop, win, compress_factor)
223
+ audio_g = pad_or_trim_to_match(chunk.detach(), audio_g, pad_value=1e-8)
224
+
225
+ w_slice = window[: audio_g.shape[-1]]
226
+ enhanced[c : c + 1, start : start + audio_g.shape[-1]] += audio_g * w_slice
227
+ window_sum[c : c + 1, start : start + audio_g.shape[-1]] += w_slice
228
+
229
+ # 3. Normalize where windows overlap. Avoid divide-by-zero at clip tails.
230
+ mask = window_sum > 1e-8
231
+ enhanced[mask] = enhanced[mask] / window_sum[mask]
232
+ return enhanced.clamp(-1.0, 1.0).cpu().float(), op_sr
src/text_chunker.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prompt chunking for long-form DramaBox generation.
2
+
3
+ The base LTX-2.3 audio DiT was trained on clips <= ~20 s. The silence-prior
4
+ patch in ``inference_server.py`` keeps generations sane up to ~45 s, but the
5
+ prior re-emerges past that boundary. For arbitrary-length prompts we split the
6
+ text into < 45 s chunks, generate each conditioned on the same voice reference,
7
+ and crossfade them back together.
8
+
9
+ Chunking is quote-aware (sentence terminators inside ``"..."`` don't count)
10
+ and preserves the speaker-description prefix on every chunk so the model keeps
11
+ the same persona / delivery style across joins.
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import re
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional
18
+
19
+
20
+ # Matches the leading speaker description, ending at the first comma that's
21
+ # directly followed by a space + opening quote. Anything before that is treated
22
+ # as persona/style metadata and re-attached to every chunk.
23
+ # "A shadowy villain speaks with cold menace, \"You have entered...\""
24
+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
25
+ _PREFIX_RE = re.compile(r'^([^"\']{3,}?)(,\s*)(?=["\'])', re.DOTALL)
26
+
27
+
28
+ @dataclass
29
+ class PromptChunk:
30
+ text: str
31
+ est_duration_s: float
32
+
33
+
34
+ def extract_speaker_prefix(prompt: str) -> tuple[Optional[str], str]:
35
+ """Return ``(prefix, body)`` where ``prefix`` is the speaker description.
36
+
37
+ If the prompt has the canonical ``"<persona>, "<dialogue>"..."`` form, the
38
+ persona (without the trailing comma) is returned as the prefix and the rest
39
+ of the prompt as the body. Otherwise ``(None, prompt)`` — no prefix to
40
+ propagate, the whole prompt is treated as a single body.
41
+ """
42
+ m = _PREFIX_RE.match(prompt)
43
+ if not m:
44
+ return None, prompt
45
+ return m.group(1).strip(), prompt[m.end():]
46
+
47
+
48
+ def split_sentences_outside_quotes(text: str) -> List[str]:
49
+ """Split ``text`` into sentences, ignoring terminators inside quotes.
50
+
51
+ A "sentence" here is a span ending in ``.``/``!``/``?`` (optionally followed
52
+ by a closing quote) at the top level — i.e. not inside an open ``"..."`` or
53
+ ``'...'`` pair. Empty / whitespace-only fragments are dropped.
54
+
55
+ Examples:
56
+ >>> split_sentences_outside_quotes('He says, "Hi, how are you?" Then leaves.')
57
+ ['He says, "Hi, how are you?"', 'Then leaves.']
58
+ """
59
+ sentences: List[str] = []
60
+ buf: List[str] = []
61
+ in_double = False
62
+ in_single = False
63
+ i = 0
64
+ n = len(text)
65
+ while i < n:
66
+ ch = text[i]
67
+ buf.append(ch)
68
+
69
+ if ch == '"' and not in_single:
70
+ was_inside = in_double
71
+ in_double = not in_double
72
+ # Treat the *closing* quote as a sentence boundary if the last
73
+ # meaningful char inside it was a terminator: ``...how are you?"``.
74
+ if was_inside and len(buf) >= 2 and buf[-2] in ".!?":
75
+ # Boundary requires whitespace / end-of-string after.
76
+ if i + 1 >= n or text[i + 1].isspace():
77
+ sentence = "".join(buf).strip()
78
+ if sentence:
79
+ sentences.append(sentence)
80
+ buf = []
81
+ i += 1
82
+ continue
83
+
84
+ elif ch == "'" and not in_double:
85
+ # Apostrophes inside a word (don't, it's) are not quote toggles.
86
+ prev = text[i - 1] if i > 0 else " "
87
+ nxt = text[i + 1] if i + 1 < n else " "
88
+ if not (prev.isalpha() and nxt.isalpha()):
89
+ in_single = not in_single
90
+
91
+ elif ch in ".!?" and not in_double and not in_single:
92
+ # Greedily eat trailing closing quotes / punctuation.
93
+ j = i + 1
94
+ while j < n and text[j] in '."\')]':
95
+ buf.append(text[j])
96
+ if text[j] == '"':
97
+ in_double = not in_double # closing quote toggle
98
+ j += 1
99
+ if j >= n or text[j].isspace():
100
+ sentence = "".join(buf).strip()
101
+ if sentence:
102
+ sentences.append(sentence)
103
+ buf = []
104
+ i = j
105
+ continue
106
+ i += 1
107
+
108
+ tail = "".join(buf).strip()
109
+ if tail:
110
+ sentences.append(tail)
111
+ return sentences
112
+
113
+
114
+ def _assemble(prefix: Optional[str], sentences: List[str]) -> str:
115
+ body = " ".join(s.strip() for s in sentences if s.strip())
116
+ if not prefix:
117
+ return body
118
+ # Re-attach prefix in the canonical "persona, body" form. If the first
119
+ # sentence already starts with a stage direction (no opening quote), drop
120
+ # the comma + use a period so the syntax reads naturally.
121
+ if body.lstrip().startswith(("'", '"')):
122
+ return f"{prefix}, {body}"
123
+ return f"{prefix}. {body}"
124
+
125
+
126
+ def chunk_prompt_for_duration(
127
+ prompt: str,
128
+ max_duration_s: float = 45.0,
129
+ target_duration_s: float = 37.0,
130
+ duration_multiplier: float = 1.1,
131
+ ) -> List[PromptChunk]:
132
+ """Split ``prompt`` into <= ``max_duration_s`` chunks.
133
+
134
+ Args:
135
+ prompt: Full scene prompt (DramaBox format or plain text).
136
+ max_duration_s: Hard cap per chunk; we never emit a chunk whose
137
+ estimator output (after ``duration_multiplier``) exceeds this.
138
+ target_duration_s: Soft cap; we close the current chunk when adding
139
+ the next sentence would push it past this. Leaving 5-10 s of
140
+ headroom below ``max_duration_s`` keeps us safe against the
141
+ estimator under-shooting by ~10-15% on action-heavy prompts.
142
+ duration_multiplier: Same breathing-room multiplier the inference
143
+ server applies in ``estimate_duration``; matches the per-chunk
144
+ target the model is actually asked to generate.
145
+
146
+ Returns:
147
+ List of :class:`PromptChunk`. Single-chunk prompts return a 1-element
148
+ list with the original prompt unchanged.
149
+ """
150
+ from duration_estimator import estimate_speech_duration
151
+
152
+ def _est(t: str) -> float:
153
+ return estimate_speech_duration(t) * duration_multiplier
154
+
155
+ total = _est(prompt)
156
+ if total <= max_duration_s:
157
+ return [PromptChunk(text=prompt, est_duration_s=total)]
158
+
159
+ prefix, body = extract_speaker_prefix(prompt)
160
+ sentences = split_sentences_outside_quotes(body)
161
+ if not sentences:
162
+ # Degenerate: no sentence boundaries. Fall back to whitespace-token
163
+ # chunking so we still produce SOMETHING under the cap.
164
+ sentences = body.split()
165
+
166
+ chunks: List[PromptChunk] = []
167
+ current: List[str] = []
168
+ current_dur = 0.0
169
+
170
+ for sent in sentences:
171
+ candidate = _assemble(prefix, current + [sent])
172
+ cand_dur = _est(candidate)
173
+
174
+ if current and cand_dur > target_duration_s:
175
+ # Close the current chunk before adding this sentence.
176
+ assembled = _assemble(prefix, current)
177
+ chunks.append(PromptChunk(text=assembled, est_duration_s=_est(assembled)))
178
+ current = [sent]
179
+ current_dur = _est(_assemble(prefix, current))
180
+ else:
181
+ current.append(sent)
182
+ current_dur = cand_dur
183
+
184
+ # Pathological case: a single sentence whose estimator output is
185
+ # already past max_duration_s. Emit it on its own and let downstream
186
+ # generate() truncate the request at the model's hard limit; the user
187
+ # gets a degraded but non-crashing result instead of an exception.
188
+ if len(current) == 1 and current_dur > max_duration_s:
189
+ solo = _assemble(prefix, current)
190
+ chunks.append(PromptChunk(text=solo, est_duration_s=current_dur))
191
+ current = []
192
+ current_dur = 0.0
193
+
194
+ if current:
195
+ assembled = _assemble(prefix, current)
196
+ chunks.append(PromptChunk(text=assembled, est_duration_s=_est(assembled)))
197
+
198
+ return chunks