Spaces:
Running on Zero
Long-form chunking + RE-USE on reference
Browse filesPort long-form generation from the upstream DramaBox repo (commit 382a37c)
to the HF Space:
- Auto chunk-and-stitch when the prompt's estimated (or explicit) duration
exceeds the max_chunk_duration cap; quote-aware sentence splitter (new
src/text_chunker.py) and shared duration estimator (new
src/duration_estimator.py). Chunks are stitched with an equal-power
crossfade so independently-generated joins are inaudible.
- RE-USE input-side voice-reference denoise (new src/super_resolution.py).
Applied to the *reference* before VAE encoding so the model conditions on
a clean speaker / style anchor; the generated paralinguistic content
(laughs, breaths, sighs) stays untouched. Cached per session so chunked
runs don't re-denoise the same reference per chunk. Silently falls back
if the mamba_ssm / causal-conv1d kernels can't be loaded.
- /generate_audio API gains denoise_ref + chunking knobs with sensible
defaults, so the existing index.html client (which sends only the
original kwargs) keeps working.
- @spaces.GPU duration bumped 60s → 600s for multi-chunk runs.
- requirements.txt: add resampy / mamba-ssm / causal-conv1d (Linux only;
optional — failures fall back to skipping the denoise).
- app.py +16 -4
- requirements.txt +12 -0
- src/duration_estimator.py +140 -0
- src/inference.py +8 -145
- src/inference_server.py +216 -9
- src/model_downloader.py +38 -4
- src/super_resolution.py +232 -0
- src/text_chunker.py +198 -0
|
@@ -183,7 +183,7 @@ async def homepage():
|
|
| 183 |
|
| 184 |
|
| 185 |
@app.api()
|
| 186 |
-
@spaces.GPU(duration=
|
| 187 |
def generate_audio(
|
| 188 |
prompt: str,
|
| 189 |
audio_ref: FileData | None,
|
|
@@ -192,11 +192,15 @@ def generate_audio(
|
|
| 192 |
dur_mult: float,
|
| 193 |
gen_dur: float,
|
| 194 |
ref_dur: float,
|
| 195 |
-
seed: int
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
) -> FileData:
|
| 197 |
if not prompt or not prompt.strip():
|
| 198 |
raise gr.Error("Prompt is empty.")
|
| 199 |
-
|
| 200 |
t0 = time.time()
|
| 201 |
ref_path = None
|
| 202 |
if audio_ref:
|
|
@@ -206,8 +210,12 @@ def generate_audio(
|
|
| 206 |
ref_path = audio_ref.path
|
| 207 |
if ref_path and not os.path.exists(ref_path):
|
| 208 |
ref_path = None
|
| 209 |
-
|
| 210 |
output = tempfile.mktemp(suffix=".wav", prefix="dramabox_")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
tts.generate_to_file(
|
| 212 |
prompt=prompt,
|
| 213 |
output=output,
|
|
@@ -218,6 +226,10 @@ def generate_audio(
|
|
| 218 |
seed=int(seed),
|
| 219 |
gen_duration=float(gen_dur),
|
| 220 |
ref_duration=float(ref_dur),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
)
|
| 222 |
elapsed = time.time() - t0
|
| 223 |
logging.info(f"Generated in {elapsed:.2f}s -> {output}")
|
|
|
|
| 183 |
|
| 184 |
|
| 185 |
@app.api()
|
| 186 |
+
@spaces.GPU(duration=600)
|
| 187 |
def generate_audio(
|
| 188 |
prompt: str,
|
| 189 |
audio_ref: FileData | None,
|
|
|
|
| 192 |
dur_mult: float,
|
| 193 |
gen_dur: float,
|
| 194 |
ref_dur: float,
|
| 195 |
+
seed: int,
|
| 196 |
+
denoise_ref: bool = True,
|
| 197 |
+
max_chunk_duration: float = 45.0,
|
| 198 |
+
target_chunk_duration: float = 37.0,
|
| 199 |
+
crossfade_ms: float = 50.0,
|
| 200 |
) -> FileData:
|
| 201 |
if not prompt or not prompt.strip():
|
| 202 |
raise gr.Error("Prompt is empty.")
|
| 203 |
+
|
| 204 |
t0 = time.time()
|
| 205 |
ref_path = None
|
| 206 |
if audio_ref:
|
|
|
|
| 210 |
ref_path = audio_ref.path
|
| 211 |
if ref_path and not os.path.exists(ref_path):
|
| 212 |
ref_path = None
|
| 213 |
+
|
| 214 |
output = tempfile.mktemp(suffix=".wav", prefix="dramabox_")
|
| 215 |
+
# Long-form: generate_to_file auto-routes to the chunk-and-stitch path when
|
| 216 |
+
# the estimated (or explicit gen_dur) duration exceeds max_chunk_duration.
|
| 217 |
+
# denoise_ref runs RE-USE on the voice reference before VAE encoding so the
|
| 218 |
+
# model conditions on a cleaner speaker / style anchor.
|
| 219 |
tts.generate_to_file(
|
| 220 |
prompt=prompt,
|
| 221 |
output=output,
|
|
|
|
| 226 |
seed=int(seed),
|
| 227 |
gen_duration=float(gen_dur),
|
| 228 |
ref_duration=float(ref_dur),
|
| 229 |
+
denoise_ref=bool(denoise_ref),
|
| 230 |
+
max_chunk_duration=float(max_chunk_duration),
|
| 231 |
+
target_chunk_duration=float(target_chunk_duration),
|
| 232 |
+
crossfade_ms=float(crossfade_ms),
|
| 233 |
)
|
| 234 |
elapsed = time.time() - t0
|
| 235 |
logging.info(f"Generated in {elapsed:.2f}s -> {output}")
|
|
@@ -24,3 +24,15 @@ gradio==6.14.0
|
|
| 24 |
spaces>=0.30.0
|
| 25 |
soundfile>=0.12.0
|
| 26 |
resemble-perth @ git+https://github.com/resemble-ai/Perth.git@master
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
spaces>=0.30.0
|
| 25 |
soundfile>=0.12.0
|
| 26 |
resemble-perth @ git+https://github.com/resemble-ai/Perth.git@master
|
| 27 |
+
|
| 28 |
+
# ── Optional: NVIDIA RE-USE speech enhancement (input-side voice-ref denoise) ─
|
| 29 |
+
# RE-USE is applied to the uploaded voice reference before VAE encoding so the
|
| 30 |
+
# model conditions on a clean speaker / style anchor (generated paralinguistic
|
| 31 |
+
# events — laughs, breaths, sighs — stay untouched because the denoiser only
|
| 32 |
+
# touches the reference). resampy is used for its pre-resample step; mamba_ssm
|
| 33 |
+
# + causal-conv1d power the bi-Mamba kernels. The kernels have no pre-built
|
| 34 |
+
# wheels for every CUDA toolkit — if installs fail, app.py logs a warning
|
| 35 |
+
# once and silently skips the reference denoise for the rest of the session.
|
| 36 |
+
resampy>=0.4.0
|
| 37 |
+
mamba-ssm>=2.2.0 ; platform_system == "Linux"
|
| 38 |
+
causal-conv1d>=1.4.0 ; platform_system == "Linux"
|
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pure-Python speech-duration estimator for DramaBox prompts.
|
| 2 |
+
|
| 3 |
+
Originally lived in ``inference.py`` but pulled out so chunkers / tooling /
|
| 4 |
+
unit tests can import it without dragging torch + the LTX pipeline through
|
| 5 |
+
sys.path. ``inference.py`` and ``inference_server.py`` continue to import
|
| 6 |
+
``estimate_speech_duration`` from here.
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import re
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_LAUGH_VERBS = {
|
| 14 |
+
# base seconds per occurrence; gets scaled by the modifier found nearby.
|
| 15 |
+
# Verb regex covers inflections: laugh/laughs/laughed/laughing.
|
| 16 |
+
r"\blaugh(?:s|ed|ing)?\b": 1.5,
|
| 17 |
+
r"\bcackl(?:e|es|ed|ing)\b": 1.5,
|
| 18 |
+
r"\bchuckl(?:e|es|ed|ing)\b": 1.0,
|
| 19 |
+
r"\bgiggl(?:e|es|ed|ing)\b": 1.0,
|
| 20 |
+
r"\bsnicker(?:s|ed|ing)?\b": 0.8,
|
| 21 |
+
r"\bcru?el laugh\b": 1.5,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _contextual_laugh_duration(text: str) -> float:
|
| 26 |
+
"""Context-aware laugh budget.
|
| 27 |
+
|
| 28 |
+
For each laugh verb in the prompt, look at the adjective/adverb that
|
| 29 |
+
modifies it and scale the base duration:
|
| 30 |
+
- short modifiers (briefly, softly, once) -> 0.4x base
|
| 31 |
+
- long modifiers (maniacally, heartily, ...) -> 1.2x base
|
| 32 |
+
- default (no mod / neutral) -> 1.0x base
|
| 33 |
+
Also reward phonetic repetition inside quotes -- 'Hahahahahaha' buys more
|
| 34 |
+
time than 'Haha' -- at ~0.2s per extra repeated syllable.
|
| 35 |
+
"""
|
| 36 |
+
short_mod = re.compile(
|
| 37 |
+
r"^\s*(?:[a-z]+ly )?(?:briefly|shortly|once|quickly)",
|
| 38 |
+
re.IGNORECASE)
|
| 39 |
+
long_mod = re.compile(
|
| 40 |
+
r"^\s*(?:[a-z]+ly )?(?:maniacally|heartily|uproariously|uncontrollably|"
|
| 41 |
+
r"hysterically|darkly|wickedly|evilly|loudly|long)"
|
| 42 |
+
r"|^\s*between phrases", re.IGNORECASE)
|
| 43 |
+
|
| 44 |
+
total = 0.0
|
| 45 |
+
for pat, base_dur in _LAUGH_VERBS.items():
|
| 46 |
+
for m in re.finditer(pat, text, re.IGNORECASE):
|
| 47 |
+
ctx = text[m.end(): m.end() + 40]
|
| 48 |
+
if short_mod.match(ctx):
|
| 49 |
+
total += base_dur * 0.4
|
| 50 |
+
elif long_mod.match(ctx):
|
| 51 |
+
total += base_dur * 1.2
|
| 52 |
+
else:
|
| 53 |
+
total += base_dur
|
| 54 |
+
|
| 55 |
+
# Phonetic laugh repetition inside quotes.
|
| 56 |
+
for q in re.findall(r'"([^"]+)"', text) + re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text):
|
| 57 |
+
for run in re.findall(r"(?:h[ae]){3,}|(?:h[ae][ \-]?){3,}", q, re.IGNORECASE):
|
| 58 |
+
syls = len(re.findall(r"h[ae]", run, re.IGNORECASE))
|
| 59 |
+
total += 0.2 * max(syls - 2, 0)
|
| 60 |
+
return total
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _estimate_nonverbal_duration(text: str) -> float:
|
| 64 |
+
"""Estimate extra duration for non-verbal sounds and actions in the prompt.
|
| 65 |
+
|
| 66 |
+
Laugh-verb handling lives in ``_contextual_laugh_duration`` so cackle /
|
| 67 |
+
chuckle / laugh budgets scale with the adjective ("maniacally" vs
|
| 68 |
+
"briefly") and with the repetition length of 'Ha'/'He' tokens inside
|
| 69 |
+
quotes.
|
| 70 |
+
"""
|
| 71 |
+
PATTERNS = {
|
| 72 |
+
r'\bsighs?\b': 0.8, r'\bshaky breath\b': 1.0, r'\bbreathing deeply\b': 1.0,
|
| 73 |
+
r'\bgasps?\b': 0.5, r'\bburps?\b': 0.5, r'\byawns?\b': 1.0,
|
| 74 |
+
r'\bpants?\b': 0.8, r'\bwheezes?\b': 0.8, r'\bcoughs?\b': 0.8,
|
| 75 |
+
r'\bsniffles?\b': 0.5, r'\bsnorts?\b': 0.3, r'\bgroans?\b': 0.8,
|
| 76 |
+
r'\blong pause\b': 1.0, r'\bpauses? briefly\b': 0.3,
|
| 77 |
+
r'\bpauses?\b': 0.5, r'\bsilence\b': 1.0,
|
| 78 |
+
r'\blets? the .{1,20} hang\b': 1.0, r'\blets? .{1,20} sink in\b': 1.0,
|
| 79 |
+
r'\bslams?\b': 0.5, r'\bclaps?\b': 0.3,
|
| 80 |
+
r'\bdraws? (?:his|her|a) sword\b': 0.5,
|
| 81 |
+
r'\btakes? a (?:drag|swig|sip|drink)\b': 0.5,
|
| 82 |
+
r'\bwhistles?\b': 1.0, r'\bhums?\b': 0.8,
|
| 83 |
+
r'\bmutters?\b': 1.5, r'\bmumbles?\b': 1.0, r'\bwhispers?\b': 0.0,
|
| 84 |
+
r'\bclears? (?:his|her) throat\b': 0.5, r'\bgulps?\b': 0.5,
|
| 85 |
+
r'\bswallows?\b': 0.5,
|
| 86 |
+
r'\bvoice (?:breaks?|cracks?|trembles?|drops?|rises?)\b': 0.5,
|
| 87 |
+
r'\bsteadies? (?:him|her)self\b': 1.0,
|
| 88 |
+
r'\bcatches? (?:his|her) breath\b': 1.0,
|
| 89 |
+
r'\bcomposes? (?:him|her)self\b': 0.8,
|
| 90 |
+
r'\bdemeanor shifts?\b': 0.5, r'\bsettles? in\b': 0.5,
|
| 91 |
+
r'\bleans? in\b': 0.3, r'\bwipes? (?:his|her) eyes\b': 0.5,
|
| 92 |
+
}
|
| 93 |
+
extra = 0.0
|
| 94 |
+
for pattern, dur in PATTERNS.items():
|
| 95 |
+
extra += dur * len(re.findall(pattern, text, re.IGNORECASE))
|
| 96 |
+
extra += _contextual_laugh_duration(text)
|
| 97 |
+
return extra
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def estimate_speech_duration(text: str, speed: float = 1.0) -> float:
|
| 101 |
+
"""Estimate speech duration from spoken content + non-verbal actions.
|
| 102 |
+
|
| 103 |
+
Extracts spoken text by priority:
|
| 104 |
+
1. Quoted text ('...' or "...") -- official prompt guide format
|
| 105 |
+
2. Text after colon -- simple "Speaker: dialogue" format
|
| 106 |
+
3. Full text -- fallback
|
| 107 |
+
|
| 108 |
+
Also scans the full prompt for non-verbal cues (laughs, pauses, sighs,
|
| 109 |
+
gasps, etc.) and adds estimated duration for each.
|
| 110 |
+
"""
|
| 111 |
+
quotes = re.findall(r'"([^"]+)"', text)
|
| 112 |
+
if not quotes:
|
| 113 |
+
quotes = re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text)
|
| 114 |
+
quotes = [q for q in quotes if len(q.split()) > 3]
|
| 115 |
+
if quotes:
|
| 116 |
+
spoken = " ".join(quotes)
|
| 117 |
+
elif ":" in text:
|
| 118 |
+
spoken = text.split(":", 1)[1].strip()
|
| 119 |
+
else:
|
| 120 |
+
spoken = text
|
| 121 |
+
|
| 122 |
+
CHARS_PER_SEC = 14.0
|
| 123 |
+
text_len = len(spoken)
|
| 124 |
+
|
| 125 |
+
if text_len < 40:
|
| 126 |
+
chars_per_sec = CHARS_PER_SEC * 0.6
|
| 127 |
+
elif text_len < 80:
|
| 128 |
+
chars_per_sec = CHARS_PER_SEC * 0.8
|
| 129 |
+
else:
|
| 130 |
+
chars_per_sec = CHARS_PER_SEC
|
| 131 |
+
|
| 132 |
+
chars_per_sec *= speed
|
| 133 |
+
duration = text_len / chars_per_sec
|
| 134 |
+
|
| 135 |
+
sentence_count = spoken.count(".") + spoken.count("!") + spoken.count("?")
|
| 136 |
+
duration += sentence_count * 0.3
|
| 137 |
+
|
| 138 |
+
duration += _estimate_nonverbal_duration(text)
|
| 139 |
+
|
| 140 |
+
return max(3.0, round(duration + 2.0, 1))
|
|
@@ -74,151 +74,14 @@ def detect_model_type(checkpoint_path: str) -> str:
|
|
| 74 |
return "distilled"
|
| 75 |
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
r"\bcru?el laugh\b": 1.5,
|
| 86 |
-
}
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def _contextual_laugh_duration(text: str) -> float:
|
| 90 |
-
"""Context-aware laugh budget.
|
| 91 |
-
|
| 92 |
-
For each laugh verb in the prompt, look at the adjective/adverb that
|
| 93 |
-
modifies it and scale the base duration:
|
| 94 |
-
- short modifiers (briefly, softly, once) -> 0.4x base
|
| 95 |
-
- long modifiers (maniacally, heartily, ...) -> 1.2x base
|
| 96 |
-
- default (no mod / neutral) -> 1.0x base
|
| 97 |
-
Also reward phonetic repetition inside quotes -- 'Hahahahahaha' buys more
|
| 98 |
-
time than 'Haha' -- at ~0.2s per extra repeated syllable.
|
| 99 |
-
"""
|
| 100 |
-
# "softly" / "quietly" describe volume not length, so keep at default 1.0x.
|
| 101 |
-
short_mod = re.compile(
|
| 102 |
-
r"^\s*(?:[a-z]+ly )?(?:briefly|shortly|once|quickly)",
|
| 103 |
-
re.IGNORECASE)
|
| 104 |
-
long_mod = re.compile(
|
| 105 |
-
r"^\s*(?:[a-z]+ly )?(?:maniacally|heartily|uproariously|uncontrollably|"
|
| 106 |
-
r"hysterically|darkly|wickedly|evilly|loudly|long)"
|
| 107 |
-
r"|^\s*between phrases", re.IGNORECASE)
|
| 108 |
-
|
| 109 |
-
total = 0.0
|
| 110 |
-
for pat, base_dur in _LAUGH_VERBS.items():
|
| 111 |
-
for m in re.finditer(pat, text, re.IGNORECASE):
|
| 112 |
-
ctx = text[m.end(): m.end() + 40]
|
| 113 |
-
if short_mod.match(ctx):
|
| 114 |
-
total += base_dur * 0.4
|
| 115 |
-
elif long_mod.match(ctx):
|
| 116 |
-
total += base_dur * 1.2
|
| 117 |
-
else:
|
| 118 |
-
total += base_dur
|
| 119 |
-
|
| 120 |
-
# Phonetic laugh repetition inside quotes:
|
| 121 |
-
# 'Haha' = 2 syllables (base, no bonus)
|
| 122 |
-
# 'Hahahaha' = 4 syllables (+0.4s)
|
| 123 |
-
# 'Hehehehahahahahahahaha' ~ 10 syllables (+1.6s)
|
| 124 |
-
for q in re.findall(r'"([^"]+)"', text) + re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text):
|
| 125 |
-
for run in re.findall(r"(?:h[ae]){3,}|(?:h[ae][ \-]?){3,}", q, re.IGNORECASE):
|
| 126 |
-
syls = len(re.findall(r"h[ae]", run, re.IGNORECASE))
|
| 127 |
-
total += 0.2 * max(syls - 2, 0)
|
| 128 |
-
return total
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def _estimate_nonverbal_duration(text: str) -> float:
|
| 132 |
-
"""Estimate extra duration for non-verbal sounds and actions in the prompt.
|
| 133 |
-
|
| 134 |
-
Laugh-verb handling lives in ``_contextual_laugh_duration`` so cackle /
|
| 135 |
-
chuckle / laugh budgets scale with the adjective ("maniacally" vs
|
| 136 |
-
"briefly") and with the repetition length of 'Ha'/'He' tokens inside
|
| 137 |
-
quotes.
|
| 138 |
-
"""
|
| 139 |
-
PATTERNS = {
|
| 140 |
-
# Breathing / sighs
|
| 141 |
-
r'\bsighs?\b': 0.8, r'\bshaky breath\b': 1.0, r'\bbreathing deeply\b': 1.0,
|
| 142 |
-
r'\bgasps?\b': 0.5, r'\bburps?\b': 0.5, r'\byawns?\b': 1.0,
|
| 143 |
-
r'\bpants?\b': 0.8, r'\bwheezes?\b': 0.8, r'\bcoughs?\b': 0.8,
|
| 144 |
-
r'\bsniffles?\b': 0.5, r'\bsnorts?\b': 0.3, r'\bgroans?\b': 0.8,
|
| 145 |
-
# Pauses (trimmed; earlier values over-budgeted silence)
|
| 146 |
-
r'\blong pause\b': 1.0, r'\bpauses? briefly\b': 0.3,
|
| 147 |
-
r'\bpauses?\b': 0.5, r'\bsilence\b': 1.0,
|
| 148 |
-
r'\blets? the .{1,20} hang\b': 1.0, r'\blets? .{1,20} sink in\b': 1.0,
|
| 149 |
-
# Physical actions that produce sound
|
| 150 |
-
r'\bslams?\b': 0.5, r'\bclaps?\b': 0.3,
|
| 151 |
-
r'\bdraws? (?:his|her|a) sword\b': 0.5,
|
| 152 |
-
r'\btakes? a (?:drag|swig|sip|drink)\b': 0.5,
|
| 153 |
-
r'\bwhistles?\b': 1.0, r'\bhums?\b': 0.8,
|
| 154 |
-
# Vocal actions (not in quotes but take time)
|
| 155 |
-
r'\bmutters?\b': 1.5, r'\bmumbles?\b': 1.0, r'\bwhispers?\b': 0.0,
|
| 156 |
-
r'\bclears? (?:his|her) throat\b': 0.5, r'\bgulps?\b': 0.5,
|
| 157 |
-
r'\bswallows?\b': 0.5,
|
| 158 |
-
# (laugh / chuckle / cackle / giggle / snicker handled by
|
| 159 |
-
# _contextual_laugh_duration below -- modifier-aware, not flat.)
|
| 160 |
-
# Emotional transitions
|
| 161 |
-
r'\bvoice (?:breaks?|cracks?|trembles?|drops?|rises?)\b': 0.5,
|
| 162 |
-
r'\bsteadies? (?:him|her)self\b': 1.0,
|
| 163 |
-
r'\bcatches? (?:his|her) breath\b': 1.0,
|
| 164 |
-
r'\bcomposes? (?:him|her)self\b': 0.8,
|
| 165 |
-
# Scene transitions that imply time
|
| 166 |
-
r'\bdemeanor shifts?\b': 0.5, r'\bsettles? in\b': 0.5,
|
| 167 |
-
r'\bleans? in\b': 0.3, r'\bwipes? (?:his|her) eyes\b': 0.5,
|
| 168 |
-
}
|
| 169 |
-
extra = 0.0
|
| 170 |
-
for pattern, dur in PATTERNS.items():
|
| 171 |
-
extra += dur * len(re.findall(pattern, text, re.IGNORECASE))
|
| 172 |
-
extra += _contextual_laugh_duration(text)
|
| 173 |
-
return extra
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
def estimate_speech_duration(text: str, speed: float = 1.0) -> float:
|
| 177 |
-
"""Estimate speech duration from spoken content + non-verbal actions.
|
| 178 |
-
|
| 179 |
-
Extracts spoken text by priority:
|
| 180 |
-
1. Quoted text ('...' or "...") -- official prompt guide format
|
| 181 |
-
2. Text after colon -- simple "Speaker: dialogue" format
|
| 182 |
-
3. Full text -- fallback
|
| 183 |
-
|
| 184 |
-
Also scans the full prompt for non-verbal cues (laughs, pauses, sighs,
|
| 185 |
-
gasps, etc.) and adds estimated duration for each.
|
| 186 |
-
"""
|
| 187 |
-
# Try double quotes first (clean, no contraction issues)
|
| 188 |
-
quotes = re.findall(r'"([^"]+)"', text)
|
| 189 |
-
if not quotes:
|
| 190 |
-
# Single quotes: allow apostrophes in contractions (don't, can't, it's)
|
| 191 |
-
# Match ' to ' but apostrophes NOT followed by space/punctuation are kept inside
|
| 192 |
-
quotes = re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text)
|
| 193 |
-
# Filter out short fragments (scene directions like "He pauses")
|
| 194 |
-
quotes = [q for q in quotes if len(q.split()) > 3]
|
| 195 |
-
if quotes:
|
| 196 |
-
spoken = " ".join(quotes)
|
| 197 |
-
elif ":" in text:
|
| 198 |
-
spoken = text.split(":", 1)[1].strip()
|
| 199 |
-
else:
|
| 200 |
-
spoken = text
|
| 201 |
-
|
| 202 |
-
CHARS_PER_SEC = 14.0
|
| 203 |
-
text_len = len(spoken)
|
| 204 |
-
|
| 205 |
-
if text_len < 40:
|
| 206 |
-
chars_per_sec = CHARS_PER_SEC * 0.6
|
| 207 |
-
elif text_len < 80:
|
| 208 |
-
chars_per_sec = CHARS_PER_SEC * 0.8
|
| 209 |
-
else:
|
| 210 |
-
chars_per_sec = CHARS_PER_SEC
|
| 211 |
-
|
| 212 |
-
chars_per_sec *= speed
|
| 213 |
-
duration = text_len / chars_per_sec
|
| 214 |
-
|
| 215 |
-
sentence_count = spoken.count(".") + spoken.count("!") + spoken.count("?")
|
| 216 |
-
duration += sentence_count * 0.3
|
| 217 |
-
|
| 218 |
-
# Add time for non-verbal sounds/actions in the full prompt
|
| 219 |
-
duration += _estimate_nonverbal_duration(text)
|
| 220 |
-
|
| 221 |
-
return max(3.0, round(duration + 2.0, 1))
|
| 222 |
|
| 223 |
|
| 224 |
def parse_args():
|
|
|
|
| 74 |
return "distilled"
|
| 75 |
|
| 76 |
|
| 77 |
+
# Duration estimator lives in duration_estimator.py so that text_chunker and
|
| 78 |
+
# other tooling can import it without dragging the torch / LTX pipeline.
|
| 79 |
+
from duration_estimator import ( # noqa: E402,F401
|
| 80 |
+
estimate_speech_duration,
|
| 81 |
+
_contextual_laugh_duration,
|
| 82 |
+
_estimate_nonverbal_duration,
|
| 83 |
+
_LAUGH_VERBS,
|
| 84 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
def parse_args():
|
|
@@ -14,6 +14,7 @@ import re
|
|
| 14 |
import sys
|
| 15 |
import time
|
| 16 |
from pathlib import Path
|
|
|
|
| 17 |
|
| 18 |
import torch
|
| 19 |
import torchaudio
|
|
@@ -53,13 +54,40 @@ DEFAULT_NEG = "worst quality, inconsistent, robotic, distorted, noise, static, m
|
|
| 53 |
|
| 54 |
|
| 55 |
def estimate_duration(prompt, multiplier=1.1):
|
| 56 |
-
"""Defer to the
|
| 57 |
-
|
| 58 |
-
from
|
| 59 |
base = estimate_speech_duration(prompt)
|
| 60 |
return max(3.0, round(base * multiplier, 1))
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
def auto_rescale_for_cfg(cfg: float) -> float:
|
| 64 |
"""CFG-aware std-rescale schedule that prevents output clipping at high cfg.
|
| 65 |
|
|
@@ -110,6 +138,11 @@ class TTSServer:
|
|
| 110 |
self._velocity_model = None
|
| 111 |
self._audio_conditioner = None
|
| 112 |
self._audio_decoder = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
logging.info(f"TTSServer loading on {device}...")
|
| 115 |
t0 = time.time()
|
|
@@ -203,10 +236,78 @@ class TTSServer:
|
|
| 203 |
)
|
| 204 |
logging.info(f" AudioDecoder (warm): {time.time()-t0:.1f}s")
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
@torch.inference_mode()
|
| 207 |
def generate(self, prompt, voice_ref=None, cfg_scale=2.5, stg_scale=1.5,
|
| 208 |
duration_multiplier=1.1, seed=42, ref_duration=10.0,
|
| 209 |
-
rescale_scale="auto", gen_duration: float = 0.0
|
|
|
|
| 210 |
"""Generate audio. Returns (waveform_path, duration_seconds).
|
| 211 |
|
| 212 |
rescale_scale: latent-side CFG std-rescale that prevents clipping at
|
|
@@ -214,6 +315,10 @@ class TTSServer:
|
|
| 214 |
float in [0, 1] for a fixed override, or 0 to disable.
|
| 215 |
gen_duration: explicit target duration in seconds. 0 (default) → auto
|
| 216 |
from prompt + duration_multiplier; >0 overrides everything else.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
"""
|
| 218 |
t_total = time.time()
|
| 219 |
|
|
@@ -236,6 +341,8 @@ class TTSServer:
|
|
| 236 |
if voice_ref and os.path.exists(voice_ref):
|
| 237 |
t0 = time.time()
|
| 238 |
voice = decode_audio_from_file(voice_ref, self.device, 0.0, ref_duration)
|
|
|
|
|
|
|
| 239 |
w = voice.waveform
|
| 240 |
if w.dim() == 2:
|
| 241 |
if w.shape[0] == 1:
|
|
@@ -323,15 +430,115 @@ class TTSServer:
|
|
| 323 |
|
| 324 |
t0 = time.time()
|
| 325 |
decoded = self._audio_decoder(latent)
|
| 326 |
-
|
|
|
|
| 327 |
|
| 328 |
total = time.time() - t_total
|
| 329 |
-
dur =
|
| 330 |
logging.info(f"Total: {total:.2f}s for {dur:.1f}s audio")
|
| 331 |
-
return
|
| 332 |
|
| 333 |
-
|
| 334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
wav_cpu = waveform.cpu().float()
|
| 336 |
if watermark:
|
| 337 |
try:
|
|
|
|
| 14 |
import sys
|
| 15 |
import time
|
| 16 |
from pathlib import Path
|
| 17 |
+
from typing import Optional
|
| 18 |
|
| 19 |
import torch
|
| 20 |
import torchaudio
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
def estimate_duration(prompt, multiplier=1.1):
|
| 57 |
+
"""Defer to the shared sentence-aware + non-verbal action budget estimator
|
| 58 |
+
so warm-server outputs match the lengths of the per-call CLI runs."""
|
| 59 |
+
from duration_estimator import estimate_speech_duration
|
| 60 |
base = estimate_speech_duration(prompt)
|
| 61 |
return max(3.0, round(base * multiplier, 1))
|
| 62 |
|
| 63 |
|
| 64 |
+
def _equal_power_crossfade(prev: torch.Tensor, nxt: torch.Tensor,
|
| 65 |
+
sample_rate: int, fade_ms: float = 50.0) -> torch.Tensor:
|
| 66 |
+
"""Equal-power crossfade concat: ``[prev | nxt]`` with a smooth boundary.
|
| 67 |
+
|
| 68 |
+
Both tensors are (C, T). Returns (C, T_prev + T_nxt - T_fade).
|
| 69 |
+
|
| 70 |
+
Equal-power (cos/sin envelopes) keeps perceived loudness constant through
|
| 71 |
+
the join — unlike a linear fade, which dips by ~3 dB in the middle when
|
| 72 |
+
the two sources are uncorrelated. Default 50 ms is short enough to be
|
| 73 |
+
inaudible on speech while still masking any waveform-level discontinuity
|
| 74 |
+
between independently-generated chunks.
|
| 75 |
+
"""
|
| 76 |
+
fade_samples = int(round(fade_ms * 1e-3 * sample_rate))
|
| 77 |
+
fade_samples = max(1, min(fade_samples, prev.shape[-1], nxt.shape[-1]))
|
| 78 |
+
if fade_samples <= 1:
|
| 79 |
+
return torch.cat([prev, nxt], dim=-1)
|
| 80 |
+
|
| 81 |
+
t = torch.linspace(0.0, 1.0, fade_samples, device=prev.device, dtype=prev.dtype)
|
| 82 |
+
fade_out = torch.cos(t * torch.pi / 2) # 1.0 -> 0.0
|
| 83 |
+
fade_in = torch.sin(t * torch.pi / 2) # 0.0 -> 1.0
|
| 84 |
+
|
| 85 |
+
prev_tail = prev[..., -fade_samples:] * fade_out
|
| 86 |
+
nxt_head = nxt[..., :fade_samples] * fade_in
|
| 87 |
+
mixed = prev_tail + nxt_head
|
| 88 |
+
return torch.cat([prev[..., :-fade_samples], mixed, nxt[..., fade_samples:]], dim=-1)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
def auto_rescale_for_cfg(cfg: float) -> float:
|
| 92 |
"""CFG-aware std-rescale schedule that prevents output clipping at high cfg.
|
| 93 |
|
|
|
|
| 138 |
self._velocity_model = None
|
| 139 |
self._audio_conditioner = None
|
| 140 |
self._audio_decoder = None
|
| 141 |
+
# RE-USE denoiser for the voice reference (input-side denoise).
|
| 142 |
+
# Lazy-loaded on first use; the cleaned-waveform cache below keeps
|
| 143 |
+
# chunked generations from re-denoising the same 10 s clip per chunk.
|
| 144 |
+
self._ref_denoiser = None
|
| 145 |
+
self._ref_denoise_cache: dict[tuple, "torch.Tensor"] = {}
|
| 146 |
|
| 147 |
logging.info(f"TTSServer loading on {device}...")
|
| 148 |
t0 = time.time()
|
|
|
|
| 236 |
)
|
| 237 |
logging.info(f" AudioDecoder (warm): {time.time()-t0:.1f}s")
|
| 238 |
|
| 239 |
+
def _denoise_voice_ref(self, voice, voice_ref_path: str, ref_duration: float):
|
| 240 |
+
"""Run RE-USE on the loaded voice reference and replace its waveform
|
| 241 |
+
with a cleaned mono signal.
|
| 242 |
+
|
| 243 |
+
Why pre-condition rather than post-generate: applying RE-USE to the
|
| 244 |
+
*output* suppresses paralinguistic events the model generates (laughs,
|
| 245 |
+
gasps, breaths, sighs) because they're broadband, non-tonal — exactly
|
| 246 |
+
what universal speech enhancement targets as "noise". Running it on
|
| 247 |
+
the *reference* instead gives the model a clean speaker / style
|
| 248 |
+
anchor, which it generalises from at inference time, while leaving
|
| 249 |
+
the generated paralinguistic content untouched.
|
| 250 |
+
|
| 251 |
+
Cached by ``(path, ref_duration, sampling_rate)`` so chunked
|
| 252 |
+
generations don't re-denoise the same 10 s clip per chunk.
|
| 253 |
+
"""
|
| 254 |
+
cache_key = (voice_ref_path, float(ref_duration), int(voice.sampling_rate))
|
| 255 |
+
if cache_key in self._ref_denoise_cache:
|
| 256 |
+
return Audio(
|
| 257 |
+
waveform=self._ref_denoise_cache[cache_key],
|
| 258 |
+
sampling_rate=voice.sampling_rate,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Lazy-load the denoiser. target_sr = input sr → no librosa resample
|
| 262 |
+
# round-trip; RE-USE does pure denoise. (The 48 kHz BWE that
|
| 263 |
+
# REUSEUpsampler can do is irrelevant here — the VAE conditioner
|
| 264 |
+
# resamples internally to whatever the audio branch expects.)
|
| 265 |
+
if self._ref_denoiser is None:
|
| 266 |
+
from super_resolution import REUSEUpsampler
|
| 267 |
+
try:
|
| 268 |
+
self._ref_denoiser = REUSEUpsampler(
|
| 269 |
+
target_sr=int(voice.sampling_rate),
|
| 270 |
+
device=self.device,
|
| 271 |
+
chunk_size_s=1.0,
|
| 272 |
+
)
|
| 273 |
+
except Exception as e:
|
| 274 |
+
# Mamba kernels / weights missing → silently skip the denoise
|
| 275 |
+
# rather than blocking generation. Surfaces once per session.
|
| 276 |
+
logging.warning(f"Voice-ref denoise disabled (RE-USE unavailable: {e})")
|
| 277 |
+
self._ref_denoiser = False # sentinel: don't retry this session
|
| 278 |
+
return voice
|
| 279 |
+
|
| 280 |
+
if self._ref_denoiser is False:
|
| 281 |
+
return voice
|
| 282 |
+
|
| 283 |
+
w = voice.waveform
|
| 284 |
+
# Collapse to mono — voice cloning is speaker-as-mono-source; we'll
|
| 285 |
+
# re-broadcast back to stereo after the conditioner.
|
| 286 |
+
if w.dim() == 3:
|
| 287 |
+
mono = w[0].mean(dim=0)
|
| 288 |
+
elif w.dim() == 2:
|
| 289 |
+
mono = w.mean(dim=0)
|
| 290 |
+
else:
|
| 291 |
+
mono = w
|
| 292 |
+
mono = mono.contiguous()
|
| 293 |
+
|
| 294 |
+
t0 = time.time()
|
| 295 |
+
cleaned, _ = self._ref_denoiser(mono, in_sr=int(voice.sampling_rate))
|
| 296 |
+
if cleaned.dim() == 2 and cleaned.shape[0] == 1:
|
| 297 |
+
cleaned = cleaned[0]
|
| 298 |
+
# Restore the (1, C=1, T) shape that the rest of the pipeline expects
|
| 299 |
+
# to consume — downstream code re-expands channels via repeat().
|
| 300 |
+
cleaned = cleaned.unsqueeze(0).unsqueeze(0).to(self.device, dtype=w.dtype)
|
| 301 |
+
logging.info(f"Voice-ref denoise (RE-USE): {time.time() - t0:.2f}s")
|
| 302 |
+
|
| 303 |
+
self._ref_denoise_cache[cache_key] = cleaned
|
| 304 |
+
return Audio(waveform=cleaned, sampling_rate=voice.sampling_rate)
|
| 305 |
+
|
| 306 |
@torch.inference_mode()
|
| 307 |
def generate(self, prompt, voice_ref=None, cfg_scale=2.5, stg_scale=1.5,
|
| 308 |
duration_multiplier=1.1, seed=42, ref_duration=10.0,
|
| 309 |
+
rescale_scale="auto", gen_duration: float = 0.0,
|
| 310 |
+
denoise_ref: bool = True):
|
| 311 |
"""Generate audio. Returns (waveform_path, duration_seconds).
|
| 312 |
|
| 313 |
rescale_scale: latent-side CFG std-rescale that prevents clipping at
|
|
|
|
| 315 |
float in [0, 1] for a fixed override, or 0 to disable.
|
| 316 |
gen_duration: explicit target duration in seconds. 0 (default) → auto
|
| 317 |
from prompt + duration_multiplier; >0 overrides everything else.
|
| 318 |
+
denoise_ref: when True (default) and a voice reference is provided,
|
| 319 |
+
RE-USE is applied to the *reference* before VAE encoding so the
|
| 320 |
+
model conditions on a clean speaker / style anchor. Generated
|
| 321 |
+
output (24→48 kHz) always goes through the LTX BigVGAN BWE.
|
| 322 |
"""
|
| 323 |
t_total = time.time()
|
| 324 |
|
|
|
|
| 341 |
if voice_ref and os.path.exists(voice_ref):
|
| 342 |
t0 = time.time()
|
| 343 |
voice = decode_audio_from_file(voice_ref, self.device, 0.0, ref_duration)
|
| 344 |
+
if denoise_ref:
|
| 345 |
+
voice = self._denoise_voice_ref(voice, voice_ref, ref_duration)
|
| 346 |
w = voice.waveform
|
| 347 |
if w.dim() == 2:
|
| 348 |
if w.shape[0] == 1:
|
|
|
|
| 430 |
|
| 431 |
t0 = time.time()
|
| 432 |
decoded = self._audio_decoder(latent)
|
| 433 |
+
out_waveform, out_sr = decoded.waveform, decoded.sampling_rate
|
| 434 |
+
logging.info(f"Decode (LTX BWE): {time.time()-t0:.2f}s")
|
| 435 |
|
| 436 |
total = time.time() - t_total
|
| 437 |
+
dur = out_waveform.shape[-1] / out_sr
|
| 438 |
logging.info(f"Total: {total:.2f}s for {dur:.1f}s audio")
|
| 439 |
+
return out_waveform, out_sr
|
| 440 |
|
| 441 |
+
@torch.inference_mode()
|
| 442 |
+
def generate_long(self, prompt, max_chunk_duration: float = 45.0,
|
| 443 |
+
target_chunk_duration: float = 37.0,
|
| 444 |
+
crossfade_ms: float = 50.0,
|
| 445 |
+
progress_callback=None,
|
| 446 |
+
**kwargs):
|
| 447 |
+
"""Chunk-and-stitch generation for prompts whose estimated duration
|
| 448 |
+
exceeds ``max_chunk_duration``.
|
| 449 |
+
|
| 450 |
+
Splits ``prompt`` into <= ``max_chunk_duration`` chunks via
|
| 451 |
+
:func:`text_chunker.chunk_prompt_for_duration`, generates each one
|
| 452 |
+
through :meth:`generate` (same voice reference + seed for every
|
| 453 |
+
chunk, so speaker identity stays coherent across joins), and
|
| 454 |
+
concatenates the waveforms with an equal-power crossfade.
|
| 455 |
+
|
| 456 |
+
Returns ``(waveform, sample_rate)`` matching :meth:`generate`.
|
| 457 |
+
"""
|
| 458 |
+
from text_chunker import chunk_prompt_for_duration
|
| 459 |
+
|
| 460 |
+
# gen_duration / duration_multiplier are per-chunk; pop them out so we
|
| 461 |
+
# control sizing here and forward only the per-chunk values.
|
| 462 |
+
per_chunk_mul = float(kwargs.pop("duration_multiplier", 1.1))
|
| 463 |
+
# gen_duration coming in as a global target only makes sense for the
|
| 464 |
+
# single-shot path; chunked generation derives durations per chunk.
|
| 465 |
+
kwargs.pop("gen_duration", None)
|
| 466 |
+
|
| 467 |
+
chunks = chunk_prompt_for_duration(
|
| 468 |
+
prompt,
|
| 469 |
+
max_duration_s=max_chunk_duration,
|
| 470 |
+
target_duration_s=target_chunk_duration,
|
| 471 |
+
duration_multiplier=per_chunk_mul,
|
| 472 |
+
)
|
| 473 |
+
logging.info(f"Long-form: {len(chunks)} chunks (target {target_chunk_duration:.0f}s, "
|
| 474 |
+
f"max {max_chunk_duration:.0f}s)")
|
| 475 |
+
|
| 476 |
+
out_waveform: Optional[torch.Tensor] = None
|
| 477 |
+
out_sr: Optional[int] = None
|
| 478 |
+
t_total = time.time()
|
| 479 |
+
for idx, chunk in enumerate(chunks):
|
| 480 |
+
logging.info(f" Chunk {idx + 1}/{len(chunks)}: est {chunk.est_duration_s:.1f}s, "
|
| 481 |
+
f"{len(chunk.text)} chars")
|
| 482 |
+
if progress_callback is not None:
|
| 483 |
+
try:
|
| 484 |
+
progress_callback(idx, len(chunks), chunk.est_duration_s)
|
| 485 |
+
except Exception as e:
|
| 486 |
+
logging.warning(f"progress_callback raised, ignoring: {e}")
|
| 487 |
+
wav, sr = self.generate(
|
| 488 |
+
chunk.text,
|
| 489 |
+
duration_multiplier=per_chunk_mul,
|
| 490 |
+
**kwargs,
|
| 491 |
+
)
|
| 492 |
+
wav = wav.cpu().float()
|
| 493 |
+
if out_waveform is None:
|
| 494 |
+
out_waveform, out_sr = wav, sr
|
| 495 |
+
else:
|
| 496 |
+
if sr != out_sr:
|
| 497 |
+
raise RuntimeError(f"Sample-rate mismatch between chunks: {out_sr} vs {sr}")
|
| 498 |
+
# Align channel counts: stereo crossfade with a mono buddy
|
| 499 |
+
# broadcasts cleanly via torch.cat after equalising dim 0.
|
| 500 |
+
if wav.shape[0] != out_waveform.shape[0]:
|
| 501 |
+
if wav.shape[0] == 1:
|
| 502 |
+
wav = wav.repeat(out_waveform.shape[0], 1)
|
| 503 |
+
elif out_waveform.shape[0] == 1:
|
| 504 |
+
out_waveform = out_waveform.repeat(wav.shape[0], 1)
|
| 505 |
+
out_waveform = _equal_power_crossfade(out_waveform, wav, out_sr,
|
| 506 |
+
fade_ms=crossfade_ms)
|
| 507 |
+
|
| 508 |
+
total_dur = out_waveform.shape[-1] / out_sr
|
| 509 |
+
logging.info(f"Long-form total: {time.time() - t_total:.2f}s wall, {total_dur:.1f}s audio")
|
| 510 |
+
return out_waveform, out_sr
|
| 511 |
+
|
| 512 |
+
def generate_to_file(self, prompt, output, watermark: bool = True,
|
| 513 |
+
max_chunk_duration: float = 45.0,
|
| 514 |
+
target_chunk_duration: float = 37.0,
|
| 515 |
+
crossfade_ms: float = 50.0,
|
| 516 |
+
progress_callback=None,
|
| 517 |
+
**kwargs):
|
| 518 |
+
# Auto-route to generate_long when the requested duration (explicit
|
| 519 |
+
# gen_duration if set, otherwise prompt-estimated) exceeds the chunk
|
| 520 |
+
# cap. Single-shot path otherwise — same as before, no regression for
|
| 521 |
+
# short prompts.
|
| 522 |
+
explicit_dur = float(kwargs.get("gen_duration") or 0.0)
|
| 523 |
+
est_dur = explicit_dur if explicit_dur > 0 else estimate_duration(
|
| 524 |
+
prompt, kwargs.get("duration_multiplier", 1.1))
|
| 525 |
+
|
| 526 |
+
if est_dur > max_chunk_duration:
|
| 527 |
+
waveform, sr = self.generate_long(
|
| 528 |
+
prompt,
|
| 529 |
+
max_chunk_duration=max_chunk_duration,
|
| 530 |
+
target_chunk_duration=target_chunk_duration,
|
| 531 |
+
crossfade_ms=crossfade_ms,
|
| 532 |
+
progress_callback=progress_callback,
|
| 533 |
+
**kwargs,
|
| 534 |
+
)
|
| 535 |
+
else:
|
| 536 |
+
if progress_callback is not None:
|
| 537 |
+
try:
|
| 538 |
+
progress_callback(0, 1, est_dur)
|
| 539 |
+
except Exception:
|
| 540 |
+
pass
|
| 541 |
+
waveform, sr = self.generate(prompt, **kwargs)
|
| 542 |
wav_cpu = waveform.cpu().float()
|
| 543 |
if watermark:
|
| 544 |
try:
|
|
@@ -15,12 +15,10 @@ logger = logging.getLogger(__name__)
|
|
| 15 |
|
| 16 |
DRAMABOX_REPO = "ResembleAI/Dramabox"
|
| 17 |
GEMMA_REPO = "unsloth/gemma-3-12b-it-bnb-4bit"
|
|
|
|
| 18 |
|
| 19 |
# Default cache directory
|
| 20 |
-
DEFAULT_CACHE = os.environ.get(
|
| 21 |
-
"DRAMABOX_CACHE",
|
| 22 |
-
os.path.join(os.path.expanduser("~"), ".cache", "dramabox"),
|
| 23 |
-
)
|
| 24 |
|
| 25 |
# Model files in the HF repo (flat structure)
|
| 26 |
MODEL_FILES = {
|
|
@@ -75,6 +73,42 @@ def get_gemma_path(cache_dir: str = None) -> str:
|
|
| 75 |
return local_dir
|
| 76 |
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
def get_all_paths(cache_dir: str = None) -> dict:
|
| 79 |
"""Download all required models and return paths dict.
|
| 80 |
|
|
|
|
| 15 |
|
| 16 |
DRAMABOX_REPO = "ResembleAI/Dramabox"
|
| 17 |
GEMMA_REPO = "unsloth/gemma-3-12b-it-bnb-4bit"
|
| 18 |
+
REUSE_REPO = "nvidia/RE-USE"
|
| 19 |
|
| 20 |
# Default cache directory
|
| 21 |
+
DEFAULT_CACHE = os.path.join(os.environ.get("HF_HOME", os.path.expanduser("~")), ".cache", "dramabox")
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# Model files in the HF repo (flat structure)
|
| 24 |
MODEL_FILES = {
|
|
|
|
| 73 |
return local_dir
|
| 74 |
|
| 75 |
|
| 76 |
+
def get_reuse_code_path(cache_dir: str = None) -> str:
|
| 77 |
+
"""Fetch the nvidia/RE-USE code + configs needed by REUSEUpsampler.
|
| 78 |
+
|
| 79 |
+
Only the .py / .yaml / .json files are pulled (~150 KB) — the 38 MB
|
| 80 |
+
``model.safetensors`` is intentionally skipped because
|
| 81 |
+
``SEMamba.from_pretrained("nvidia/RE-USE", ...)`` re-downloads weights
|
| 82 |
+
through the standard HF cache on first instantiation, so vendoring them
|
| 83 |
+
here would just duplicate ~38 MB on disk.
|
| 84 |
+
|
| 85 |
+
Honors $REUSE_DIR for a pre-vendored copy (e.g. ``third_party/RE-USE/``):
|
| 86 |
+
if set and exists, that path is returned without touching the network.
|
| 87 |
+
Falls back to ``third_party/RE-USE/`` if it already contains the model
|
| 88 |
+
file, otherwise snapshot-downloads into the dramabox cache.
|
| 89 |
+
"""
|
| 90 |
+
env_dir = os.environ.get("REUSE_DIR")
|
| 91 |
+
if env_dir and Path(env_dir).is_dir():
|
| 92 |
+
return env_dir
|
| 93 |
+
|
| 94 |
+
repo_root = Path(__file__).resolve().parent.parent
|
| 95 |
+
local_vendor = repo_root / "third_party" / "RE-USE"
|
| 96 |
+
if (local_vendor / "models" / "generator_SEMamba_time_d4.py").is_file():
|
| 97 |
+
return str(local_vendor)
|
| 98 |
+
|
| 99 |
+
cache_dir = cache_dir or DEFAULT_CACHE
|
| 100 |
+
logger.info(f"Fetching RE-USE code/configs from {REUSE_REPO}...")
|
| 101 |
+
local_dir = snapshot_download(
|
| 102 |
+
repo_id=REUSE_REPO,
|
| 103 |
+
cache_dir=cache_dir,
|
| 104 |
+
token=os.environ.get("HF_TOKEN"),
|
| 105 |
+
allow_patterns=["*.py", "*.yaml", "*.json",
|
| 106 |
+
"recipes/*", "models/*.py", "utils/*.py"],
|
| 107 |
+
)
|
| 108 |
+
logger.info(f" -> {local_dir}")
|
| 109 |
+
return local_dir
|
| 110 |
+
|
| 111 |
+
|
| 112 |
def get_all_paths(cache_dir: str = None) -> dict:
|
| 113 |
"""Download all required models and return paths dict.
|
| 114 |
|
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RE-USE (nvidia/RE-USE) speech-enhancement wrapper.
|
| 2 |
+
|
| 3 |
+
Used by ``TTSServer._denoise_voice_ref`` to denoise the input voice reference
|
| 4 |
+
before VAE conditioning. Lazy-loads weights + code on first call so importing
|
| 5 |
+
this module is cheap.
|
| 6 |
+
|
| 7 |
+
up = REUSEUpsampler(target_sr=48000, device="cuda")
|
| 8 |
+
clean, sr = up(wav, in_sr=24000) # wav: (C, T) or (T,) float
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# REUSE_DIR is resolved lazily via model_downloader.get_reuse_code_path on
|
| 21 |
+
# first use of REUSEUpsampler — it returns the vendored third_party/RE-USE/
|
| 22 |
+
# tree if present, otherwise snapshot-downloads just the code from HF.
|
| 23 |
+
_REUSE_DIR: Optional[Path] = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _resolve_reuse_dir() -> Path:
|
| 27 |
+
global _REUSE_DIR
|
| 28 |
+
if _REUSE_DIR is None:
|
| 29 |
+
from model_downloader import get_reuse_code_path
|
| 30 |
+
_REUSE_DIR = Path(get_reuse_code_path())
|
| 31 |
+
return _REUSE_DIR
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class REUSEUpsampler:
|
| 35 |
+
"""Universal speech enhancement with optional bandwidth extension.
|
| 36 |
+
|
| 37 |
+
nvidia/RE-USE is a 9.6 M-param bidirectional-Mamba model that operates on
|
| 38 |
+
STFT amplitude+phase. With ``target_sr`` set it both denoises *and* extends
|
| 39 |
+
the bandwidth to that rate via librosa kaiser-best resample + restoration.
|
| 40 |
+
|
| 41 |
+
License: NSCLv1 (noncommercial). The base ``SEMamba`` class lives in the
|
| 42 |
+
HF repo under ``models/generator_SEMamba_time_d4.py`` and pulls in the
|
| 43 |
+
``mamba_ssm`` / ``causal-conv1d`` CUDA kernels.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
target_sr: int = 48000,
|
| 49 |
+
config_path: Optional[str] = None,
|
| 50 |
+
chunk_size_s: float = 1.0,
|
| 51 |
+
hop_portion: float = 0.5,
|
| 52 |
+
device: str | torch.device = "cuda",
|
| 53 |
+
) -> None:
|
| 54 |
+
# chunk_size_s: peak VRAM scales linearly with chunk length.
|
| 55 |
+
# 5.0s -> 2.95 GB | 2.5s -> 1.52 GB | 1.0s -> 0.67 GB (default).
|
| 56 |
+
# 1.0s is chosen as default so RE-USE fits comfortably on top of the
|
| 57 |
+
# rest of the DramaBox pipeline on any 24 GB-class GPU.
|
| 58 |
+
self.device = torch.device(device)
|
| 59 |
+
self.target_sr = int(target_sr)
|
| 60 |
+
self.chunk_size_s = float(chunk_size_s)
|
| 61 |
+
self.hop_portion = float(hop_portion)
|
| 62 |
+
# Config path is resolved lazily on first use (alongside the code tree)
|
| 63 |
+
# so importing this module never triggers a download.
|
| 64 |
+
self._config_path_override = Path(config_path) if config_path else None
|
| 65 |
+
self.config_path: Optional[Path] = None
|
| 66 |
+
self._model = None
|
| 67 |
+
self._cfg = None
|
| 68 |
+
self._stft_fns = None # (mag_phase_stft, mag_phase_istft, compress_factor, pad_or_trim)
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def _ensure_mamba_ssm_importable() -> None:
|
| 72 |
+
"""Import ``mamba_ssm`` cleanly, with a kernel-free fallback if needed.
|
| 73 |
+
|
| 74 |
+
Normal path (kernels present): just import — fast path uses
|
| 75 |
+
``selective_scan_cuda`` natively.
|
| 76 |
+
|
| 77 |
+
Fallback (kernels missing): the official package does an unconditional
|
| 78 |
+
``import selective_scan_cuda`` at module load. We stub it into
|
| 79 |
+
``sys.modules`` before importing, then redirect ``selective_scan_fn``
|
| 80 |
+
to the pure-PyTorch ``selective_scan_ref`` so the model still runs
|
| 81 |
+
(~5-10x slower).
|
| 82 |
+
"""
|
| 83 |
+
try:
|
| 84 |
+
import selective_scan_cuda # noqa: F401
|
| 85 |
+
import mamba_ssm # noqa: F401
|
| 86 |
+
return # Fast path: kernel present.
|
| 87 |
+
except ImportError:
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
import types
|
| 91 |
+
if "selective_scan_cuda" not in sys.modules:
|
| 92 |
+
stub = types.ModuleType("selective_scan_cuda")
|
| 93 |
+
def _missing(*a, **kw): # pragma: no cover - safety net only
|
| 94 |
+
raise NotImplementedError(
|
| 95 |
+
"selective_scan_cuda kernel missing; the call should have "
|
| 96 |
+
"been routed to selective_scan_ref via the runtime patch."
|
| 97 |
+
)
|
| 98 |
+
stub.fwd = _missing
|
| 99 |
+
stub.bwd = _missing
|
| 100 |
+
sys.modules["selective_scan_cuda"] = stub
|
| 101 |
+
|
| 102 |
+
from mamba_ssm.ops import selective_scan_interface as ssi
|
| 103 |
+
from mamba_ssm.modules import mamba_simple
|
| 104 |
+
if getattr(ssi, "_dramabox_kernel_free_patch_applied", False):
|
| 105 |
+
return
|
| 106 |
+
ssi.selective_scan_fn = ssi.selective_scan_ref
|
| 107 |
+
ssi.mamba_inner_fn = ssi.mamba_inner_ref
|
| 108 |
+
# mamba_simple imported these names by reference at module load -
|
| 109 |
+
# rebind there too, otherwise Mamba.forward keeps the original handles.
|
| 110 |
+
mamba_simple.selective_scan_fn = ssi.selective_scan_ref
|
| 111 |
+
mamba_simple.mamba_inner_fn = ssi.mamba_inner_ref
|
| 112 |
+
ssi._dramabox_kernel_free_patch_applied = True
|
| 113 |
+
logging.info(
|
| 114 |
+
"mamba_ssm kernel missing - using kernel-free fallback "
|
| 115 |
+
"(selective_scan_fn -> selective_scan_ref). Expect ~5-10x slowdown."
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def _lazy_load(self) -> None:
|
| 119 |
+
if self._model is not None:
|
| 120 |
+
return
|
| 121 |
+
|
| 122 |
+
# Prefer real CUDA kernels; gracefully fall back to pure-PyTorch impl.
|
| 123 |
+
self._ensure_mamba_ssm_importable()
|
| 124 |
+
|
| 125 |
+
# The RE-USE module imports `from models...` and `from utils...` —
|
| 126 |
+
# both relative to the repo root. Add to path during load.
|
| 127 |
+
reuse_dir = _resolve_reuse_dir()
|
| 128 |
+
if str(reuse_dir) not in sys.path:
|
| 129 |
+
sys.path.insert(0, str(reuse_dir))
|
| 130 |
+
|
| 131 |
+
if self.config_path is None:
|
| 132 |
+
self.config_path = self._config_path_override or (
|
| 133 |
+
reuse_dir / "recipes" /
|
| 134 |
+
"USEMamba_30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_001.yaml"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
from models.generator_SEMamba_time_d4 import SEMamba # type: ignore
|
| 138 |
+
from models.stfts import mag_phase_stft, mag_phase_istft # type: ignore
|
| 139 |
+
from utils.util import load_config, pad_or_trim_to_match # type: ignore
|
| 140 |
+
|
| 141 |
+
self._cfg = load_config(str(self.config_path))
|
| 142 |
+
compress_factor = self._cfg["model_cfg"]["compress_factor"]
|
| 143 |
+
self._stft_fns = (mag_phase_stft, mag_phase_istft, compress_factor, pad_or_trim_to_match)
|
| 144 |
+
|
| 145 |
+
# SEMamba is a PyTorchModelHubMixin; from_pretrained pulls weights from HF.
|
| 146 |
+
model = SEMamba.from_pretrained("nvidia/RE-USE", cfg=self._cfg).to(self.device)
|
| 147 |
+
model.train(False)
|
| 148 |
+
self._model = model
|
| 149 |
+
n_params = sum(p.numel() for p in model.parameters())
|
| 150 |
+
logging.info(f"RE-USE loaded: SEMamba ({n_params / 1e6:.1f}M params) -> {self.target_sr} Hz")
|
| 151 |
+
|
| 152 |
+
@staticmethod
|
| 153 |
+
def _make_even(v: float) -> int:
|
| 154 |
+
v = int(round(v))
|
| 155 |
+
return v if v % 2 == 0 else v + 1
|
| 156 |
+
|
| 157 |
+
@torch.inference_mode()
|
| 158 |
+
def __call__(self, waveform: torch.Tensor, in_sr: int = 16000) -> Tuple[torch.Tensor, int]:
|
| 159 |
+
"""Chunked overlap-add denoise / BWE (ports nvidia/RE-USE inference_chunk.py).
|
| 160 |
+
|
| 161 |
+
Peak VRAM is bounded by ``chunk_size_s * target_sr`` rather than the
|
| 162 |
+
whole clip, so a 60 s clip costs the same as a 5 s one. Crossfade is
|
| 163 |
+
a Hann-window normalized overlap-add with default 50% hop.
|
| 164 |
+
"""
|
| 165 |
+
import math
|
| 166 |
+
self._lazy_load()
|
| 167 |
+
import librosa
|
| 168 |
+
mag_phase_stft, mag_phase_istft, compress_factor, pad_or_trim_to_match = self._stft_fns
|
| 169 |
+
|
| 170 |
+
# STFT params are scaled relative to the config's training rate (8000).
|
| 171 |
+
base_n_fft = self._cfg["stft_cfg"]["n_fft"]
|
| 172 |
+
base_hop = self._cfg["stft_cfg"]["hop_size"]
|
| 173 |
+
base_win = self._cfg["stft_cfg"]["win_size"]
|
| 174 |
+
base_sr = self._cfg["stft_cfg"]["sampling_rate"]
|
| 175 |
+
|
| 176 |
+
if waveform.dim() == 1:
|
| 177 |
+
waveform = waveform.unsqueeze(0)
|
| 178 |
+
|
| 179 |
+
# 1. Resample to target rate first (skips if target_sr == in_sr).
|
| 180 |
+
if self.target_sr != in_sr:
|
| 181 |
+
wav_np = waveform.cpu().float().numpy()
|
| 182 |
+
wav_np = librosa.resample(
|
| 183 |
+
wav_np, orig_sr=in_sr, target_sr=self.target_sr, res_type="kaiser_best"
|
| 184 |
+
)
|
| 185 |
+
wav = torch.from_numpy(wav_np).to(self.device, dtype=torch.float32)
|
| 186 |
+
else:
|
| 187 |
+
wav = waveform.to(self.device, dtype=torch.float32)
|
| 188 |
+
|
| 189 |
+
op_sr = self.target_sr
|
| 190 |
+
n_fft = self._make_even(base_n_fft * op_sr // base_sr)
|
| 191 |
+
hop = self._make_even(base_hop * op_sr // base_sr)
|
| 192 |
+
win = self._make_even(base_win * op_sr // base_sr)
|
| 193 |
+
|
| 194 |
+
# 2. Chunked OLA with Hann analysis window. Mirrors inference_chunk.py.
|
| 195 |
+
chunk_size = int(self.chunk_size_s * op_sr)
|
| 196 |
+
hop_length = int(self.hop_portion * chunk_size)
|
| 197 |
+
window = torch.hann_window(chunk_size, device=self.device)
|
| 198 |
+
|
| 199 |
+
n_ch, total = wav.shape
|
| 200 |
+
enhanced = torch.zeros_like(wav)
|
| 201 |
+
window_sum = torch.zeros_like(wav)
|
| 202 |
+
n_chunks = max(1, math.ceil((total - chunk_size) / hop_length) + 1) if total > chunk_size else 1
|
| 203 |
+
|
| 204 |
+
for c in range(n_ch):
|
| 205 |
+
ch_in = wav[c : c + 1] # (1, T)
|
| 206 |
+
for i in range(n_chunks):
|
| 207 |
+
start = i * hop_length
|
| 208 |
+
end = min(start + chunk_size, total)
|
| 209 |
+
chunk = ch_in[:, start:end]
|
| 210 |
+
if chunk.shape[-1] < 2: # skip degenerate tail
|
| 211 |
+
continue
|
| 212 |
+
noisy_mag, noisy_pha, _ = mag_phase_stft(
|
| 213 |
+
chunk, n_fft=n_fft, hop_size=hop, win_size=win,
|
| 214 |
+
compress_factor=compress_factor, center=True, addeps=False,
|
| 215 |
+
)
|
| 216 |
+
amp_g, pha_g, _ = self._model(noisy_mag, noisy_pha)
|
| 217 |
+
# "Sweep artifact" filter — match the official inference.
|
| 218 |
+
mag = torch.expm1(torch.relu(amp_g))
|
| 219 |
+
zero_portion = (mag == 0).sum(dim=1) / mag.shape[1]
|
| 220 |
+
amp_g[:, :, (zero_portion > 0.5)[0]] = 0
|
| 221 |
+
|
| 222 |
+
audio_g = mag_phase_istft(amp_g, pha_g, n_fft, hop, win, compress_factor)
|
| 223 |
+
audio_g = pad_or_trim_to_match(chunk.detach(), audio_g, pad_value=1e-8)
|
| 224 |
+
|
| 225 |
+
w_slice = window[: audio_g.shape[-1]]
|
| 226 |
+
enhanced[c : c + 1, start : start + audio_g.shape[-1]] += audio_g * w_slice
|
| 227 |
+
window_sum[c : c + 1, start : start + audio_g.shape[-1]] += w_slice
|
| 228 |
+
|
| 229 |
+
# 3. Normalize where windows overlap. Avoid divide-by-zero at clip tails.
|
| 230 |
+
mask = window_sum > 1e-8
|
| 231 |
+
enhanced[mask] = enhanced[mask] / window_sum[mask]
|
| 232 |
+
return enhanced.clamp(-1.0, 1.0).cpu().float(), op_sr
|
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prompt chunking for long-form DramaBox generation.
|
| 2 |
+
|
| 3 |
+
The base LTX-2.3 audio DiT was trained on clips <= ~20 s. The silence-prior
|
| 4 |
+
patch in ``inference_server.py`` keeps generations sane up to ~45 s, but the
|
| 5 |
+
prior re-emerges past that boundary. For arbitrary-length prompts we split the
|
| 6 |
+
text into < 45 s chunks, generate each conditioned on the same voice reference,
|
| 7 |
+
and crossfade them back together.
|
| 8 |
+
|
| 9 |
+
Chunking is quote-aware (sentence terminators inside ``"..."`` don't count)
|
| 10 |
+
and preserves the speaker-description prefix on every chunk so the model keeps
|
| 11 |
+
the same persona / delivery style across joins.
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import re
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import List, Optional
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Matches the leading speaker description, ending at the first comma that's
|
| 21 |
+
# directly followed by a space + opening quote. Anything before that is treated
|
| 22 |
+
# as persona/style metadata and re-attached to every chunk.
|
| 23 |
+
# "A shadowy villain speaks with cold menace, \"You have entered...\""
|
| 24 |
+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 25 |
+
_PREFIX_RE = re.compile(r'^([^"\']{3,}?)(,\s*)(?=["\'])', re.DOTALL)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class PromptChunk:
|
| 30 |
+
text: str
|
| 31 |
+
est_duration_s: float
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def extract_speaker_prefix(prompt: str) -> tuple[Optional[str], str]:
|
| 35 |
+
"""Return ``(prefix, body)`` where ``prefix`` is the speaker description.
|
| 36 |
+
|
| 37 |
+
If the prompt has the canonical ``"<persona>, "<dialogue>"..."`` form, the
|
| 38 |
+
persona (without the trailing comma) is returned as the prefix and the rest
|
| 39 |
+
of the prompt as the body. Otherwise ``(None, prompt)`` — no prefix to
|
| 40 |
+
propagate, the whole prompt is treated as a single body.
|
| 41 |
+
"""
|
| 42 |
+
m = _PREFIX_RE.match(prompt)
|
| 43 |
+
if not m:
|
| 44 |
+
return None, prompt
|
| 45 |
+
return m.group(1).strip(), prompt[m.end():]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def split_sentences_outside_quotes(text: str) -> List[str]:
|
| 49 |
+
"""Split ``text`` into sentences, ignoring terminators inside quotes.
|
| 50 |
+
|
| 51 |
+
A "sentence" here is a span ending in ``.``/``!``/``?`` (optionally followed
|
| 52 |
+
by a closing quote) at the top level — i.e. not inside an open ``"..."`` or
|
| 53 |
+
``'...'`` pair. Empty / whitespace-only fragments are dropped.
|
| 54 |
+
|
| 55 |
+
Examples:
|
| 56 |
+
>>> split_sentences_outside_quotes('He says, "Hi, how are you?" Then leaves.')
|
| 57 |
+
['He says, "Hi, how are you?"', 'Then leaves.']
|
| 58 |
+
"""
|
| 59 |
+
sentences: List[str] = []
|
| 60 |
+
buf: List[str] = []
|
| 61 |
+
in_double = False
|
| 62 |
+
in_single = False
|
| 63 |
+
i = 0
|
| 64 |
+
n = len(text)
|
| 65 |
+
while i < n:
|
| 66 |
+
ch = text[i]
|
| 67 |
+
buf.append(ch)
|
| 68 |
+
|
| 69 |
+
if ch == '"' and not in_single:
|
| 70 |
+
was_inside = in_double
|
| 71 |
+
in_double = not in_double
|
| 72 |
+
# Treat the *closing* quote as a sentence boundary if the last
|
| 73 |
+
# meaningful char inside it was a terminator: ``...how are you?"``.
|
| 74 |
+
if was_inside and len(buf) >= 2 and buf[-2] in ".!?":
|
| 75 |
+
# Boundary requires whitespace / end-of-string after.
|
| 76 |
+
if i + 1 >= n or text[i + 1].isspace():
|
| 77 |
+
sentence = "".join(buf).strip()
|
| 78 |
+
if sentence:
|
| 79 |
+
sentences.append(sentence)
|
| 80 |
+
buf = []
|
| 81 |
+
i += 1
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
elif ch == "'" and not in_double:
|
| 85 |
+
# Apostrophes inside a word (don't, it's) are not quote toggles.
|
| 86 |
+
prev = text[i - 1] if i > 0 else " "
|
| 87 |
+
nxt = text[i + 1] if i + 1 < n else " "
|
| 88 |
+
if not (prev.isalpha() and nxt.isalpha()):
|
| 89 |
+
in_single = not in_single
|
| 90 |
+
|
| 91 |
+
elif ch in ".!?" and not in_double and not in_single:
|
| 92 |
+
# Greedily eat trailing closing quotes / punctuation.
|
| 93 |
+
j = i + 1
|
| 94 |
+
while j < n and text[j] in '."\')]':
|
| 95 |
+
buf.append(text[j])
|
| 96 |
+
if text[j] == '"':
|
| 97 |
+
in_double = not in_double # closing quote toggle
|
| 98 |
+
j += 1
|
| 99 |
+
if j >= n or text[j].isspace():
|
| 100 |
+
sentence = "".join(buf).strip()
|
| 101 |
+
if sentence:
|
| 102 |
+
sentences.append(sentence)
|
| 103 |
+
buf = []
|
| 104 |
+
i = j
|
| 105 |
+
continue
|
| 106 |
+
i += 1
|
| 107 |
+
|
| 108 |
+
tail = "".join(buf).strip()
|
| 109 |
+
if tail:
|
| 110 |
+
sentences.append(tail)
|
| 111 |
+
return sentences
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _assemble(prefix: Optional[str], sentences: List[str]) -> str:
|
| 115 |
+
body = " ".join(s.strip() for s in sentences if s.strip())
|
| 116 |
+
if not prefix:
|
| 117 |
+
return body
|
| 118 |
+
# Re-attach prefix in the canonical "persona, body" form. If the first
|
| 119 |
+
# sentence already starts with a stage direction (no opening quote), drop
|
| 120 |
+
# the comma + use a period so the syntax reads naturally.
|
| 121 |
+
if body.lstrip().startswith(("'", '"')):
|
| 122 |
+
return f"{prefix}, {body}"
|
| 123 |
+
return f"{prefix}. {body}"
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def chunk_prompt_for_duration(
|
| 127 |
+
prompt: str,
|
| 128 |
+
max_duration_s: float = 45.0,
|
| 129 |
+
target_duration_s: float = 37.0,
|
| 130 |
+
duration_multiplier: float = 1.1,
|
| 131 |
+
) -> List[PromptChunk]:
|
| 132 |
+
"""Split ``prompt`` into <= ``max_duration_s`` chunks.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
prompt: Full scene prompt (DramaBox format or plain text).
|
| 136 |
+
max_duration_s: Hard cap per chunk; we never emit a chunk whose
|
| 137 |
+
estimator output (after ``duration_multiplier``) exceeds this.
|
| 138 |
+
target_duration_s: Soft cap; we close the current chunk when adding
|
| 139 |
+
the next sentence would push it past this. Leaving 5-10 s of
|
| 140 |
+
headroom below ``max_duration_s`` keeps us safe against the
|
| 141 |
+
estimator under-shooting by ~10-15% on action-heavy prompts.
|
| 142 |
+
duration_multiplier: Same breathing-room multiplier the inference
|
| 143 |
+
server applies in ``estimate_duration``; matches the per-chunk
|
| 144 |
+
target the model is actually asked to generate.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
List of :class:`PromptChunk`. Single-chunk prompts return a 1-element
|
| 148 |
+
list with the original prompt unchanged.
|
| 149 |
+
"""
|
| 150 |
+
from duration_estimator import estimate_speech_duration
|
| 151 |
+
|
| 152 |
+
def _est(t: str) -> float:
|
| 153 |
+
return estimate_speech_duration(t) * duration_multiplier
|
| 154 |
+
|
| 155 |
+
total = _est(prompt)
|
| 156 |
+
if total <= max_duration_s:
|
| 157 |
+
return [PromptChunk(text=prompt, est_duration_s=total)]
|
| 158 |
+
|
| 159 |
+
prefix, body = extract_speaker_prefix(prompt)
|
| 160 |
+
sentences = split_sentences_outside_quotes(body)
|
| 161 |
+
if not sentences:
|
| 162 |
+
# Degenerate: no sentence boundaries. Fall back to whitespace-token
|
| 163 |
+
# chunking so we still produce SOMETHING under the cap.
|
| 164 |
+
sentences = body.split()
|
| 165 |
+
|
| 166 |
+
chunks: List[PromptChunk] = []
|
| 167 |
+
current: List[str] = []
|
| 168 |
+
current_dur = 0.0
|
| 169 |
+
|
| 170 |
+
for sent in sentences:
|
| 171 |
+
candidate = _assemble(prefix, current + [sent])
|
| 172 |
+
cand_dur = _est(candidate)
|
| 173 |
+
|
| 174 |
+
if current and cand_dur > target_duration_s:
|
| 175 |
+
# Close the current chunk before adding this sentence.
|
| 176 |
+
assembled = _assemble(prefix, current)
|
| 177 |
+
chunks.append(PromptChunk(text=assembled, est_duration_s=_est(assembled)))
|
| 178 |
+
current = [sent]
|
| 179 |
+
current_dur = _est(_assemble(prefix, current))
|
| 180 |
+
else:
|
| 181 |
+
current.append(sent)
|
| 182 |
+
current_dur = cand_dur
|
| 183 |
+
|
| 184 |
+
# Pathological case: a single sentence whose estimator output is
|
| 185 |
+
# already past max_duration_s. Emit it on its own and let downstream
|
| 186 |
+
# generate() truncate the request at the model's hard limit; the user
|
| 187 |
+
# gets a degraded but non-crashing result instead of an exception.
|
| 188 |
+
if len(current) == 1 and current_dur > max_duration_s:
|
| 189 |
+
solo = _assemble(prefix, current)
|
| 190 |
+
chunks.append(PromptChunk(text=solo, est_duration_s=current_dur))
|
| 191 |
+
current = []
|
| 192 |
+
current_dur = 0.0
|
| 193 |
+
|
| 194 |
+
if current:
|
| 195 |
+
assembled = _assemble(prefix, current)
|
| 196 |
+
chunks.append(PromptChunk(text=assembled, est_duration_s=_est(assembled)))
|
| 197 |
+
|
| 198 |
+
return chunks
|