multimodalart commited on
Commit
cdc4405
Β·
1 Parent(s): c327e46

Initial Gradio ZeroGPU app for Scenema Audio

Browse files
README.md CHANGED
@@ -1,13 +1,38 @@
1
  ---
2
  title: Scenema Audio
3
- emoji: πŸš€
4
  colorFrom: pink
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 6.14.0
8
- python_version: '3.13'
9
  app_file: app.py
10
  pinned: false
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Scenema Audio
3
+ emoji: πŸŽ™οΈ
4
  colorFrom: pink
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 6.14.0
8
+ python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
+ hardware: zero-a10g
12
+ short_description: Zero-shot expressive voice cloning and speech generation
13
+ suggested_storage: large
14
  ---
15
 
16
+ # Scenema Audio (ZeroGPU)
17
+
18
+ Gradio wrapper around [ScenemaAI/scenema-audio](https://github.com/ScenemaAI/scenema-audio).
19
+
20
+ Zero-shot expressive voice cloning and speech generation with emotion, pacing,
21
+ and breath control, built on an audio diffusion transformer extracted from
22
+ [LTX 2.3](https://github.com/Lightricks/LTX-2).
23
+
24
+ ## Cold start
25
+
26
+ First request downloads ~38 GB of model weights:
27
+ - `scenema-audio-transformer-int8.safetensors` (~4.9 GB)
28
+ - `scenema-audio-pipeline.safetensors` (~6.7 GB)
29
+ - `google/gemma-3-12b-it` (~24 GB, **gated** β€” requires `HF_TOKEN` secret)
30
+ - SeedVC + BigVGAN + Whisper checkpoints (~3 GB)
31
+ - MelBandRoFormer (~436 MB)
32
+
33
+ Set `HF_TOKEN` in the Space secrets with access to `google/gemma-3-12b-it`.
34
+
35
+ ## License
36
+
37
+ - **Model weights:** LTX-2 Community License Agreement
38
+ - **Code:** MIT
app.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scenema Audio - ZeroGPU Gradio Space.
2
+
3
+ Wraps the ScenemaAI/scenema-audio AudioProcessor in a Gradio UI.
4
+ Heavy model weights (~38 GB) are downloaded on first cold-start and
5
+ cached on persistent storage; generation runs under @spaces.GPU.
6
+ """
7
+
8
+ import asyncio
9
+ import base64
10
+ import logging
11
+ import os
12
+ import sys
13
+ import tempfile
14
+ import uuid
15
+ from pathlib import Path
16
+
17
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
18
+
19
+ # Allow tweaking via env, but default to repo-local cache so weights persist
20
+ # across worker restarts on Spaces persistent storage if mounted at /data.
21
+ MODEL_DIR = Path(os.environ.get("MODEL_DIR", "/data/models")) \
22
+ if Path("/data").exists() else Path(os.environ.get("MODEL_DIR", "./models"))
23
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
24
+ os.environ["MODEL_DIR"] = str(MODEL_DIR)
25
+
26
+ # Default model paths (must be set before AudioProcessor is imported)
27
+ os.environ.setdefault(
28
+ "AUDIO_CKPT", str(MODEL_DIR / "scenema-audio-transformer-int8.safetensors")
29
+ )
30
+ os.environ.setdefault(
31
+ "PIPELINE_CKPT", str(MODEL_DIR / "scenema-audio-pipeline.safetensors")
32
+ )
33
+ os.environ.setdefault(
34
+ "VAE_ENCODER_CKPT", str(MODEL_DIR / "scenema-audio-vae-encoder.safetensors")
35
+ )
36
+ os.environ.setdefault("GEMMA_ROOT", str(MODEL_DIR / "gemma-3-12b-it"))
37
+ os.environ.setdefault(
38
+ "MELBAND_MODEL_PATH", str(MODEL_DIR / "MelBandRoformer_fp16.safetensors")
39
+ )
40
+ os.environ.setdefault("SEEDVC_PATH", str(Path.cwd() / "seed-vc"))
41
+ os.environ.setdefault("MELBAND_NODE_PATH", str(Path.cwd() / "melband_roformer_node"))
42
+ os.environ.setdefault("HF_HUB_CACHE", str(MODEL_DIR / "hf_cache"))
43
+ os.environ.setdefault("GEMMA_QUANTIZE", "nf4")
44
+
45
+ # Make repo source importable
46
+ sys.path.insert(0, str(Path(__file__).parent / "src"))
47
+
48
+ import gradio as gr
49
+ import spaces
50
+ from huggingface_hub import hf_hub_download, snapshot_download
51
+
52
+ logging.basicConfig(
53
+ level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s"
54
+ )
55
+ logger = logging.getLogger("scenema-space")
56
+
57
+
58
+ # ── Model download (CPU phase, runs at import) ────────────────────────────
59
+
60
+ HF_REPO = "ScenemaAI/scenema-audio"
61
+ GEMMA_REPO = "google/gemma-3-12b-it"
62
+ SEEDVC_REPO = "Plachta/Seed-VC"
63
+ BIGVGAN_REPO = "nvidia/bigvgan_v2_22khz_80band_256x"
64
+ WHISPER_REPO = "openai/whisper-small"
65
+
66
+
67
+ def _download_all():
68
+ token = os.environ.get("HF_TOKEN")
69
+
70
+ audio_ckpt = Path(os.environ["AUDIO_CKPT"])
71
+ if not audio_ckpt.exists():
72
+ logger.info("Downloading audio transformer INT8 (~4.9 GB)...")
73
+ hf_hub_download(
74
+ HF_REPO,
75
+ "scenema-audio-transformer-int8.safetensors",
76
+ local_dir=str(audio_ckpt.parent),
77
+ token=token,
78
+ )
79
+
80
+ pipeline_ckpt = Path(os.environ["PIPELINE_CKPT"])
81
+ if not pipeline_ckpt.exists():
82
+ logger.info("Downloading pipeline checkpoint (~6.7 GB)...")
83
+ hf_hub_download(
84
+ HF_REPO,
85
+ "scenema-audio-pipeline.safetensors",
86
+ local_dir=str(pipeline_ckpt.parent),
87
+ token=token,
88
+ )
89
+
90
+ vae = Path(os.environ["VAE_ENCODER_CKPT"])
91
+ if not vae.exists():
92
+ logger.info("Downloading VAE encoder (~42 MB)...")
93
+ hf_hub_download(
94
+ HF_REPO,
95
+ "scenema-audio-vae-encoder.safetensors",
96
+ local_dir=str(vae.parent),
97
+ token=token,
98
+ )
99
+
100
+ melband = Path(os.environ["MELBAND_MODEL_PATH"])
101
+ if not melband.exists():
102
+ logger.info("Downloading MelBandRoFormer (~436 MB)...")
103
+ hf_hub_download(
104
+ "Kijai/MelBandRoFormer_comfy",
105
+ "MelBandRoformer_fp16.safetensors",
106
+ local_dir=str(melband.parent),
107
+ token=token,
108
+ )
109
+
110
+ gemma = Path(os.environ["GEMMA_ROOT"])
111
+ if not gemma.exists() or not any(gemma.glob("*.safetensors")):
112
+ logger.info("Downloading Gemma 3 12B IT (~24 GB, gated)...")
113
+ snapshot_download(
114
+ GEMMA_REPO,
115
+ local_dir=str(gemma),
116
+ ignore_patterns=["*.gguf"],
117
+ token=token,
118
+ )
119
+
120
+ seedvc_path = Path(os.environ["SEEDVC_PATH"])
121
+ seedvc_ckpts = seedvc_path / "checkpoints"
122
+ if not seedvc_ckpts.exists() or not any(seedvc_ckpts.glob("*.pth")):
123
+ logger.info("Downloading SeedVC checkpoints (~1.6 GB)...")
124
+ seedvc_ckpts.mkdir(parents=True, exist_ok=True)
125
+ hf_cache = seedvc_ckpts / "hf_cache"
126
+ hf_cache.mkdir(parents=True, exist_ok=True)
127
+ os.environ["HF_HUB_CACHE"] = str(hf_cache)
128
+ hf_hub_download(
129
+ SEEDVC_REPO,
130
+ "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
131
+ local_dir=str(seedvc_ckpts),
132
+ token=token,
133
+ )
134
+ hf_hub_download(
135
+ SEEDVC_REPO,
136
+ "config_dit_mel_seed_uvit_whisper_small_wavenet.yml",
137
+ local_dir=str(seedvc_ckpts),
138
+ token=token,
139
+ )
140
+ snapshot_download(BIGVGAN_REPO, local_dir=str(hf_cache / "bigvgan"))
141
+ snapshot_download(WHISPER_REPO, local_dir=str(hf_cache / "whisper-small"))
142
+
143
+
144
+ def _ensure_seedvc_repo():
145
+ """Clone the seed-vc python source if missing (architecture code)."""
146
+ seedvc = Path(os.environ["SEEDVC_PATH"])
147
+ if not (seedvc / "modules").exists():
148
+ logger.info("Cloning seed-vc source...")
149
+ os.system(f"git clone --depth 1 https://github.com/Plachtaa/seed-vc.git {seedvc}")
150
+
151
+ melband_node = Path(os.environ["MELBAND_NODE_PATH"])
152
+ if not melband_node.exists():
153
+ logger.info("Cloning ComfyUI-MelBandRoFormer source...")
154
+ os.system(
155
+ f"git clone --depth 1 https://github.com/kijai/ComfyUI-MelBandRoFormer {melband_node}"
156
+ )
157
+
158
+
159
+ _ensure_seedvc_repo()
160
+ _download_all()
161
+
162
+ # Import processor only after model paths/env are set
163
+ from audio_core.processor import AudioProcessor # noqa: E402
164
+ from common.handlers.base import ProcessJob # noqa: E402
165
+
166
+ _processor: AudioProcessor | None = None
167
+
168
+
169
+ def _get_processor() -> AudioProcessor:
170
+ global _processor
171
+ if _processor is None:
172
+ _processor = AudioProcessor()
173
+ _processor.startup()
174
+ return _processor
175
+
176
+
177
+ # ── Generation ────────────────────────────────────────────────────────────
178
+
179
+
180
+ def _build_prompt(text, voice, gender, scene, language, shot, action, sound_before):
181
+ attrs = [f'voice="{voice}"', f'gender="{gender}"']
182
+ if scene:
183
+ attrs.append(f'scene="{scene}"')
184
+ if language and language != "en":
185
+ attrs.append(f'language="{language}"')
186
+ if shot:
187
+ attrs.append(f'shot="{shot}"')
188
+
189
+ inner = ""
190
+ if sound_before:
191
+ inner += f"<sound>{sound_before}</sound>"
192
+ if action:
193
+ inner += f"<action>{action}</action>"
194
+ inner += text
195
+
196
+ return f"<speak {' '.join(attrs)}>{inner}</speak>"
197
+
198
+
199
+ @spaces.GPU(duration=300)
200
+ def generate(
201
+ text,
202
+ voice,
203
+ gender,
204
+ scene,
205
+ language,
206
+ shot,
207
+ action,
208
+ sound_before,
209
+ reference_audio,
210
+ mode,
211
+ seed,
212
+ background_sfx,
213
+ skip_vc,
214
+ raw_xml,
215
+ progress=gr.Progress(track_tqdm=True),
216
+ ):
217
+ progress(0, desc="Loading models (cold start can take a few minutes)")
218
+ processor = _get_processor()
219
+
220
+ if raw_xml and raw_xml.strip():
221
+ prompt = raw_xml.strip()
222
+ else:
223
+ if not text.strip():
224
+ raise gr.Error("Speech text is required.")
225
+ prompt = _build_prompt(text, voice, gender, scene, language, shot, action, sound_before)
226
+
227
+ # If reference audio is a local file (gradio path), upload-less: we copy into
228
+ # a temp http-less path that AudioProcessor expects URL. Easiest: serve via
229
+ # a file:// URL β€” but httpx doesn't support file://. Instead, patch path by
230
+ # writing input to a known place and using a fake URL handler via temp.
231
+ body = {
232
+ "prompt": prompt,
233
+ "mode": mode,
234
+ "seed": int(seed) if seed is not None else -1,
235
+ "background_sfx": bool(background_sfx),
236
+ "skip_vc": bool(skip_vc),
237
+ "validate": True,
238
+ }
239
+
240
+ # Reference voice: AudioProcessor downloads from URL. We bypass by directly
241
+ # placing a local path; the _generate function uses `reference_voice_url`
242
+ # and calls `_download_reference`. Workaround: monkey-patch download to
243
+ # return the local path if a file:// URL is given.
244
+ ref_local_path = None
245
+ if reference_audio:
246
+ ref_local_path = reference_audio
247
+ body["reference_voice_url"] = f"file://{ref_local_path}"
248
+
249
+ async def _run():
250
+ # Patch _download_reference for this call to handle file:// URLs
251
+ original = processor._download_reference
252
+
253
+ async def patched(url):
254
+ if url.startswith("file://"):
255
+ return url[len("file://"):]
256
+ return await original(url)
257
+
258
+ processor._download_reference = patched
259
+ try:
260
+ job = ProcessJob(job_id=str(uuid.uuid4()), input=body)
261
+ return await processor.process(job)
262
+ finally:
263
+ processor._download_reference = original
264
+
265
+ progress(0.1, desc="Generating audio")
266
+ result = asyncio.run(_run())
267
+
268
+ if not result.success:
269
+ raise gr.Error(result.error or "Generation failed")
270
+
271
+ # Write to temp wav and return path
272
+ out_path = Path(tempfile.gettempdir()) / f"scenema_{uuid.uuid4().hex}.wav"
273
+ out_path.write_bytes(result.output.data)
274
+ meta = result.output.metadata or {}
275
+ info = (
276
+ f"Duration: {meta.get('duration_s', 0)}s Β· "
277
+ f"Seed: {meta.get('seed')} Β· "
278
+ f"GPU: {meta.get('gpu', 'N/A')} Β· "
279
+ f"Time: {meta.get('processing_ms', 0)} ms"
280
+ )
281
+ return str(out_path), info
282
+
283
+
284
+ # ── UI ────────────────────────────────────────────────────────────────────
285
+
286
+ EXAMPLES = [
287
+ [
288
+ "The old lighthouse had stood on the cliff for over a century, its beam cutting through the fog like a blade of light.",
289
+ "A warm, clear male voice with a slight British accent. Measured, thoughtful pacing.",
290
+ "male", "", "en", "closeup", "", "",
291
+ None, "generate", 42, False, False, "",
292
+ ],
293
+ [
294
+ "The city never really sleeps. It just closes its eyes and pretends for a while.",
295
+ "A young woman with a smoky, low register voice. Intimate, confessional tone.",
296
+ "female", "", "en", "closeup", "", "",
297
+ None, "voice_design", 7, False, False, "",
298
+ ],
299
+ [
300
+ "Get the lines! She is pulling loose! Move! I said move!",
301
+ "Male, mid 40s. Weathered. Urgent, projecting over wind.",
302
+ "male", "Open dock in a thunderstorm, heavy rain", "en", "scene",
303
+ "He shouts over the storm", "Heavy rain and wind howling",
304
+ None, "generate", 11, True, False, "",
305
+ ],
306
+ ]
307
+
308
+ with gr.Blocks(title="Scenema Audio") as demo:
309
+ gr.Markdown(
310
+ """
311
+ # Scenema Audio Β· Zero-shot Expressive TTS
312
+ Generate expressive speech with emotion, scene, and voice cloning.
313
+ Built on [ScenemaAI/scenema-audio](https://github.com/ScenemaAI/scenema-audio).
314
+
315
+ **Note:** First request triggers a ~38 GB cold start. Subsequent requests are fast.
316
+ """
317
+ )
318
+ with gr.Row():
319
+ with gr.Column(scale=3):
320
+ text = gr.Textbox(
321
+ label="Speech text",
322
+ lines=4,
323
+ placeholder="What the voice should say...",
324
+ )
325
+ voice = gr.Textbox(
326
+ label="Voice description",
327
+ lines=2,
328
+ placeholder='e.g. "A warm male voice with a slight British accent..."',
329
+ )
330
+ with gr.Row():
331
+ gender = gr.Radio(["male", "female"], value="male", label="Gender")
332
+ language = gr.Dropdown(
333
+ ["en", "es", "fr", "de", "it", "pt", "ja", "zh", "ko"],
334
+ value="en", label="Language",
335
+ )
336
+ shot = gr.Radio(
337
+ ["closeup", "wide", "scene"], value="closeup", label="Shot"
338
+ )
339
+ with gr.Accordion("Scene & direction (optional)", open=False):
340
+ scene = gr.Textbox(label="Scene", placeholder="e.g. busy cafe at midday")
341
+ action = gr.Textbox(label="Performance direction (<action>)")
342
+ sound_before = gr.Textbox(label="Sound event before speech (<sound>)")
343
+ with gr.Accordion("Raw XML override (optional)", open=False):
344
+ raw_xml = gr.Textbox(
345
+ label="<speak> XML (overrides fields above when set)",
346
+ lines=4,
347
+ )
348
+ with gr.Accordion("Voice cloning (optional)", open=False):
349
+ reference_audio = gr.Audio(
350
+ label="Reference voice (10-20s)",
351
+ type="filepath",
352
+ )
353
+ with gr.Row():
354
+ mode = gr.Radio(
355
+ ["generate", "voice_design"], value="generate", label="Mode"
356
+ )
357
+ seed = gr.Number(value=42, precision=0, label="Seed (-1 = random)")
358
+ with gr.Row():
359
+ background_sfx = gr.Checkbox(value=False, label="Keep background SFX")
360
+ skip_vc = gr.Checkbox(value=False, label="Skip SeedVC post-processing")
361
+ run_btn = gr.Button("Generate", variant="primary")
362
+ with gr.Column(scale=2):
363
+ out_audio = gr.Audio(label="Output", type="filepath")
364
+ info = gr.Textbox(label="Info", interactive=False)
365
+
366
+ gr.Examples(
367
+ examples=EXAMPLES,
368
+ inputs=[
369
+ text, voice, gender, scene, language, shot, action, sound_before,
370
+ reference_audio, mode, seed, background_sfx, skip_vc, raw_xml,
371
+ ],
372
+ )
373
+
374
+ run_btn.click(
375
+ generate,
376
+ inputs=[
377
+ text, voice, gender, scene, language, shot, action, sound_before,
378
+ reference_audio, mode, seed, background_sfx, skip_vc, raw_xml,
379
+ ],
380
+ outputs=[out_audio, info],
381
+ )
382
+
383
+
384
+ if __name__ == "__main__":
385
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==2.2.6
2
+ transformers==4.57.6
3
+ accelerate==1.13.0
4
+ safetensors==0.7.0
5
+ sentencepiece==0.2.1
6
+ ltx-core @ git+https://github.com/Lightricks/LTX-2.git@41d924371612b692c0fd1e4d9d94c3dfb3c02cb3#subdirectory=packages/ltx-core
7
+ ltx-pipelines @ git+https://github.com/Lightricks/LTX-2.git@41d924371612b692c0fd1e4d9d94c3dfb3c02cb3#subdirectory=packages/ltx-pipelines
8
+ scipy==1.13.1
9
+ librosa==0.10.2
10
+ huggingface-hub==0.36.2
11
+ munch==4.0.0
12
+ einops==0.8.0
13
+ descript-audio-codec==1.0.0
14
+ pydub==0.25.1
15
+ soundfile==0.12.1
16
+ hydra-core==1.3.2
17
+ pyyaml==6.0.3
18
+ python-dotenv==1.2.2
19
+ diffusers==0.37.1
20
+ onnxruntime==1.25.0
21
+ funasr==1.3.1
22
+ rotary-embedding-torch==0.8.9
23
+ beartype==0.22.9
24
+ fastapi==0.136.1
25
+ httpx==0.28.1
26
+ psutil==7.2.2
27
+ bitsandbytes==0.49.2
28
+ kokoro==0.9.4
29
+ faster-whisper==1.2.1
30
+ ctranslate2==4.7.1
src/audio_core/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """Scenema Audio: Expressive audio generation via LTX 2.3 audio diffusion."""
6
+
7
+ __version__ = "1.0.0"
src/audio_core/audio_utils.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """Audio utility functions for Scenema Audio.
6
+
7
+ Silence trimming, volume normalization, wav I/O, format conversion.
8
+ """
9
+
10
+ import logging
11
+ import math
12
+
13
+ import numpy as np
14
+ import soundfile as sf
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def trim_silence(
20
+ audio_np: np.ndarray,
21
+ sr: int,
22
+ max_silence: float = 0.5,
23
+ threshold_db: float = -40,
24
+ ) -> np.ndarray:
25
+ """Trim silence exceeding max_silence from start and end of audio.
26
+
27
+ Keeps up to max_silence seconds of silence at boundaries.
28
+
29
+ Args:
30
+ audio_np: Audio samples, shape (samples,) or (samples, channels).
31
+ sr: Sample rate in Hz.
32
+ max_silence: Maximum silence to keep at head/tail in seconds.
33
+ threshold_db: Amplitude threshold below which audio is considered silence.
34
+
35
+ Returns:
36
+ Trimmed audio array with the same number of dimensions as input.
37
+ """
38
+ threshold = 10 ** (threshold_db / 20.0)
39
+ max_silent_samples = int(max_silence * sr)
40
+ window = int(0.02 * sr) # 20ms analysis window
41
+
42
+ if audio_np.ndim == 2:
43
+ mono = audio_np.mean(axis=1)
44
+ else:
45
+ mono = audio_np
46
+
47
+ if len(mono) < window:
48
+ return audio_np
49
+
50
+ energy = np.array(
51
+ [
52
+ np.abs(mono[i : i + window]).max()
53
+ for i in range(0, len(mono) - window, window)
54
+ ]
55
+ )
56
+
57
+ voiced = np.where(energy > threshold)[0]
58
+ if len(voiced) == 0:
59
+ return audio_np
60
+
61
+ first_voiced = max(0, voiced[0] * window - max_silent_samples)
62
+ last_voiced = min(len(audio_np), (voiced[-1] + 1) * window + max_silent_samples)
63
+
64
+ return audio_np[first_voiced:last_voiced]
65
+
66
+
67
+ def normalize_volume(
68
+ audio_np: np.ndarray,
69
+ sr: int,
70
+ target_lufs: float = -23.0,
71
+ ) -> np.ndarray:
72
+ """Normalize audio volume to target LUFS (approximate via RMS).
73
+
74
+ Uses a simplified RMS-based LUFS approximation suitable for
75
+ per-chunk normalization before concatenation.
76
+
77
+ Args:
78
+ audio_np: Audio samples, shape (samples,) or (samples, channels).
79
+ sr: Sample rate in Hz.
80
+ target_lufs: Target loudness in LUFS (default -23, EBU R128).
81
+
82
+ Returns:
83
+ Volume-normalized audio array, soft-clipped to prevent distortion.
84
+ """
85
+ if audio_np.ndim == 2:
86
+ mono = audio_np.mean(axis=1)
87
+ else:
88
+ mono = audio_np
89
+
90
+ rms = np.sqrt(np.mean(mono**2))
91
+ if rms < 1e-8:
92
+ return audio_np
93
+
94
+ current_lufs = 20 * math.log10(rms) - 0.691
95
+ gain_db = target_lufs - current_lufs
96
+ gain = 10 ** (gain_db / 20.0)
97
+ gain = max(0.1, min(gain, 10.0))
98
+
99
+ result = audio_np * gain
100
+
101
+ peak = np.abs(result).max()
102
+ if peak > 0.99:
103
+ result = result * (0.99 / peak)
104
+
105
+ return result
106
+
107
+
108
+ def extract_wav(audio_obj) -> tuple[np.ndarray, int]:
109
+ """Extract numpy waveform from an LTX Audio object.
110
+
111
+ Handles shapes: (B,C,samples) -> (samples,C), (C,samples) -> (samples,C).
112
+
113
+ Args:
114
+ audio_obj: LTX pipeline Audio object with .waveform and .sampling_rate.
115
+
116
+ Returns:
117
+ Tuple of (waveform as float32 numpy, sample_rate).
118
+ """
119
+ w = audio_obj.waveform.cpu().float().numpy()
120
+ if w.ndim == 3:
121
+ w = w.squeeze(0)
122
+ if w.ndim == 2:
123
+ w = w.T
124
+ return w, audio_obj.sampling_rate
125
+
126
+
127
+ def save_wav(audio_np: np.ndarray, sr: int, path: str) -> None:
128
+ """Save audio to WAV file.
129
+
130
+ Args:
131
+ audio_np: Audio samples, shape (samples,) or (samples, channels).
132
+ sr: Sample rate in Hz.
133
+ path: Output file path.
134
+ """
135
+ sf.write(path, audio_np, sr)
136
+
137
+
138
+ def load_wav(path: str) -> tuple[np.ndarray, int]:
139
+ """Load audio from WAV file.
140
+
141
+ Args:
142
+ path: Input file path.
143
+
144
+ Returns:
145
+ Tuple of (audio samples as float64 numpy, sample_rate).
146
+ """
147
+ data, sr = sf.read(path)
148
+ return data, sr
149
+
150
+
151
+ def to_mono(audio_np: np.ndarray) -> np.ndarray:
152
+ """Convert stereo to mono by averaging channels.
153
+
154
+ Args:
155
+ audio_np: Audio samples, shape (samples, 2) for stereo or (samples,) for mono.
156
+
157
+ Returns:
158
+ Mono audio array, shape (samples,).
159
+ """
160
+ if audio_np.ndim == 2 and audio_np.shape[1] == 2:
161
+ return audio_np.mean(axis=1)
162
+ return audio_np
163
+
164
+
165
+ def shorten_long_silence(
166
+ audio_np: np.ndarray,
167
+ sr: int,
168
+ max_duration: float = 1.0,
169
+ target_duration: float = 0.3,
170
+ threshold_db: float = -35,
171
+ ) -> np.ndarray:
172
+ """Shorten silence regions longer than max_duration to target_duration.
173
+
174
+ Unlike silenceremove which deletes silence entirely, this preserves
175
+ a natural pause of target_duration seconds. Prevents chunk boundary
176
+ artifacts while keeping the audio flow natural.
177
+
178
+ Args:
179
+ audio_np: Audio samples, shape (samples,) or (samples, channels).
180
+ sr: Sample rate in Hz.
181
+ max_duration: Silence longer than this is shortened.
182
+ target_duration: Silence is shortened to this duration.
183
+ threshold_db: Amplitude threshold below which audio is silence.
184
+
185
+ Returns:
186
+ Audio with long silence regions shortened.
187
+ """
188
+ threshold = 10 ** (threshold_db / 20.0)
189
+ window = int(0.02 * sr) # 20ms analysis window
190
+ max_samples = int(max_duration * sr)
191
+ target_samples = int(target_duration * sr)
192
+
193
+ if audio_np.ndim == 2:
194
+ mono = audio_np.mean(axis=1)
195
+ else:
196
+ mono = audio_np
197
+
198
+ if len(mono) < window:
199
+ return audio_np
200
+
201
+ # Find silent regions
202
+ energy = np.array(
203
+ [
204
+ np.abs(mono[i : i + window]).max()
205
+ for i in range(0, len(mono) - window, window)
206
+ ]
207
+ )
208
+ is_silent = energy < threshold
209
+
210
+ # Build list of (start_sample, end_sample) for silence regions
211
+ silence_regions = []
212
+ in_silence = False
213
+ start = 0
214
+ for i, silent in enumerate(is_silent):
215
+ if silent and not in_silence:
216
+ start = i * window
217
+ in_silence = True
218
+ elif not silent and in_silence:
219
+ end = i * window
220
+ if end - start > max_samples:
221
+ silence_regions.append((start, end))
222
+ in_silence = False
223
+ if in_silence:
224
+ end = len(mono)
225
+ if end - start > max_samples:
226
+ silence_regions.append((start, end))
227
+
228
+ if not silence_regions:
229
+ return audio_np
230
+
231
+ # Build output by keeping non-silence and shortening long silence
232
+ parts = []
233
+ prev_end = 0
234
+ for s_start, s_end in silence_regions:
235
+ # Keep audio before this silence
236
+ parts.append(audio_np[prev_end:s_start])
237
+ # Add shortened silence (target_duration worth)
238
+ parts.append(audio_np[s_start : s_start + target_samples])
239
+ prev_end = s_end
240
+
241
+ # Keep remaining audio after last silence
242
+ parts.append(audio_np[prev_end:])
243
+
244
+ result = np.concatenate(parts, axis=0)
245
+ shortened = (len(audio_np) - len(result)) / sr
246
+ if shortened > 0:
247
+ logger.info(
248
+ "Shortened %d silence regions, removed %.1fs",
249
+ len(silence_regions),
250
+ shortened,
251
+ )
252
+ return result
253
+
254
+
255
+ def ensure_stereo(audio_np: np.ndarray) -> np.ndarray:
256
+ """Convert mono to stereo by duplicating the channel.
257
+
258
+ Args:
259
+ audio_np: Audio samples, shape (samples,) for mono or (samples, 2) for stereo.
260
+
261
+ Returns:
262
+ Stereo audio array, shape (samples, 2).
263
+ """
264
+ if audio_np.ndim == 1:
265
+ return np.stack([audio_np, audio_np], axis=-1)
266
+ return audio_np
src/audio_core/chunker.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """Text chunking and duration estimation for Scenema Audio.
6
+
7
+ Splits long text into chunks at sentence boundaries using Kokoro TTS
8
+ phoneme-level timing as the source of truth for duration. No word counting.
9
+
10
+ Algorithm:
11
+ 1. Split text into sentences
12
+ 2. Estimate each sentence's duration via Kokoro (one call per sentence)
13
+ 3. Greedily merge: accumulate sentence durations, start a new chunk
14
+ when running_sum * LTX_MULTIPLIER exceeds MAX_CHUNK_DURATION_S
15
+ """
16
+
17
+ import logging
18
+ import random
19
+ from dataclasses import dataclass
20
+
21
+ from .compiler import compile_chunk_prompt, compile_prompt, extract_sentence_actions
22
+ from .validator import validate_prompt
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ FALLBACK_WORDS_PER_SEC = 2.2 # Test-environment-only fallback when Kokoro is mocked
27
+ ACTION_DURATION_S = 1.5 # Extra time per action block
28
+ MAX_CHUNK_DURATION_S = (
29
+ 15.0 # Safe generation limit β€” model trained on 20s but repeats beyond ~15s
30
+ )
31
+ LTX_MULTIPLIER = 1.5 # LTX speaks slower than Kokoro; overshoot for trimming
32
+
33
+ # Kokoro singleton (loaded once, reused)
34
+ _kokoro_pipeline = None
35
+ _kokoro_available: bool | None = None
36
+
37
+
38
+ def _get_kokoro():
39
+ """Get or initialize the Kokoro TTS pipeline for duration estimation.
40
+
41
+ Kokoro is 82M params, runs on CPU. Loaded once and cached.
42
+ Falls back to word-count heuristic only in test environments.
43
+ """
44
+ global _kokoro_pipeline, _kokoro_available
45
+
46
+ if _kokoro_available is False:
47
+ return None
48
+
49
+ if _kokoro_pipeline is not None:
50
+ return _kokoro_pipeline
51
+
52
+ try:
53
+ from kokoro import KPipeline
54
+
55
+ pipe = KPipeline(lang_code="a")
56
+ # Verify it's a real Kokoro pipeline (not a mock in tests)
57
+ if not hasattr(pipe, "__module__") or "kokoro" not in str(
58
+ getattr(pipe, "__module__", "")
59
+ ):
60
+ raise TypeError("Kokoro pipeline is not genuine (test mock)")
61
+ _kokoro_pipeline = pipe
62
+ _kokoro_available = True
63
+ logger.info("Kokoro TTS loaded for duration estimation")
64
+ return _kokoro_pipeline
65
+ except TypeError:
66
+ # Test environment with mocks, fall back silently
67
+ _kokoro_available = False
68
+ return None
69
+ except (ImportError, Exception) as e:
70
+ _kokoro_available = False
71
+ logger.error("Kokoro is required but not available: %s", e)
72
+ raise RuntimeError(
73
+ f"Kokoro TTS is a required dependency for duration estimation. "
74
+ f"Install it with: pip install kokoro. Error: {e}"
75
+ ) from e
76
+
77
+
78
+ def _kokoro_duration(text: str) -> float | None:
79
+ """Estimate speech duration using Kokoro TTS phoneme-level timing.
80
+
81
+ Args:
82
+ text: Speech text to estimate duration for
83
+
84
+ Returns:
85
+ Duration in seconds, or None if Kokoro unavailable
86
+ """
87
+ pipe = _get_kokoro()
88
+ if pipe is None:
89
+ return None
90
+
91
+ try:
92
+ total_frames = 0
93
+ for result in pipe(text, voice="af_heart"):
94
+ if hasattr(result, "audio") and result.audio is not None:
95
+ total_frames += len(result.audio)
96
+
97
+ # Kokoro outputs at 24000Hz
98
+ duration = total_frames / 24000.0
99
+ return duration
100
+ except Exception as e:
101
+ logger.warning("Kokoro estimation failed: %s", e)
102
+ return None
103
+
104
+
105
+ @dataclass
106
+ class ChunkSpec:
107
+ compiled_prompt: str
108
+ duration_s: float
109
+ seed: int
110
+ expected_text: str
111
+ language: str = "en"
112
+
113
+
114
+ def _split_into_sentences(text: str) -> list[str]:
115
+ """Split text into individual sentences at .!? boundaries."""
116
+ sentences = []
117
+ current = ""
118
+ for char in text:
119
+ current += char
120
+ if char in ".!?":
121
+ stripped = current.strip()
122
+ if stripped:
123
+ sentences.append(stripped)
124
+ current = ""
125
+ if current.strip():
126
+ sentences.append(current.strip())
127
+ return sentences
128
+
129
+
130
+ def _estimate_sentence_durations(sentences: list[str]) -> list[float]:
131
+ """Estimate Kokoro duration for each sentence individually.
132
+
133
+ One Kokoro call per sentence. Returns raw Kokoro durations (before
134
+ LTX multiplier). Falls back to word-count heuristic per sentence
135
+ only in test environments where Kokoro is mocked.
136
+ """
137
+ durations = []
138
+ for sent in sentences:
139
+ dur = _kokoro_duration(sent)
140
+ if dur is None:
141
+ # Test environment fallback only
142
+ dur = len(sent.split()) / FALLBACK_WORDS_PER_SEC + 0.3
143
+ durations.append(dur)
144
+ return durations
145
+
146
+
147
+ def split_text_by_duration(
148
+ text: str,
149
+ multiplier: float = LTX_MULTIPLIER,
150
+ max_duration: float = MAX_CHUNK_DURATION_S,
151
+ ) -> list[tuple[str, float]]:
152
+ """Split text into chunks using Kokoro duration estimation.
153
+
154
+ Kokoro is the source of truth for duration. No word counting.
155
+
156
+ Algorithm:
157
+ 1. Split text into sentences
158
+ 2. Estimate each sentence's duration via Kokoro (one call per sentence)
159
+ 3. Greedily merge: accumulate durations, start a new chunk when
160
+ running_sum * multiplier would exceed max_duration
161
+
162
+ Duration is additive across sentences because Kokoro estimates are
163
+ phoneme-level with no cross-sentence dependencies.
164
+
165
+ Args:
166
+ text: Full speech text.
167
+ multiplier: LTX speaks slower than Kokoro; applied to estimates.
168
+ max_duration: Max audio duration per chunk (model training limit).
169
+
170
+ Returns:
171
+ List of (chunk_text, estimated_ltx_duration) tuples.
172
+ """
173
+ sentences = _split_into_sentences(text)
174
+ if not sentences:
175
+ return []
176
+
177
+ # Split long sentences at commas if they exceed max_duration on their own
178
+ expanded = []
179
+ for sent in sentences:
180
+ dur = _estimate_sentence_durations([sent])[0]
181
+ if dur * multiplier > max_duration and "," in sent:
182
+ # Split at commas and re-estimate
183
+ clauses = [c.strip() for c in sent.split(",") if c.strip()]
184
+ clause_durs = _estimate_sentence_durations(clauses)
185
+ sub_texts: list[str] = []
186
+ sub_dur = 0.0
187
+ for clause, cdur in zip(clauses, clause_durs):
188
+ if sub_texts and (sub_dur + cdur) * multiplier > max_duration:
189
+ expanded.append(", ".join(sub_texts))
190
+ sub_texts = []
191
+ sub_dur = 0.0
192
+ sub_texts.append(clause)
193
+ sub_dur += cdur
194
+ if sub_texts:
195
+ expanded.append(", ".join(sub_texts))
196
+ else:
197
+ expanded.append(sent)
198
+
199
+ durations = _estimate_sentence_durations(expanded)
200
+
201
+ chunks: list[tuple[str, float]] = []
202
+ current_texts: list[str] = []
203
+ current_dur = 0.0
204
+
205
+ for sent, dur in zip(expanded, durations):
206
+ if current_texts and (current_dur + dur) * multiplier > max_duration:
207
+ chunk_text = " ".join(current_texts)
208
+ chunks.append((chunk_text, min(current_dur * multiplier, max_duration)))
209
+ current_texts = []
210
+ current_dur = 0.0
211
+
212
+ current_texts.append(sent)
213
+ current_dur += dur
214
+
215
+ if current_texts:
216
+ chunk_text = " ".join(current_texts)
217
+ chunks.append((chunk_text, min(current_dur * multiplier, max_duration)))
218
+
219
+ return chunks
220
+
221
+
222
+ def estimate_duration(
223
+ text: str,
224
+ num_actions: int = 0,
225
+ multiplier: float = LTX_MULTIPLIER,
226
+ ) -> float:
227
+ """Estimate audio duration for a single chunk of text.
228
+
229
+ Used for single-chunk prompts that don't need splitting.
230
+
231
+ Args:
232
+ text: Speech text (no actions)
233
+ num_actions: Number of action blocks (adds time for breaths/pauses)
234
+ multiplier: Duration multiplier (LTX speaks slower than Kokoro)
235
+ """
236
+ kokoro_dur = _kokoro_duration(text)
237
+
238
+ if kokoro_dur is not None:
239
+ base_duration = kokoro_dur
240
+ logger.debug("Kokoro estimate: %.1fs for '%s'", kokoro_dur, text[:40])
241
+ else:
242
+ words = len(text.split())
243
+ base_duration = words / FALLBACK_WORDS_PER_SEC + 0.5
244
+
245
+ action_time = num_actions * ACTION_DURATION_S
246
+ duration = (base_duration + action_time) * multiplier
247
+ return min(duration, MAX_CHUNK_DURATION_S)
248
+
249
+
250
+ def plan_chunks(
251
+ xml_string: str,
252
+ base_seed: int = -1,
253
+ pace: float = LTX_MULTIPLIER,
254
+ ) -> list[ChunkSpec]:
255
+ """Plan generation chunks from an XML prompt.
256
+
257
+ Validates XML, extracts text, splits into duration-based chunks
258
+ using Kokoro, and builds per-chunk compiled prompts.
259
+
260
+ Args:
261
+ xml_string: Valid <speak> XML string
262
+ base_seed: Base seed (-1 for random, otherwise sequential per chunk)
263
+ pace: Duration multiplier (default 1.5). Higher = slower speech.
264
+ """
265
+ result = validate_prompt(xml_string)
266
+ if not result.valid:
267
+ raise ValueError(f"Invalid prompt: {'; '.join(result.errors)}")
268
+
269
+ compiled = compile_prompt(xml_string)
270
+
271
+ if base_seed == -1:
272
+ base_seed = random.randint(0, 999999)
273
+
274
+ # Check if entire text fits in a single chunk (uncapped duration for this check)
275
+ kokoro_dur = _kokoro_duration(compiled.speech_text)
276
+ if kokoro_dur is not None:
277
+ total_dur = kokoro_dur * pace
278
+ else:
279
+ words = len(compiled.speech_text.split())
280
+ total_dur = (words / FALLBACK_WORDS_PER_SEC + 0.5) * pace
281
+
282
+ if total_dur <= MAX_CHUNK_DURATION_S:
283
+ return [
284
+ ChunkSpec(
285
+ compiled_prompt=compiled.prompt,
286
+ duration_s=min(total_dur, MAX_CHUNK_DURATION_S),
287
+ seed=base_seed,
288
+ expected_text=compiled.speech_text,
289
+ language=compiled.language,
290
+ )
291
+ ]
292
+
293
+ # Extract action-to-sentence mapping before splitting
294
+ sentence_action_map = extract_sentence_actions(xml_string)
295
+
296
+ # Split by Kokoro-estimated duration
297
+ text_chunks = split_text_by_duration(compiled.speech_text, multiplier=pace)
298
+
299
+ # Track which global sentence index each chunk starts at
300
+ global_sentence_idx = 0
301
+
302
+ specs: list[ChunkSpec] = []
303
+ for i, (chunk_text, chunk_dur) in enumerate(text_chunks):
304
+ # Find actions that belong to this chunk's first sentence
305
+ actions_before = sentence_action_map.get(global_sentence_idx)
306
+
307
+ chunk_prompt = compile_chunk_prompt(
308
+ speech_text=chunk_text,
309
+ voice=compiled.voice,
310
+ scene=compiled.scene,
311
+ actions_before=actions_before,
312
+ gender=compiled.gender,
313
+ shot=compiled.shot,
314
+ )
315
+ specs.append(
316
+ ChunkSpec(
317
+ compiled_prompt=chunk_prompt,
318
+ duration_s=chunk_dur,
319
+ seed=base_seed + i * 1000,
320
+ expected_text=chunk_text,
321
+ language=compiled.language,
322
+ )
323
+ )
324
+
325
+ # Count sentences in this chunk to advance global index
326
+ chunk_sentences = _split_into_sentences(chunk_text)
327
+ global_sentence_idx += len(chunk_sentences)
328
+
329
+ logger.info(
330
+ "Planned %d chunks (%.1fs total estimated)",
331
+ len(specs),
332
+ sum(s.duration_s for s in specs),
333
+ )
334
+ return specs
src/audio_core/compiler.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """XML prompt compiler for Scenema Audio.
6
+
7
+ Compiles a <speak> XML prompt into the video-style flat text prompt
8
+ that the LTX 2.3 audio model expects.
9
+
10
+ Supports three block types inside <speak>:
11
+ <action> β€” delivery/performance cues (how the person speaks/acts)
12
+ <sound> β€” audio events that should be heard (SFX, ambient sounds)
13
+ Text β€” the actual speech content
14
+
15
+ And three shot modes via the shot attribute:
16
+ closeup (default) β€” speech-focused, no SFX, clean audio
17
+ wide β€” environment + speech, SFX prominent
18
+ scene β€” raw scene description, maximum SFX
19
+
20
+ Example (closeup mode):
21
+ Input:
22
+ <speak voice="Deep male voice" scene="A dimly lit room" gender="male">
23
+ <action>He takes a slow breath</action>
24
+ Many years later, as he faced the firing squad...
25
+ </speak>
26
+
27
+ Output:
28
+ Close-up in a dimly lit room. He takes a slow breath.
29
+ "Many years later, as he faced the firing squad..."
30
+ Deep male voice.
31
+
32
+ Example (scene mode with SFX):
33
+ Input:
34
+ <speak voice="Tense male whisper" scene="Dark room, heavy rain"
35
+ gender="male" shot="scene">
36
+ <sound>A phone rings twice then stops</sound>
37
+ <action>He picks up the receiver and speaks in a low whisper</action>
38
+ Its done. The package is at the location.
39
+ <sound>Thunder rumbles in the distance</sound>
40
+ <action>He continues urgently</action>
41
+ You have thirty minutes.
42
+ </speak>
43
+
44
+ Output:
45
+ Dark room, heavy rain. A phone rings twice then stops.
46
+ He picks up the receiver and speaks in a low whisper:
47
+ "Its done. The package is at the location."
48
+ Thunder rumbles in the distance. He continues urgently:
49
+ "You have thirty minutes."
50
+ Tense male whisper. Dark room, heavy rain.
51
+ """
52
+
53
+ import xml.etree.ElementTree as ET
54
+ from dataclasses import dataclass
55
+
56
+ DEFAULT_SCENE = "a person speaking to camera"
57
+
58
+
59
+ @dataclass
60
+ class CompiledPrompt:
61
+ prompt: str
62
+ speech_text: str
63
+ voice: str
64
+ scene: str | None
65
+ language: str
66
+ gender: str
67
+ shot: str
68
+
69
+
70
+ @dataclass
71
+ class TextBlock:
72
+ text: str
73
+
74
+
75
+ @dataclass
76
+ class ActionBlock:
77
+ text: str
78
+
79
+
80
+ @dataclass
81
+ class SoundBlock:
82
+ text: str
83
+
84
+
85
+ Block = TextBlock | ActionBlock | SoundBlock
86
+
87
+
88
+ def _extract_blocks(root: ET.Element) -> list[Block]:
89
+ """Walk <speak> children in document order, extract text, action, and sound blocks."""
90
+ blocks: list[Block] = []
91
+
92
+ if root.text and root.text.strip():
93
+ blocks.append(TextBlock(text=root.text.strip()))
94
+
95
+ for child in root:
96
+ if child.tag == "action" and child.text and child.text.strip():
97
+ blocks.append(ActionBlock(text=child.text.strip()))
98
+ elif child.tag == "sound" and child.text and child.text.strip():
99
+ blocks.append(SoundBlock(text=child.text.strip()))
100
+ if child.tail and child.tail.strip():
101
+ blocks.append(TextBlock(text=child.tail.strip()))
102
+
103
+ return blocks
104
+
105
+
106
+ def _ensure_trailing_punctuation(text: str) -> str:
107
+ """Ensure text ends with sentence-ending punctuation."""
108
+ if text and text[-1] not in ".!?\"'":
109
+ return text + "."
110
+ return text
111
+
112
+
113
+ SHOT_PREFIXES = {
114
+ "closeup": "Close-up in",
115
+ "wide": "Wide shot of",
116
+ "scene": "",
117
+ }
118
+
119
+
120
+ def _compile_blocks(
121
+ blocks: list[Block],
122
+ voice: str,
123
+ scene: str | None,
124
+ gender: str = "male",
125
+ shot: str = "closeup",
126
+ ) -> str:
127
+ """Compile blocks into the video-style prompt string."""
128
+ parts: list[str] = []
129
+ is_scene_mode = shot in ("scene", "wide")
130
+ pronoun = "She" if gender == "female" else "He"
131
+
132
+ scene_text = scene if scene else DEFAULT_SCENE
133
+ prefix = SHOT_PREFIXES.get(shot, SHOT_PREFIXES["closeup"])
134
+ if prefix:
135
+ parts.append(f"{prefix} {scene_text}.")
136
+ else:
137
+ parts.append(f"{scene_text}.")
138
+
139
+ first_speech = True
140
+ for block in blocks:
141
+ if isinstance(block, SoundBlock):
142
+ # Sound events compile as standalone sentences
143
+ parts.append(_ensure_trailing_punctuation(block.text))
144
+ elif isinstance(block, ActionBlock):
145
+ if is_scene_mode:
146
+ # In scene/wide mode, action flows into speech with connector
147
+ # Don't add punctuation β€” the colon before the quote handles it
148
+ parts.append(block.text + ":")
149
+ else:
150
+ # In closeup mode, action is a standalone sentence
151
+ parts.append(_ensure_trailing_punctuation(block.text))
152
+ elif isinstance(block, TextBlock):
153
+ clean_text = _ensure_trailing_punctuation(block.text)
154
+ if (
155
+ is_scene_mode
156
+ and first_speech
157
+ and not any(isinstance(b, ActionBlock) for b in blocks)
158
+ ):
159
+ # No action before first speech in scene mode β€” add pronoun
160
+ parts.append(f'{pronoun} speaks: "{clean_text}"')
161
+ else:
162
+ parts.append(f'"{clean_text}"')
163
+ first_speech = False
164
+
165
+ parts.append(_ensure_trailing_punctuation(voice))
166
+
167
+ # In scene/wide mode, repeat scene as SFX reinforcement at the end
168
+ if is_scene_mode and scene:
169
+ parts.append(_ensure_trailing_punctuation(scene))
170
+
171
+ return " ".join(parts)
172
+
173
+
174
+ def _extract_speech_only(blocks: list[Block]) -> str:
175
+ """Extract only speech text (no actions or sounds) for duration estimation."""
176
+ texts = [b.text for b in blocks if isinstance(b, TextBlock)]
177
+ return " ".join(texts)
178
+
179
+
180
+ def compile_prompt(xml_string: str) -> CompiledPrompt:
181
+ """Compile a <speak> XML prompt into a video-style text prompt.
182
+
183
+ Args:
184
+ xml_string: Valid <speak> XML string (must pass validate_prompt first)
185
+
186
+ Returns:
187
+ CompiledPrompt with the compiled prompt and extracted metadata
188
+ """
189
+ root = ET.fromstring(xml_string)
190
+
191
+ voice = root.get("voice", "").strip()
192
+ scene = root.get("scene")
193
+ if scene:
194
+ scene = scene.strip()
195
+ language = root.get("language", "en").strip()
196
+ gender = root.get("gender", "male").strip()
197
+ shot = root.get("shot", "closeup").strip()
198
+
199
+ blocks = _extract_blocks(root)
200
+ prompt = _compile_blocks(blocks, voice, scene, gender, shot)
201
+ speech_text = _extract_speech_only(blocks)
202
+
203
+ return CompiledPrompt(
204
+ prompt=prompt,
205
+ speech_text=speech_text,
206
+ voice=voice,
207
+ scene=scene,
208
+ language=language,
209
+ gender=gender,
210
+ shot=shot,
211
+ )
212
+
213
+
214
+ def extract_sentence_actions(xml_string: str) -> dict[int, list[str]]:
215
+ """Map sentence indices to their preceding action blocks.
216
+
217
+ Walks the XML blocks in order, tracking the most recent action(s).
218
+ When a text block is encountered, its sentences inherit the pending actions.
219
+ Only the first sentence of each text block gets the actions (the action
220
+ precedes the text block in the XML).
221
+
222
+ Returns:
223
+ Dict mapping sentence index (0-based across all speech text) to a list
224
+ of action strings that precede that sentence.
225
+ """
226
+ root = ET.fromstring(xml_string)
227
+ blocks = _extract_blocks(root)
228
+
229
+ sentence_actions: dict[int, list[str]] = {}
230
+ pending_actions: list[str] = []
231
+ sentence_idx = 0
232
+
233
+ for block in blocks:
234
+ if isinstance(block, ActionBlock):
235
+ pending_actions.append(block.text)
236
+ elif isinstance(block, TextBlock):
237
+ # Split this text block into sentences to count them
238
+ text = block.text.strip()
239
+ sentences = []
240
+ current = ""
241
+ for char in text:
242
+ current += char
243
+ if char in ".!?":
244
+ s = current.strip()
245
+ if s:
246
+ sentences.append(s)
247
+ current = ""
248
+ if current.strip():
249
+ sentences.append(current.strip())
250
+
251
+ if pending_actions and sentences:
252
+ sentence_actions[sentence_idx] = pending_actions.copy()
253
+ pending_actions.clear()
254
+
255
+ sentence_idx += len(sentences)
256
+
257
+ return sentence_actions
258
+
259
+
260
+ def extract_speech_text(xml_string: str) -> str:
261
+ """Extract only the speech text from XML, ignoring actions and sounds.
262
+
263
+ Useful for duration estimation (Kokoro) without compiling the full prompt.
264
+ """
265
+ root = ET.fromstring(xml_string)
266
+ blocks = _extract_blocks(root)
267
+ return _extract_speech_only(blocks)
268
+
269
+
270
+ def compile_chunk_prompt(
271
+ speech_text: str,
272
+ voice: str,
273
+ scene: str | None = None,
274
+ actions_before: list[str] | None = None,
275
+ actions_after: list[str] | None = None,
276
+ gender: str = "male",
277
+ shot: str = "closeup",
278
+ ) -> str:
279
+ """Compile a single chunk's prompt from pre-split text.
280
+
281
+ Used by the chunker to build per-chunk prompts after text splitting.
282
+
283
+ Args:
284
+ speech_text: The chunk's speech text portion.
285
+ voice: Voice description string.
286
+ scene: Scene description string (optional).
287
+ actions_before: Action blocks to prepend before speech.
288
+ actions_after: Action blocks to append after speech.
289
+
290
+ Returns:
291
+ Compiled video-style prompt string.
292
+ """
293
+ blocks: list[Block] = []
294
+
295
+ if actions_before:
296
+ for a in actions_before:
297
+ blocks.append(ActionBlock(text=a))
298
+
299
+ blocks.append(TextBlock(text=speech_text))
300
+
301
+ if actions_after:
302
+ for a in actions_after:
303
+ blocks.append(ActionBlock(text=a))
304
+
305
+ return _compile_blocks(blocks, voice, scene, gender, shot)
src/audio_core/engine.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """Audio generation engine for Scenema Audio.
6
+
7
+ Loads the LTX 2.3 audio-only checkpoint, Audio VAE encoder, and
8
+ Gemma 3 12B text encoder. VRAM management is auto-detected: models
9
+ are moved between GPU and CPU as needed per inference phase.
10
+ """
11
+
12
+ import gc
13
+ import json
14
+ import logging
15
+ import os
16
+ import time
17
+ from contextlib import contextmanager
18
+ from dataclasses import dataclass, replace as dc_replace
19
+
20
+ import numpy as np
21
+ import psutil
22
+ import torch
23
+ import torchaudio
24
+ from safetensors import safe_open
25
+ from safetensors.torch import load_file
26
+
27
+ from ltx_core.batch_split import BatchSplitAdapter, BatchedPerturbationConfig
28
+ from ltx_core.components.diffusion_steps import EulerDiffusionStep
29
+ from ltx_core.components.noisers import GaussianNoiser
30
+ from ltx_core.components.patchifiers import AudioPatchifier, VideoLatentPatchifier
31
+ from ltx_core.model.audio_vae.audio_vae import Audio, encode_audio
32
+ from ltx_core.model.audio_vae.model_configurator import AudioEncoderConfigurator
33
+ from ltx_core.model.transformer.model import X0Model
34
+ from ltx_core.model.transformer.model_configurator import LTXModelConfigurator
35
+ from ltx_core.model.transformer.transformer import BasicAVTransformerBlock, rms_norm
36
+ from ltx_core.tools import AudioLatentTools, LatentState, VideoLatentTools
37
+ from ltx_core.types import AudioLatentShape, VideoLatentShape, VideoPixelShape
38
+ from ltx_pipelines.distilled import DISTILLED_SIGMAS, DistilledPipeline
39
+ from ltx_pipelines.utils.blocks import ModalitySpec, _build_state
40
+ from ltx_pipelines.utils.denoisers import SimpleDenoiser
41
+ from ltx_pipelines.utils.samplers import euler_denoising_loop
42
+ from ltx_pipelines.utils.types import OffloadMode
43
+ from ltx_core.text_encoders.gemma.tokenizer import LTXVGemmaTokenizer
44
+ import bitsandbytes # noqa: F401
45
+ from transformers import BitsAndBytesConfig, Gemma3ForConditionalGeneration
46
+
47
+ from .audio_utils import extract_wav
48
+
49
+ logger = logging.getLogger(__name__)
50
+
51
+ FPS = 24
52
+ MAX_REF_SECONDS = 5
53
+
54
+
55
+ class _Int8Linear(torch.nn.Module):
56
+ """Linear layer with INT8 weights, dequantized to input dtype during forward.
57
+
58
+ Keeps weights as int8 buffers in VRAM (~50% of bf16). Dequantization
59
+ happens per forward pass: weight = int8 * scale, then cast to input dtype.
60
+ Ported from bench_full_quantized.py.
61
+ """
62
+
63
+ def __init__(self, weight_int8, scale, bias=None):
64
+ super().__init__()
65
+ self.register_buffer("weight_int8", weight_int8)
66
+ self.register_buffer("scale", scale)
67
+ if bias is not None:
68
+ self.register_parameter("bias", torch.nn.Parameter(bias))
69
+ else:
70
+ self.bias = None
71
+
72
+ def forward(self, x):
73
+ w = self.weight_int8.float() * self.scale.unsqueeze(1)
74
+ w = w.to(x.dtype)
75
+ return torch.nn.functional.linear(x, w, self.bias)
76
+
77
+
78
+ # VRAM threshold: cards with this much VRAM keep all models GPU-resident
79
+ # (Gemma bf16 on GPU, no offloading, MelBandRoFormer + SeedVC preloaded).
80
+ # Below this: Gemma streams from CPU, models load/unload per request.
81
+ HIGH_VRAM_THRESHOLD_GB = 40
82
+
83
+
84
+ @dataclass
85
+ class AudioResult:
86
+ waveform_np: np.ndarray # (samples,) or (samples, channels) float32
87
+ sample_rate: int
88
+ duration_s: float
89
+
90
+
91
+ def _materialize_meta_tensors(module, device="cpu"):
92
+ """Replace meta tensors with zeros on the specified device."""
93
+ for name, param in list(module.named_parameters()):
94
+ if param.is_meta:
95
+ parts = name.split(".")
96
+ mod = module
97
+ for p in parts[:-1]:
98
+ mod = getattr(mod, p)
99
+ mod._parameters[parts[-1]] = torch.nn.Parameter(
100
+ torch.zeros(param.shape, dtype=torch.bfloat16, device=device)
101
+ )
102
+ for name, buf in list(module.named_buffers()):
103
+ if buf.is_meta:
104
+ parts = name.split(".")
105
+ mod = module
106
+ for p in parts[:-1]:
107
+ mod = getattr(mod, p)
108
+ mod._buffers[parts[-1]] = torch.zeros(
109
+ buf.shape, dtype=torch.bfloat16, device=device
110
+ )
111
+
112
+
113
+ def _audio_only_forward(self, video, audio, perturbations=None):
114
+ """Monkey-patched forward for audio-only transformer blocks.
115
+
116
+ Skips all video computation (attn1, attn2, ff, audio_to_video_attn)
117
+ and only runs audio self-attention, cross-attention, and feedforward.
118
+ """
119
+ if video is None and audio is None:
120
+ raise ValueError("Need at least one modality")
121
+ batch_size = (video or audio).x.shape[0]
122
+ if perturbations is None:
123
+ perturbations = BatchedPerturbationConfig.empty(batch_size)
124
+ vx = video.x if video is not None else None
125
+ ax = audio.x if audio is not None else None
126
+ run_ax = audio is not None and audio.enabled and ax.numel() > 0
127
+ if run_ax:
128
+ ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
129
+ self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
130
+ )
131
+ norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
132
+ del ashift_msa, ascale_msa
133
+ ax = (
134
+ ax
135
+ + self.audio_attn1(
136
+ norm_ax, pe=audio.positional_embeddings, mask=audio.self_attention_mask
137
+ )
138
+ * agate_msa
139
+ )
140
+ del agate_msa, norm_ax
141
+ ax = ax + self._apply_text_cross_attention(
142
+ ax,
143
+ audio.context,
144
+ self.audio_attn2,
145
+ self.audio_scale_shift_table,
146
+ getattr(self, "audio_prompt_scale_shift_table", None),
147
+ audio.timesteps,
148
+ audio.prompt_timestep,
149
+ audio.context_mask,
150
+ cross_attention_adaln=self.cross_attention_adaln,
151
+ )
152
+ ashift_ff, ascale_ff, agate_ff = self.get_ada_values(
153
+ self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
154
+ )
155
+ norm_ax_ff = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_ff) + ashift_ff
156
+ del ashift_ff, ascale_ff
157
+ ax = ax + self.audio_ff(norm_ax_ff) * agate_ff
158
+ del agate_ff, norm_ax_ff
159
+ if video is not None:
160
+ object.__setattr__(video, "x", vx)
161
+ if audio is not None:
162
+ object.__setattr__(audio, "x", ax)
163
+ return video, audio
164
+
165
+
166
+ # ── VRAM Manager ────────────────────────────────────────────────────────
167
+
168
+
169
+ class VRAMManager:
170
+ """Manages model placement between GPU and CPU based on available VRAM.
171
+
172
+ Tracks which models are on GPU and moves them as needed per inference phase.
173
+ Offloading is determined by comparing total registered model size against
174
+ available VRAM. If all models fit, no offloading occurs.
175
+ """
176
+
177
+ def __init__(self, vram_gb: float):
178
+ self.vram_gb = vram_gb
179
+ self._models: dict[str, torch.nn.Module] = {}
180
+ self._model_sizes: dict[str, float] = {} # GB per model
181
+ self._on_gpu: set[str] = set()
182
+ self.needs_offload = False # Determined after all models registered
183
+
184
+ def register(self, name: str, model: torch.nn.Module, on_gpu: bool = True) -> None:
185
+ """Register a model for VRAM management.
186
+
187
+ Args:
188
+ name: Identifier for the model.
189
+ model: The PyTorch module.
190
+ on_gpu: Whether the model is currently on GPU.
191
+ """
192
+ self._models[name] = model
193
+ size_gb = sum(p.numel() * p.element_size() for p in model.parameters()) / 1e9
194
+ self._model_sizes[name] = size_gb
195
+ if on_gpu:
196
+ self._on_gpu.add(name)
197
+
198
+ def finalize(self) -> None:
199
+ """Determine offloading strategy based on total model size vs VRAM.
200
+
201
+ Call after all models are registered. Sets needs_offload based on
202
+ whether all registered models fit in VRAM simultaneously with
203
+ headroom for activations and pipeline overhead (~5GB).
204
+ """
205
+ total_model_gb = sum(self._model_sizes.values())
206
+ # Gemma overhead depends on quantization mode:
207
+ # bf16 streaming: ~16GB peak (13GB Gemma + 2GB embeddings + 1GB safety)
208
+ # NF4: ~11GB peak (8GB NF4 model on GPU + 2GB embeddings + 1GB safety)
209
+ gemma_nf4 = os.environ.get("GEMMA_QUANTIZE", "").lower() == "nf4"
210
+ gemma_overhead_gb = 11.0 if gemma_nf4 else 16.0
211
+ self.needs_offload = (total_model_gb + gemma_overhead_gb) > self.vram_gb
212
+ logger.info(
213
+ "VRAM strategy: %.1f GB models + %.1f GB Gemma overhead (%s) vs %.1f GB VRAM -> offload=%s",
214
+ total_model_gb,
215
+ gemma_overhead_gb,
216
+ "nf4" if gemma_nf4 else "bf16",
217
+ self.vram_gb,
218
+ "yes" if self.needs_offload else "no",
219
+ )
220
+
221
+ def to_gpu(self, *names: str) -> None:
222
+ """Move specified models to GPU, offloading others if needed.
223
+
224
+ If offloading is required (VRAM < 40GB), all models NOT in the
225
+ requested set are moved to CPU first to free VRAM.
226
+
227
+ Args:
228
+ names: Model names that should be on GPU for the current phase.
229
+ """
230
+ if not self.needs_offload:
231
+ # High VRAM: just ensure requested models are on GPU
232
+ for name in names:
233
+ if name not in self._on_gpu and name in self._models:
234
+ self._models[name].cuda()
235
+ self._on_gpu.add(name)
236
+ return
237
+
238
+ # Offload models that shouldn't be on GPU
239
+ needed = set(names)
240
+ to_offload = self._on_gpu - needed
241
+ for name in to_offload:
242
+ if name in self._models:
243
+ self._models[name].cpu()
244
+ self._on_gpu.discard(name)
245
+ logger.debug("Offloaded %s to CPU", name)
246
+
247
+ torch.cuda.empty_cache()
248
+
249
+ # Load requested models to GPU
250
+ for name in names:
251
+ if name not in self._on_gpu and name in self._models:
252
+ self._models[name].cuda()
253
+ self._on_gpu.add(name)
254
+ logger.debug("Loaded %s to GPU", name)
255
+
256
+ def free_all(self) -> None:
257
+ """Move all models to CPU."""
258
+ for name in list(self._on_gpu):
259
+ if name in self._models:
260
+ self._models[name].cpu()
261
+ self._on_gpu.clear()
262
+ torch.cuda.empty_cache()
263
+
264
+ @contextmanager
265
+ def phase(self, *names: str):
266
+ """Context manager for a VRAM phase.
267
+
268
+ Ensures specified models are on GPU for the duration, then
269
+ returns to previous state on exit.
270
+
271
+ Args:
272
+ names: Model names needed on GPU for this phase.
273
+ """
274
+ prev_on_gpu = set(self._on_gpu)
275
+ self.to_gpu(*names)
276
+ try:
277
+ yield
278
+ finally:
279
+ # Restore previous state only if offloading is needed
280
+ if self.needs_offload:
281
+ to_restore = prev_on_gpu - set(names)
282
+ to_remove = set(names) - prev_on_gpu
283
+ for name in to_remove:
284
+ if name in self._models and name in self._on_gpu:
285
+ self._models[name].cpu()
286
+ self._on_gpu.discard(name)
287
+ for name in to_restore:
288
+ if name in self._models and name not in self._on_gpu:
289
+ self._models[name].cuda()
290
+ self._on_gpu.add(name)
291
+ torch.cuda.empty_cache()
292
+
293
+
294
+ # ── Audio Engine ────────────────────────────────────────────────────────
295
+
296
+
297
+ class AudioEngine:
298
+ """LTX 2.3 audio-only generation engine.
299
+
300
+ Loads the baked audio checkpoint, Audio VAE encoder, and Gemma 3 12B
301
+ text encoder. VRAM is managed automatically per inference phase.
302
+ """
303
+
304
+ def __init__(
305
+ self,
306
+ audio_ckpt_path: str,
307
+ vae_encoder_path: str,
308
+ gemma_root: str,
309
+ pipeline_ckpt_path: str | None = None,
310
+ ):
311
+ """Initialize AudioEngine.
312
+
313
+ Args:
314
+ audio_ckpt_path: Path to the audio-only transformer checkpoint.
315
+ vae_encoder_path: Path to the standalone Audio VAE encoder checkpoint.
316
+ gemma_root: Path to the Gemma 3 12B model directory.
317
+ pipeline_ckpt_path: Path to checkpoint for DistilledPipeline.
318
+ """
319
+ self.audio_ckpt_path = audio_ckpt_path
320
+ self.vae_encoder_path = vae_encoder_path
321
+ self.gemma_root = gemma_root
322
+ self.pipeline_ckpt_path = pipeline_ckpt_path or audio_ckpt_path
323
+
324
+ self._config = None
325
+ self._mdl_wrapper = None
326
+ self._audio_encoder = None
327
+ self._pipeline = None
328
+ self._vram: VRAMManager | None = None
329
+ self._vae_sr = None
330
+ self._loaded = False
331
+
332
+ @property
333
+ def vae_sample_rate(self) -> int:
334
+ return self._vae_sr or 16000
335
+
336
+ def load(self) -> None:
337
+ """Load all models. Call once at startup."""
338
+ if self._loaded:
339
+ return
340
+
341
+ vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
342
+ ram_gb = psutil.virtual_memory().total / 1e9
343
+ logger.info(
344
+ "System: %.1f GB VRAM, %.1f GB RAM, GPU: %s",
345
+ vram_gb,
346
+ ram_gb,
347
+ torch.cuda.get_device_name(0),
348
+ )
349
+
350
+ if vram_gb < 11:
351
+ raise RuntimeError(
352
+ f"Insufficient VRAM: {vram_gb:.0f}GB. Minimum 11GB required."
353
+ )
354
+ if ram_gb < 24:
355
+ raise RuntimeError(
356
+ f"Insufficient RAM: {ram_gb:.0f}GB. Minimum 24GB required."
357
+ )
358
+
359
+ self._vram = VRAMManager(vram_gb)
360
+
361
+ self._load_audio_model()
362
+ self._load_vae_encoder()
363
+ self._patch_transformer_blocks()
364
+ self._build_pipeline()
365
+
366
+ # Determine offloading strategy based on actual model sizes vs VRAM
367
+ self._vram.finalize()
368
+
369
+ self._loaded = True
370
+ logger.info("AudioEngine loaded")
371
+
372
+ def _load_audio_model(self) -> None:
373
+ """Load the audio-only checkpoint to GPU.
374
+
375
+ Supports both bf16 and INT8 quantized checkpoints. INT8 checkpoints
376
+ store weights as .weight.int8 (int8) + .weight.scale (float32) pairs.
377
+ For INT8, nn.Linear layers are replaced with Int8Linear modules that
378
+ keep weights quantized in VRAM (~5GB vs 9.8GB) and dequantize during
379
+ the forward pass.
380
+ """
381
+ t0 = time.time()
382
+
383
+ with safe_open(self.audio_ckpt_path, framework="pt") as f:
384
+ self._config = json.loads(f.metadata()["config"])
385
+
386
+ with torch.device("meta"):
387
+ mdl = LTXModelConfigurator.from_config(self._config)
388
+
389
+ sd = load_file(self.audio_ckpt_path, device="cpu")
390
+
391
+ # Detect INT8 checkpoint format
392
+ int8_map = {
393
+ k.replace(".weight.int8", ""): k for k in sd if k.endswith(".weight.int8")
394
+ }
395
+ scale_map = {
396
+ k.replace(".weight.scale", ""): k for k in sd if k.endswith(".weight.scale")
397
+ }
398
+ is_int8 = len(int8_map) > 0
399
+
400
+ if is_int8:
401
+ # Load only non-quantized keys first (biases, norms, embeddings)
402
+ regular_sd = {
403
+ k: v
404
+ for k, v in sd.items()
405
+ if not k.endswith(".int8") and not k.endswith(".scale")
406
+ }
407
+ mdl_wrapper = X0Model(mdl)
408
+ mdl_wrapper.load_state_dict(regular_sd, strict=False, assign=True)
409
+
410
+ # Replace nn.Linear with Int8Linear for quantized weights
411
+ n_replaced = 0
412
+ for name in int8_map:
413
+ w_int8 = sd[int8_map[name]]
414
+ w_scale = sd[scale_map[name]]
415
+ parts = name.split(".")
416
+ parent = mdl_wrapper
417
+ for p in parts[:-1]:
418
+ parent = getattr(parent, p)
419
+ old = getattr(parent, parts[-1])
420
+ bias_key = name + ".bias"
421
+ bias = sd.get(bias_key)
422
+ if bias is None and hasattr(old, "bias") and old.bias is not None:
423
+ bias = old.bias.data
424
+ setattr(parent, parts[-1], _Int8Linear(w_int8, w_scale, bias))
425
+ n_replaced += 1
426
+
427
+ logger.info("INT8: replaced %d Linear layers with Int8Linear", n_replaced)
428
+ else:
429
+ mdl_wrapper = X0Model(mdl)
430
+ mdl_wrapper.load_state_dict(sd, strict=False, assign=True)
431
+
432
+ # Runtime INT8 quantization via BnB (bf16 checkpoint β†’ INT8 on GPU)
433
+ if os.environ.get("TRANSFORMER_QUANTIZE", "").lower() == "int8":
434
+ import bitsandbytes as bnb
435
+
436
+ n_quantized = 0
437
+ for name, module in list(mdl_wrapper.named_modules()):
438
+ for cn, child in list(module.named_children()):
439
+ if (
440
+ isinstance(child, torch.nn.Linear)
441
+ and child.weight.numel() > 1_000_000
442
+ ):
443
+ int8_layer = bnb.nn.Linear8bitLt(
444
+ child.in_features,
445
+ child.out_features,
446
+ bias=child.bias is not None,
447
+ has_fp16_weights=False,
448
+ )
449
+ int8_layer.weight = bnb.nn.Int8Params(
450
+ child.weight.data,
451
+ requires_grad=False,
452
+ has_fp16_weights=False,
453
+ )
454
+ if child.bias is not None:
455
+ int8_layer.bias = child.bias
456
+ setattr(module, cn, int8_layer)
457
+ n_quantized += 1
458
+ logger.info(
459
+ "Runtime INT8: quantized %d Linear layers via BnB", n_quantized
460
+ )
461
+
462
+ del sd
463
+ gc.collect()
464
+
465
+ for block in mdl.transformer_blocks:
466
+ block.attn1 = torch.nn.Identity()
467
+ block.attn2 = torch.nn.Identity()
468
+ block.ff = torch.nn.Identity()
469
+ block.audio_to_video_attn = torch.nn.Identity()
470
+ gc.collect()
471
+
472
+ _materialize_meta_tensors(mdl_wrapper)
473
+
474
+ cross_pe = max(
475
+ mdl.positional_embedding_max_pos[0],
476
+ mdl.audio_positional_embedding_max_pos[0],
477
+ )
478
+ mdl._init_preprocessors(cross_pe)
479
+
480
+ self._mdl_wrapper = mdl_wrapper.cuda().eval()
481
+ self._vram.register("audio_model", self._mdl_wrapper, on_gpu=True)
482
+
483
+ logger.info(
484
+ "Audio model loaded: %.1f GB, %.1fs",
485
+ torch.cuda.memory_allocated() / 1e9,
486
+ time.time() - t0,
487
+ )
488
+
489
+ def _load_vae_encoder(self) -> None:
490
+ """Load Audio VAE encoder from standalone checkpoint."""
491
+ t0 = time.time()
492
+ avae_cfg = self._config["audio_vae"]
493
+ preproc = avae_cfg["preprocessing"]
494
+ self._vae_sr = preproc["audio"]["sampling_rate"]
495
+
496
+ with torch.device("meta"):
497
+ encoder = AudioEncoderConfigurator().from_config(avae_cfg)
498
+
499
+ sd = load_file(self.vae_encoder_path, device="cpu")
500
+ encoder.load_state_dict(sd, strict=False, assign=True)
501
+
502
+ pcs = encoder.per_channel_statistics
503
+ if "per_channel_statistics.std-of-means" in sd:
504
+ pcs._buffers["std-of-means"] = sd["per_channel_statistics.std-of-means"]
505
+ pcs._buffers["mean-of-means"] = sd["per_channel_statistics.mean-of-means"]
506
+ del sd
507
+
508
+ dd = avae_cfg["model"]["params"]["ddconfig"]
509
+ encoder.mel_bins = dd["mel_bins"]
510
+ encoder.mid.attn_1 = torch.nn.Identity()
511
+
512
+ _materialize_meta_tensors(encoder, device="cpu")
513
+
514
+ self._audio_encoder = encoder.cuda().eval().to(torch.bfloat16)
515
+ self._vram.register("vae_encoder", self._audio_encoder, on_gpu=True)
516
+
517
+ logger.info(
518
+ "Audio VAE encoder loaded: %.1fM params, %.1fs",
519
+ sum(p.numel() for p in self._audio_encoder.parameters()) / 1e6,
520
+ time.time() - t0,
521
+ )
522
+
523
+ def _patch_transformer_blocks(self) -> None:
524
+ """Monkey-patch transformer blocks for audio-only forward pass."""
525
+ BasicAVTransformerBlock.forward = _audio_only_forward
526
+ logger.info("Transformer blocks patched for audio-only forward")
527
+
528
+ def _build_pipeline(self) -> None:
529
+ """Build DistilledPipeline and cache Gemma + embeddings processor in CPU RAM.
530
+
531
+ Caching eliminates the ~35s rebuild cost on every encode call.
532
+ Gemma stays in CPU RAM permanently, streams to GPU layer-by-layer.
533
+ Embeddings processor shuttles between CPU and GPU per call.
534
+ """
535
+ t0 = time.time()
536
+ mdl_wrapper = self._mdl_wrapper
537
+
538
+ # Use NONE offload when VRAM is sufficient so Gemma stays GPU-resident
539
+ # for fast encoding (~0.5s vs ~7s streaming). Fall back to CPU streaming
540
+ # on smaller cards.
541
+ offload = (
542
+ OffloadMode.NONE
543
+ if self._vram.vram_gb >= HIGH_VRAM_THRESHOLD_GB
544
+ else OffloadMode.CPU
545
+ )
546
+ self._pipeline = DistilledPipeline(
547
+ distilled_checkpoint_path=self.pipeline_ckpt_path,
548
+ gemma_root=self.gemma_root,
549
+ spatial_upsampler_path=None,
550
+ loras=[],
551
+ offload_mode=offload,
552
+ )
553
+
554
+ @contextmanager
555
+ def _gpu_ctx(**kw):
556
+ yield mdl_wrapper
557
+
558
+ self._pipeline.stage._transformer_ctx = _gpu_ctx
559
+
560
+ pe = self._pipeline.prompt_encoder
561
+
562
+ # Gemma loading strategy:
563
+ # NF4: BitsAndBytes int4 quantization (~8GB on GPU, ~0.1s encode)
564
+ # bf16 GPU: full precision on GPU (~24GB, ~1-2s encode) β€” when VRAM >= 40GB
565
+ # bf16 streaming: streams from CPU RAM layer-by-layer (~7s encode) β€” when VRAM < 40GB
566
+ self._gemma_nf4 = os.environ.get("GEMMA_QUANTIZE", "").lower() == "nf4"
567
+ self._gemma_on_gpu = False
568
+
569
+ if self._gemma_nf4:
570
+ self._build_nf4_gemma()
571
+ # NF4 needs its own embeddings processor and tokenizer
572
+ self._cached_emb_proc = pe._embeddings_processor_builder.build(
573
+ device="cuda",
574
+ dtype=torch.bfloat16,
575
+ ).eval()
576
+ self._cached_tokenizer = LTXVGemmaTokenizer(self.gemma_root)
577
+ logger.info("Embeddings processor cached on CUDA (NF4 mode)")
578
+ elif self._vram.vram_gb >= HIGH_VRAM_THRESHOLD_GB:
579
+ # Build pipeline's text encoder ONCE on GPU and keep it resident.
580
+ # This uses the same builder as pipeline.prompt_encoder but
581
+ # avoids the build/destroy cycle that makes each call ~30s.
582
+ t_gemma = time.time()
583
+ self._resident_text_encoder = pe._text_encoder_builder.build(
584
+ device=torch.device("cuda"),
585
+ dtype=torch.bfloat16,
586
+ ).eval()
587
+ self._cached_emb_proc = pe._embeddings_processor_builder.build(
588
+ device="cuda",
589
+ dtype=torch.bfloat16,
590
+ ).eval()
591
+ self._gemma_on_gpu = True
592
+ vram_gb = torch.cuda.memory_allocated() / (1024**3)
593
+ logger.info(
594
+ "Gemma bf16 (pipeline encoder) GPU-resident: %.1fGB VRAM, %.1fs",
595
+ vram_gb,
596
+ time.time() - t_gemma,
597
+ )
598
+ else:
599
+ # Low VRAM: pipeline.prompt_encoder streams from CPU (~7s/encode)
600
+ logger.info("Gemma managed by pipeline prompt_encoder (CPU streaming)")
601
+
602
+ logger.info("Pipeline built: %.1fs", time.time() - t0)
603
+
604
+ def _build_nf4_gemma(self) -> None:
605
+ """Load Gemma 3 12B with BitsAndBytes NF4 quantization (~8GB on GPU).
606
+
607
+ NF4 Gemma stays on GPU permanently. Encode is near-instant (~0.1s)
608
+ since there's no CPU->GPU streaming. Slight quality tradeoff vs bf16
609
+ but acceptable for production use.
610
+ """
611
+ t0 = time.time()
612
+ quant_config = BitsAndBytesConfig(
613
+ load_in_4bit=True,
614
+ bnb_4bit_compute_dtype=torch.bfloat16,
615
+ bnb_4bit_quant_type="nf4",
616
+ )
617
+ self._nf4_gemma_model = Gemma3ForConditionalGeneration.from_pretrained(
618
+ self.gemma_root,
619
+ quantization_config=quant_config,
620
+ device_map="cuda",
621
+ dtype=torch.bfloat16,
622
+ ).eval()
623
+
624
+ # No streaming text encoder needed β€” _cached_text_encoder stays None
625
+ self._cached_text_encoder = None
626
+
627
+ vram_gb = torch.cuda.memory_allocated() / (1024**3)
628
+ logger.info(
629
+ "Gemma NF4 loaded on GPU: %.1fGB VRAM, %.1fs", vram_gb, time.time() - t0
630
+ )
631
+
632
+ def _build_bf16_gemma_gpu(self) -> None:
633
+ """Load Gemma 3 12B bf16 directly on GPU (~24GB).
634
+
635
+ For cards with >= 40GB VRAM. Gemma stays on GPU permanently.
636
+ Encode is ~1-2s (pure inference, no CPU->GPU streaming).
637
+ """
638
+ t0 = time.time()
639
+ self._nf4_gemma_model = Gemma3ForConditionalGeneration.from_pretrained(
640
+ self.gemma_root,
641
+ device_map="cuda",
642
+ torch_dtype=torch.bfloat16,
643
+ ).eval()
644
+
645
+ self._cached_text_encoder = None
646
+ self._gemma_on_gpu = True
647
+
648
+ vram_gb = torch.cuda.memory_allocated() / (1024**3)
649
+ logger.info(
650
+ "Gemma bf16 loaded on GPU: %.1fGB VRAM, %.1fs", vram_gb, time.time() - t0
651
+ )
652
+
653
+ def unload(self) -> None:
654
+ """Free all GPU and CPU memory."""
655
+ if self._vram:
656
+ self._vram.free_all()
657
+ if (
658
+ hasattr(self, "_cached_text_encoder")
659
+ and self._cached_text_encoder is not None
660
+ ):
661
+ self._cached_text_encoder.teardown()
662
+ self._cached_text_encoder = None
663
+ if hasattr(self, "_nf4_gemma_model"):
664
+ del self._nf4_gemma_model
665
+ self._nf4_gemma_model = None
666
+ if hasattr(self, "_cached_emb_proc"):
667
+ self._cached_emb_proc = None
668
+ if hasattr(self, "_cached_tokenizer"):
669
+ self._cached_tokenizer = None
670
+ self._mdl_wrapper = None
671
+ self._audio_encoder = None
672
+ self._pipeline = None
673
+ self._vram = None
674
+ self._loaded = False
675
+ gc.collect()
676
+ torch.cuda.empty_cache()
677
+ logger.info("AudioEngine unloaded")
678
+
679
+ def encode_text(self, prompt: str):
680
+ """Encode text prompt via Gemma 3 12B.
681
+
682
+ Uses the pipeline's PromptEncoder which builds Gemma through
683
+ the LTX-native builder. This ensures identical encoding to the
684
+ reference pipeline (critical for SFX generation quality).
685
+
686
+ Falls back to NF4/bf16 GPU-resident Gemma when available for speed,
687
+ but routes through the pipeline encoder for correctness.
688
+
689
+ Args:
690
+ prompt: Compiled video-style text prompt.
691
+
692
+ Returns:
693
+ Tuple of (video_context, audio_context) tensors for denoising.
694
+ """
695
+ t0 = time.time()
696
+ with torch.inference_mode():
697
+ if self._gemma_nf4:
698
+ # NF4: use BitsAndBytes quantized Gemma (fast, ~0.1s)
699
+ tp = self._cached_tokenizer.tokenize_with_weights(prompt)["gemma"]
700
+ ids = torch.tensor([[t[0] for t in tp]], device="cuda")
701
+ mask = torch.tensor([[w[1] for w in tp]], device="cuda")
702
+ out = self._nf4_gemma_model.model(
703
+ input_ids=ids,
704
+ attention_mask=mask,
705
+ output_hidden_states=True,
706
+ )
707
+ hs = out.hidden_states
708
+ am = mask
709
+ del out, ids
710
+ emb = self._cached_emb_proc.process_hidden_states(hs, am)
711
+ vc = emb.video_encoding
712
+ ac = emb.audio_encoding
713
+ del hs, am, emb
714
+ elif self._gemma_on_gpu:
715
+ # bf16 GPU-resident: use pipeline's text encoder (fast, ~0.5s)
716
+ hs, am = self._resident_text_encoder.encode(prompt)
717
+ emb = self._cached_emb_proc.process_hidden_states(hs, am)
718
+ vc = emb.video_encoding
719
+ ac = emb.audio_encoding
720
+ del hs, am, emb
721
+ else:
722
+ # CPU streaming: use pipeline's prompt encoder (~7s)
723
+ (emb,) = self._pipeline.prompt_encoder([prompt])
724
+ vc = emb.video_encoding
725
+ ac = emb.audio_encoding
726
+
727
+ logger.info("Gemma encode: %.1fs", time.time() - t0)
728
+ return vc, ac
729
+
730
+ def encode_reference(self, waveform_np: np.ndarray, sample_rate: int):
731
+ """Encode reference audio to latent via Audio VAE encoder.
732
+
733
+ Args:
734
+ waveform_np: Audio samples, shape (samples,) or (samples, channels).
735
+ sample_rate: Sample rate of the input audio in Hz.
736
+
737
+ Returns:
738
+ Reference latent tensor [B, C, T, F] on GPU.
739
+ """
740
+ # Ensure VAE encoder is on GPU
741
+ self._vram.to_gpu("vae_encoder")
742
+
743
+ if waveform_np.ndim == 1:
744
+ waveform_np = np.stack([waveform_np, waveform_np], axis=-1)
745
+
746
+ if waveform_np.ndim == 2 and waveform_np.shape[1] == 2:
747
+ wav = torch.from_numpy(waveform_np.T).float()
748
+ else:
749
+ wav = torch.from_numpy(waveform_np).float()
750
+
751
+ if sample_rate != self._vae_sr:
752
+ wav = torchaudio.functional.resample(wav, sample_rate, self._vae_sr)
753
+
754
+ max_samples = MAX_REF_SECONDS * self._vae_sr
755
+ if wav.shape[1] > max_samples:
756
+ wav = wav[:, :max_samples]
757
+
758
+ audio_obj = Audio(waveform=wav.unsqueeze(0), sampling_rate=self._vae_sr)
759
+ with torch.inference_mode():
760
+ latent = encode_audio(audio_obj, self._audio_encoder)
761
+
762
+ logger.info("Reference encoded: %s", latent.shape)
763
+ return latent
764
+
765
+ def generate(
766
+ self,
767
+ vc,
768
+ ac,
769
+ duration: float,
770
+ seed: int,
771
+ ref_latent=None,
772
+ ) -> AudioResult:
773
+ """Generate audio with optional A2V reference conditioning.
774
+
775
+ Args:
776
+ vc: Video context from encode_text().
777
+ ac: Audio context from encode_text().
778
+ duration: Target duration in seconds.
779
+ seed: Random seed for reproducibility.
780
+ ref_latent: Optional reference latent from encode_reference()
781
+ for A2V voice conditioning.
782
+
783
+ Returns:
784
+ AudioResult with waveform numpy array and metadata.
785
+ """
786
+ return self._generate_impl(vc, ac, duration, seed, ref_latent)
787
+
788
+ @torch.inference_mode()
789
+ def _generate_impl(self, vc, ac, duration, seed, ref_latent=None):
790
+ # Ensure audio model is on GPU for denoising
791
+ self._vram.to_gpu("audio_model")
792
+
793
+ num_frames = ((int(duration * FPS) + 7) // 8) * 8 + 1
794
+ device = torch.device("cuda")
795
+
796
+ gen = torch.Generator(device=device).manual_seed(seed)
797
+ noiser = GaussianNoiser(generator=gen)
798
+ sigmas = DISTILLED_SIGMAS.to(dtype=torch.float32, device=device)
799
+
800
+ pixel_shape = VideoPixelShape(
801
+ batch=1, frames=num_frames, width=64, height=64, fps=FPS
802
+ )
803
+
804
+ v_shape = VideoLatentShape.from_pixel_shape(pixel_shape)
805
+ video_tools = VideoLatentTools(
806
+ VideoLatentPatchifier(patch_size=1), v_shape, fps=FPS
807
+ )
808
+ video_state = _build_state(
809
+ ModalitySpec(context=vc, conditionings=[]),
810
+ video_tools,
811
+ noiser,
812
+ torch.bfloat16,
813
+ device,
814
+ )
815
+
816
+ a_shape = AudioLatentShape.from_video_pixel_shape(pixel_shape)
817
+ audio_tools = AudioLatentTools(AudioPatchifier(patch_size=1), a_shape)
818
+ audio_state = _build_state(
819
+ ModalitySpec(context=ac),
820
+ audio_tools,
821
+ noiser,
822
+ torch.bfloat16,
823
+ device,
824
+ )
825
+
826
+ ref_frames = 0
827
+ if ref_latent is not None:
828
+ ref = ref_latent.to(device=device, dtype=torch.bfloat16)
829
+ ref_frames = ref.shape[2]
830
+ total_t = ref_frames + audio_state.latent.shape[1]
831
+
832
+ ref_patchified = ref.permute(0, 2, 1, 3).reshape(1, ref_frames, -1)
833
+ combined_latent = torch.cat([ref_patchified, audio_state.latent], dim=1)
834
+
835
+ ref_mask = torch.zeros(
836
+ 1, ref_frames, 1, device=device, dtype=audio_state.denoise_mask.dtype
837
+ )
838
+ combined_mask = torch.cat([ref_mask, audio_state.denoise_mask], dim=1)
839
+ combined_clean = torch.cat(
840
+ [ref_patchified, torch.zeros_like(audio_state.clean_latent)], dim=1
841
+ )
842
+
843
+ combined_a_shape = AudioLatentShape(
844
+ batch=1, channels=8, frames=total_t, mel_bins=16
845
+ )
846
+ combined_audio_tools = AudioLatentTools(
847
+ AudioPatchifier(patch_size=1), combined_a_shape
848
+ )
849
+ gen2 = torch.Generator(device=device).manual_seed(seed)
850
+ noiser2 = GaussianNoiser(generator=gen2)
851
+ tmp_state = _build_state(
852
+ ModalitySpec(context=ac),
853
+ combined_audio_tools,
854
+ noiser2,
855
+ torch.bfloat16,
856
+ device,
857
+ )
858
+ combined_positions = tmp_state.positions
859
+ del tmp_state
860
+
861
+ audio_state_final = LatentState(
862
+ latent=combined_latent,
863
+ denoise_mask=combined_mask,
864
+ positions=combined_positions,
865
+ clean_latent=combined_clean,
866
+ attention_mask=None,
867
+ )
868
+ else:
869
+ audio_state_final = audio_state
870
+
871
+ stepper = EulerDiffusionStep()
872
+ with self._pipeline.stage._transformer_ctx() as transformer:
873
+ wrapped = BatchSplitAdapter(transformer, max_batch_size=1)
874
+ t0 = time.time()
875
+ _, audio_state_out = euler_denoising_loop(
876
+ sigmas=sigmas,
877
+ video_state=video_state,
878
+ audio_state=audio_state_final,
879
+ stepper=stepper,
880
+ transformer=wrapped,
881
+ denoiser=SimpleDenoiser(vc, ac),
882
+ )
883
+ logger.debug("Denoise: %.2fs", time.time() - t0)
884
+
885
+ if ref_latent is not None and audio_state_out is not None and ref_frames > 0:
886
+ audio_state_out = dc_replace(
887
+ audio_state_out,
888
+ latent=audio_state_out.latent[:, ref_frames:],
889
+ denoise_mask=audio_state_out.denoise_mask[:, ref_frames:],
890
+ positions=audio_state_out.positions[:, :, ref_frames:],
891
+ clean_latent=(
892
+ audio_state_out.clean_latent[:, ref_frames:]
893
+ if audio_state_out.clean_latent is not None
894
+ else None
895
+ ),
896
+ )
897
+
898
+ audio_state_out = audio_tools.clear_conditioning(audio_state_out)
899
+ audio_state_out = audio_tools.unpatchify(audio_state_out)
900
+
901
+ if torch.isnan(audio_state_out.latent).any():
902
+ logger.warning("NaN detected in denoised latent")
903
+
904
+ # Offload audio model before VAE decode (pipeline handles decoder GPU usage)
905
+ self._vram.to_gpu()
906
+ audio = self._pipeline.audio_decoder(audio_state_out.latent)
907
+ # Restore audio model after decode
908
+ self._vram.to_gpu("audio_model")
909
+
910
+ w, sr = extract_wav(audio)
911
+ return AudioResult(waveform_np=w, sample_rate=sr, duration_s=w.shape[0] / sr)
src/audio_core/enhancer.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """VoiceFixer audio post-processing for Scenema Audio.
6
+
7
+ Applies neural speech restoration to improve clarity, remove artifacts,
8
+ and bring speech to studio quality. Runs on GPU after SeedVC as the
9
+ final processing step.
10
+
11
+ Model is downloaded on first use and cached to disk for subsequent runs.
12
+ """
13
+
14
+ import logging
15
+ import os
16
+ import subprocess
17
+ import sys
18
+ import tempfile
19
+
20
+ import numpy as np
21
+ import soundfile as sf
22
+ import torchaudio
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ _voicefixer = None
27
+
28
+
29
+ def _ensure_installed():
30
+ """Install voicefixer if not available."""
31
+ try:
32
+ import voicefixer # noqa: F401
33
+ except ImportError:
34
+ logger.info("Installing voicefixer...")
35
+ try:
36
+ subprocess.check_call(
37
+ [sys.executable, "-m", "pip", "install", "voicefixer", "--quiet"],
38
+ )
39
+ logger.info("voicefixer installed")
40
+ except subprocess.CalledProcessError:
41
+ logger.warning("Failed to install voicefixer, enhancement will be skipped")
42
+ raise ImportError("voicefixer not available")
43
+
44
+
45
+ def _get_voicefixer():
46
+ """Get or initialize the VoiceFixer model.
47
+
48
+ Downloaded on first use and cached by the library's default cache.
49
+ """
50
+ global _voicefixer
51
+
52
+ if _voicefixer is not None:
53
+ return _voicefixer
54
+
55
+ _ensure_installed()
56
+
57
+ from voicefixer import VoiceFixer # noqa: E402
58
+
59
+ _voicefixer = VoiceFixer()
60
+ logger.info("VoiceFixer model loaded")
61
+ return _voicefixer
62
+
63
+
64
+ def enhance_audio(audio_np: np.ndarray, sr: int) -> np.ndarray:
65
+ """Apply VoiceFixer to audio for studio-quality output.
66
+
67
+ VoiceFixer works on WAV files, so we write to temp, process, and read back.
68
+
69
+ Args:
70
+ audio_np: Audio array (mono or stereo), any sample rate.
71
+ sr: Sample rate.
72
+
73
+ Returns:
74
+ Enhanced audio array at original sample rate.
75
+ """
76
+ try:
77
+ vf = _get_voicefixer()
78
+ except (ImportError, Exception) as e:
79
+ logger.warning("VoiceFixer unavailable: %s, skipping", e)
80
+ return audio_np
81
+
82
+ is_stereo = audio_np.ndim == 2 and audio_np.shape[1] == 2
83
+
84
+ with tempfile.TemporaryDirectory() as tmp:
85
+ input_path = os.path.join(tmp, "input.wav")
86
+ output_path = os.path.join(tmp, "output.wav")
87
+
88
+ sf.write(input_path, audio_np, sr)
89
+
90
+ try:
91
+ vf.restore(
92
+ input=input_path,
93
+ output=output_path,
94
+ cuda=True,
95
+ mode=0, # 0=general, 1=speech-specific
96
+ )
97
+
98
+ enhanced, enhanced_sr = sf.read(output_path)
99
+
100
+ # Resample back to original sr if needed
101
+ if enhanced_sr != sr:
102
+ import torch
103
+
104
+ t = torch.from_numpy(
105
+ enhanced.T if enhanced.ndim == 2 else enhanced
106
+ ).float()
107
+ if t.ndim == 1:
108
+ t = t.unsqueeze(0)
109
+ t = torchaudio.functional.resample(t, enhanced_sr, sr)
110
+ enhanced = t.squeeze(0).numpy()
111
+ if enhanced.ndim == 1 and is_stereo:
112
+ enhanced = np.stack([enhanced, enhanced], axis=1)
113
+ elif enhanced.ndim == 2:
114
+ enhanced = enhanced.T
115
+
116
+ logger.info("Enhanced audio: %.1fs", len(enhanced) / sr)
117
+ return enhanced
118
+
119
+ except Exception as e:
120
+ logger.warning("VoiceFixer failed: %s, returning original", e)
121
+ return audio_np
src/audio_core/inference.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """Inference orchestration for Scenema Audio.
6
+
7
+ Generates audio for planned chunks with A2V voice conditioning between
8
+ chunks and concatenates the results. A2V reference from each chunk's tail
9
+ guides the next chunk toward a consistent voice, which SeedVC then
10
+ polishes for exact identity matching.
11
+ """
12
+
13
+ import logging
14
+
15
+ import numpy as np
16
+
17
+ from .audio_utils import normalize_volume, trim_silence
18
+ from .chunker import ChunkSpec
19
+ from .engine import AudioEngine, AudioResult
20
+ from .whisper_aligner import validate_text
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ REF_TAIL_SECONDS = 3.0
25
+ MAX_RETRIES = 3
26
+ RETRY_DURATION_FACTOR = 1.3
27
+ MIN_WORD_MATCH_RATIO = 0.90
28
+
29
+
30
+ def generate_chunks(
31
+ engine: AudioEngine,
32
+ chunks: list[ChunkSpec],
33
+ ref_latent=None,
34
+ ref_duration_s: float = REF_TAIL_SECONDS,
35
+ validate: bool = False,
36
+ min_match_ratio: float = MIN_WORD_MATCH_RATIO,
37
+ anchor_ref: bool = False,
38
+ ) -> list[AudioResult]:
39
+ """Generate audio for all chunks with A2V voice conditioning.
40
+
41
+ Each chunk gets its own Gemma encode (since each has different text).
42
+ The tail of each chunk's audio is encoded via Audio VAE and used as
43
+ A2V reference for the next chunk, guiding voice consistency. SeedVC
44
+ is applied afterward by the processor for exact identity matching.
45
+
46
+ Args:
47
+ engine: AudioEngine instance
48
+ chunks: List of ChunkSpec from plan_chunks()
49
+ ref_latent: Initial reference latent (from user-provided voice URL)
50
+ ref_duration_s: Seconds of tail audio to use as A2V reference
51
+ validate: If True, run Whisper validation with retry loop.
52
+ If False (default), generate once without validation.
53
+ anchor_ref: If True, every chunk uses ref_latent instead of
54
+ chaining from the previous chunk's tail. Keeps voice
55
+ anchored to the external reference.
56
+ """
57
+ results: list[AudioResult] = []
58
+
59
+ for i, chunk in enumerate(chunks):
60
+ label = "with ref" if ref_latent is not None else "no ref"
61
+ logger.info(
62
+ "Chunk %d/%d (%s, %.1fs): %s",
63
+ i + 1,
64
+ len(chunks),
65
+ label,
66
+ chunk.duration_s,
67
+ chunk.expected_text[:60] + ("..." if len(chunk.expected_text) > 60 else ""),
68
+ )
69
+
70
+ # Gemma encode once per chunk (reused across retries)
71
+ logger.info("Compiled prompt: %s", chunk.compiled_prompt)
72
+ vc, ac = engine.encode_text(chunk.compiled_prompt)
73
+
74
+ duration = chunk.duration_s
75
+ seed = chunk.seed
76
+
77
+ if not validate:
78
+ # Single generation, no whisper validation
79
+ result = engine.generate(vc, ac, duration, seed, ref_latent=ref_latent)
80
+ best_result = result
81
+ else:
82
+ # Validation retry loop with whisper
83
+ best_result = None
84
+ best_ratio = -1.0
85
+
86
+ for attempt in range(MAX_RETRIES + 1):
87
+ result = engine.generate(vc, ac, duration, seed, ref_latent=ref_latent)
88
+
89
+ passed, transcribed, ratio = validate_text(
90
+ result.waveform_np,
91
+ result.sample_rate,
92
+ chunk.expected_text,
93
+ language=chunk.language,
94
+ min_word_ratio=min_match_ratio,
95
+ )
96
+
97
+ if ratio > best_ratio:
98
+ best_result = result
99
+ best_ratio = ratio
100
+
101
+ if passed:
102
+ logger.info(
103
+ " Chunk %d validated: %.0f%% word match",
104
+ i + 1,
105
+ ratio * 100,
106
+ )
107
+ break
108
+
109
+ if attempt < MAX_RETRIES:
110
+ duration = min(duration * RETRY_DURATION_FACTOR, 20.0)
111
+ seed += 1
112
+ logger.info(
113
+ " Chunk %d retry %d: %.0f%% match, extending to %.1fs, seed=%d",
114
+ i + 1,
115
+ attempt + 1,
116
+ ratio * 100,
117
+ duration,
118
+ seed,
119
+ )
120
+ else:
121
+ logger.warning(
122
+ " Chunk %d: best %.0f%% match after %d retries, accepting",
123
+ i + 1,
124
+ best_ratio * 100,
125
+ MAX_RETRIES,
126
+ )
127
+
128
+ results.append(best_result)
129
+
130
+ # A2V: use tail of this chunk as reference for the next
131
+ # In anchor mode, keep using the original ref_latent for every chunk
132
+ if i < len(chunks) - 1 and not anchor_ref:
133
+ tail_samples = int(ref_duration_s * result.sample_rate)
134
+ tail_wav = result.waveform_np[-tail_samples:]
135
+ ref_latent = engine.encode_reference(tail_wav, result.sample_rate)
136
+
137
+ return results
138
+
139
+
140
+ def concatenate_chunks(
141
+ results: list[AudioResult],
142
+ trim: bool = True,
143
+ normalize: bool = True,
144
+ ) -> tuple[np.ndarray, int]:
145
+ """Concatenate audio chunks with silence trimming and volume normalization.
146
+
147
+ Trims excess silence from chunk boundaries and normalizes volume
148
+ per-chunk to ensure consistent loudness across the full output.
149
+ Chunks are hard-concatenated (no crossfade).
150
+
151
+ Args:
152
+ results: List of AudioResult from generate_chunks().
153
+ trim: Whether to trim silence from chunk boundaries.
154
+ normalize: Whether to normalize volume per chunk.
155
+
156
+ Returns:
157
+ Tuple of (concatenated waveform numpy array, sample_rate).
158
+ """
159
+ if not results:
160
+ raise ValueError("No chunks to concatenate")
161
+
162
+ sr = results[0].sample_rate
163
+ processed: list[np.ndarray] = []
164
+
165
+ for i, r in enumerate(results):
166
+ w = r.waveform_np
167
+ if trim:
168
+ w = trim_silence(w, sr, max_silence=0.5)
169
+ if normalize:
170
+ w = normalize_volume(w, sr)
171
+ processed.append(w)
172
+ logger.debug(
173
+ "Chunk %d: %.1fs -> %.1fs",
174
+ i,
175
+ r.duration_s,
176
+ w.shape[0] / sr,
177
+ )
178
+
179
+ result = np.concatenate(processed, axis=0)
180
+ logger.info(
181
+ "Concatenated: %.1fs from %d chunks", result.shape[0] / sr, len(processed)
182
+ )
183
+ return result, sr
src/audio_core/main.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """Scenema Audio entry point.
6
+
7
+ CRITICAL: CUDA memory config must happen before torch imports.
8
+ """
9
+
10
+ import os
11
+
12
+ if "expandable_segments" not in os.environ.get("PYTORCH_CUDA_ALLOC_CONF", ""):
13
+ _alloc = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
14
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
15
+ (_alloc + ",expandable_segments:True") if _alloc else "expandable_segments:True"
16
+ )
17
+
18
+ import logging
19
+
20
+ logging.basicConfig(
21
+ level=logging.DEBUG if os.environ.get("DEBUG") else logging.INFO,
22
+ format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
23
+ )
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def main():
28
+ # These imports are inside main() because CUDA config above
29
+ # must execute before torch is imported (processor -> engine -> torch)
30
+ from common.runner import run
31
+
32
+ from .processor import AudioProcessor
33
+
34
+ handler_mode = os.environ.get("HANDLER_MODE", "http")
35
+ logger.info("Starting Scenema Audio in %s mode", handler_mode)
36
+
37
+ processor = AudioProcessor()
38
+ run(processor, service_type="scenema_audio")
39
+
40
+
41
+ if __name__ == "__main__":
42
+ main()
src/audio_core/processor.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """Scenema Audio processor. Processor protocol implementation.
6
+
7
+ Handles HTTP sync/async requests for audio generation and voice design.
8
+ Follows the pattern of gpu_x2v/processor.py.
9
+ """
10
+
11
+ import io
12
+ import logging
13
+ import os
14
+ import random
15
+ import shutil
16
+ import tempfile
17
+ import time
18
+ from datetime import datetime, timezone
19
+
20
+ import httpx
21
+ import numpy as np
22
+ import psutil
23
+ import soundfile as sf
24
+ import torch
25
+ import torchaudio
26
+
27
+ from common.handlers.base import ProcessJob, ProcessOutput, ProcessResult
28
+
29
+ from .audio_utils import (
30
+ ensure_stereo,
31
+ load_wav,
32
+ normalize_volume,
33
+ shorten_long_silence,
34
+ save_wav,
35
+ to_mono,
36
+ trim_silence,
37
+ )
38
+ from .chunker import plan_chunks
39
+ from .compiler import compile_prompt
40
+ from .engine import AudioEngine, HIGH_VRAM_THRESHOLD_GB
41
+ from .inference import concatenate_chunks, generate_chunks
42
+ from .seedvc import SeedVC
43
+ from .validate_and_patch import validate_and_patch
44
+ from .validator import validate_prompt
45
+ from .vocal_separator import VocalSeparator
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ VOICE_DESIGN_DURATION_S = 15.0
50
+
51
+
52
+ class AudioProcessor:
53
+ """Processor for Scenema Audio generation.
54
+
55
+ Implements the Processor protocol (startup/shutdown/process).
56
+ """
57
+
58
+ def __init__(self):
59
+ self.engine: AudioEngine | None = None
60
+ self.vocal_separator = None
61
+ self.seedvc = None
62
+ self._http_client = None
63
+
64
+ def startup(self) -> None:
65
+ """Load models. Called once by handler at startup."""
66
+ if self.engine is not None:
67
+ return
68
+
69
+ audio_ckpt = os.environ.get(
70
+ "AUDIO_CKPT",
71
+ "/app/models/scenema-audio-transformer.safetensors",
72
+ )
73
+ vae_encoder = os.environ.get(
74
+ "VAE_ENCODER_CKPT",
75
+ "/app/models/scenema-audio-vae-encoder.safetensors",
76
+ )
77
+ gemma_root = os.environ.get(
78
+ "GEMMA_ROOT",
79
+ "/app/models/gemma-3-12b-it",
80
+ )
81
+ pipeline_ckpt = os.environ.get(
82
+ "PIPELINE_CKPT",
83
+ "/app/models/ltx-2.3-22b-distilled.safetensors",
84
+ )
85
+
86
+ self.engine = AudioEngine(
87
+ audio_ckpt_path=audio_ckpt,
88
+ vae_encoder_path=vae_encoder,
89
+ gemma_root=gemma_root,
90
+ pipeline_ckpt_path=pipeline_ckpt,
91
+ )
92
+ self.engine.load()
93
+
94
+ self.vocal_separator = VocalSeparator()
95
+ self.seedvc = SeedVC()
96
+
97
+ # Preload all models on high-VRAM cards (>= 40GB), keep resident
98
+ vram_gb = (
99
+ torch.cuda.get_device_properties(0).total_memory / 1e9
100
+ if torch.cuda.is_available()
101
+ else 0
102
+ )
103
+ self._keep_resident = vram_gb >= HIGH_VRAM_THRESHOLD_GB
104
+ if self._keep_resident:
105
+ self.vocal_separator.load()
106
+ self.seedvc.load()
107
+ logger.info("All models preloaded and resident (%.0fGB VRAM)", vram_gb)
108
+ else:
109
+ logger.info("Low VRAM (%.0fGB), models loaded on-demand", vram_gb)
110
+
111
+ logger.info("AudioProcessor ready")
112
+
113
+ def shutdown(self) -> None:
114
+ """Unload all models."""
115
+ if self.engine:
116
+ self.engine.unload()
117
+ self.engine = None
118
+ if self.vocal_separator:
119
+ self.vocal_separator.unload()
120
+ self.vocal_separator = None
121
+ if self.seedvc and self.seedvc._loaded:
122
+ self.seedvc.unload()
123
+ logger.info("AudioProcessor shutdown")
124
+
125
+ async def process(self, job: ProcessJob) -> ProcessResult:
126
+ """Process an audio generation job."""
127
+ start_time = time.time()
128
+ started_at = datetime.now(timezone.utc).isoformat()
129
+ torch.cuda.reset_peak_memory_stats()
130
+
131
+ try:
132
+ if self.engine is None:
133
+ self.startup()
134
+
135
+ config = self._parse_input(job)
136
+
137
+ if config["mode"] == "voice_design":
138
+ wav_np, sr = await self._voice_design(config)
139
+ else:
140
+ wav_np, sr = await self._generate(config)
141
+
142
+ wav_bytes = self._encode_wav(wav_np, sr)
143
+ processing_ms = int((time.time() - start_time) * 1000)
144
+
145
+ return ProcessResult(
146
+ job_id=job.job_id,
147
+ success=True,
148
+ output=ProcessOutput(
149
+ success=True,
150
+ data=wav_bytes,
151
+ content_type="audio/wav",
152
+ metadata=self._build_metadata(
153
+ config, wav_np, sr, processing_ms, started_at
154
+ ),
155
+ ),
156
+ processing_ms=processing_ms,
157
+ )
158
+ except Exception as e:
159
+ logger.error("Processing failed: %s", e, exc_info=True)
160
+ processing_ms = int((time.time() - start_time) * 1000)
161
+ return ProcessResult(
162
+ job_id=job.job_id,
163
+ success=False,
164
+ output=ProcessOutput(success=False, error=str(e)),
165
+ error=str(e),
166
+ processing_ms=processing_ms,
167
+ )
168
+
169
+ def _parse_input(self, job: ProcessJob) -> dict:
170
+ """Parse and validate job input.
171
+
172
+ Input schema:
173
+ prompt: str - Required. <speak> XML string.
174
+ mode: str - "generate" (default) or "voice_design".
175
+ reference_voice_url: str | None - URL to reference audio for voice cloning.
176
+ background_sfx: bool - Keep background SFX (default: false, strips via MelBandRoFormer).
177
+ validate: bool - Enable Whisper speech validation (default: false).
178
+ When true, each generated chunk is transcribed by faster-whisper
179
+ (GPU, float16, ~1GB VRAM) and compared against the expected text.
180
+ If word match ratio falls below 60%, the chunk is regenerated with
181
+ extended duration and a new seed (up to 3 retries), keeping the
182
+ best result. Adds <1s per chunk on GPU. When false, each chunk is
183
+ generated once with no quality gate, which is faster and sufficient
184
+ for most prompts.
185
+ seed: int - Base seed (-1 for random).
186
+ """
187
+ inp = job.input
188
+
189
+ prompt = inp.get("prompt")
190
+ if not prompt:
191
+ raise ValueError("Missing required 'prompt' field")
192
+
193
+ mode = inp.get("mode", "generate")
194
+ if mode not in ("generate", "voice_design"):
195
+ raise ValueError(
196
+ f"Invalid mode: {mode}. Must be 'generate' or 'voice_design'"
197
+ )
198
+
199
+ result = validate_prompt(prompt)
200
+ if not result.valid:
201
+ raise ValueError(f"Invalid prompt XML: {'; '.join(result.errors)}")
202
+
203
+ seed = inp.get("seed", -1)
204
+ if seed == -1:
205
+ seed = random.randint(0, 999999)
206
+
207
+ return {
208
+ "prompt": prompt,
209
+ "mode": mode,
210
+ "reference_voice_url": inp.get("reference_voice_url"),
211
+ "background_sfx": inp.get("background_sfx", False),
212
+ "validate": inp.get("validate", True),
213
+ "seed": seed,
214
+ "pace": inp.get("pace", 1.5),
215
+ "min_match_ratio": inp.get("min_match_ratio", 0.90),
216
+ "vc_cfg_rate": inp.get("vc_cfg_rate", 0.5),
217
+ "vc_steps": inp.get("vc_steps", 25),
218
+ "skip_vc": inp.get("skip_vc", False),
219
+ }
220
+
221
+ async def _voice_design(self, config: dict) -> tuple[np.ndarray, int]:
222
+ """Generate a 15s voice sample for voice design."""
223
+ compiled = compile_prompt(config["prompt"])
224
+ vc, ac = self.engine.encode_text(compiled.prompt)
225
+ result = self.engine.generate(vc, ac, VOICE_DESIGN_DURATION_S, config["seed"])
226
+
227
+ wav = result.waveform_np
228
+ sr = result.sample_rate
229
+
230
+ if not config["background_sfx"]:
231
+ wav = self._strip_background(wav, sr)
232
+
233
+ wav = trim_silence(wav, sr)
234
+ wav = shorten_long_silence(wav, sr)
235
+ wav = normalize_volume(wav, sr)
236
+
237
+ return wav, sr
238
+
239
+ async def _generate(self, config: dict) -> tuple[np.ndarray, int]:
240
+ """Full generation pipeline with chunking and post-processing."""
241
+ chunks = plan_chunks(
242
+ config["prompt"], base_seed=config["seed"], pace=config["pace"]
243
+ )
244
+ logger.info("Planned %d chunk(s)", len(chunks))
245
+
246
+ ref_wav_path = None
247
+ if config["reference_voice_url"]:
248
+ ref_wav_path = await self._download_reference(config["reference_voice_url"])
249
+
250
+ # skip_vc: seed every chunk with the reference audio's tail latent,
251
+ # identical to how inter-chunk chaining works. The model sees the
252
+ # reference as "what I just generated" and continues in that voice.
253
+ # Disables the normal chaining (each chunk chains from the ref, not
254
+ # from the previous chunk) to keep the voice anchored to the reference.
255
+ anchor_latent = None
256
+ if config["skip_vc"] and ref_wav_path:
257
+ ref_wav, ref_sr = load_wav(ref_wav_path)
258
+ ref_mono = to_mono(ref_wav)
259
+ tail_seconds = 3.0
260
+ tail_samples = int(tail_seconds * ref_sr)
261
+ if ref_mono.shape[0] > tail_samples:
262
+ ref_tail = ref_mono[-tail_samples:]
263
+ else:
264
+ ref_tail = ref_mono
265
+ anchor_latent = self.engine.encode_reference(ref_tail, ref_sr)
266
+ logger.info(
267
+ "Anchor mode: every chunk seeded from %.1fs reference tail",
268
+ ref_tail.shape[0] / ref_sr,
269
+ )
270
+
271
+ with torch.inference_mode():
272
+ results = generate_chunks(
273
+ self.engine,
274
+ chunks,
275
+ ref_latent=anchor_latent,
276
+ anchor_ref=anchor_latent is not None,
277
+ validate=config["validate"],
278
+ min_match_ratio=config["min_match_ratio"],
279
+ )
280
+
281
+ wav, sr = concatenate_chunks(results)
282
+
283
+ # Strip background music/SFX from the concatenated audio (single pass)
284
+ if not config["background_sfx"]:
285
+ wav = self._strip_background(wav, sr)
286
+
287
+ # Cap silence β€” scale with pace
288
+ max_silence = min(0.5 * config["pace"], 1.5)
289
+ wav = shorten_long_silence(
290
+ wav, sr, max_duration=max_silence, target_duration=max_silence * 0.6
291
+ )
292
+
293
+ # Apply SeedVC when: reference voice provided, or multiple chunks (voice consistency).
294
+ # Skip for single-chunk generations without reference (preserves SFX).
295
+ needs_vc = ref_wav_path or len(results) > 1
296
+ if not config["skip_vc"] and needs_vc:
297
+ wav = self._apply_seedvc(
298
+ wav,
299
+ sr,
300
+ results,
301
+ ref_wav_path,
302
+ vc_steps=config["vc_steps"],
303
+ vc_cfg_rate=config["vc_cfg_rate"],
304
+ )
305
+
306
+ # Post-SeedVC alignment trimming (disabled by default, needs refinement)
307
+ if config.get("patch", False):
308
+ expected_text = " ".join(c.expected_text for c in chunks)
309
+ wav = validate_and_patch(wav, sr, expected_text)
310
+
311
+ # Ensure stereo final output
312
+ wav = ensure_stereo(wav)
313
+
314
+ if ref_wav_path and os.path.exists(ref_wav_path):
315
+ os.unlink(ref_wav_path)
316
+
317
+ return wav, sr
318
+
319
+ def _strip_background(self, wav_np: np.ndarray, sr: int) -> np.ndarray:
320
+ """Strip background music/SFX using MelBandRoFormer.
321
+
322
+ Loads the model on-demand and unloads after to free VRAM.
323
+ """
324
+ if self.vocal_separator is None:
325
+ return wav_np
326
+
327
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
328
+ input_path = f.name
329
+ vocals_path = input_path.replace(".wav", "_vocals.wav")
330
+
331
+ try:
332
+ if not self._keep_resident:
333
+ self.vocal_separator.load()
334
+ stereo = ensure_stereo(wav_np)
335
+ save_wav(stereo, sr, input_path)
336
+ self.vocal_separator.separate(input_path, vocals_path, None)
337
+ vocals, _ = load_wav(vocals_path)
338
+ return vocals
339
+ except Exception as e:
340
+ logger.warning("Vocal separation failed: %s", e)
341
+ return wav_np
342
+ finally:
343
+ if not self._keep_resident:
344
+ self.vocal_separator.unload()
345
+ for p in [input_path, vocals_path]:
346
+ if os.path.exists(p):
347
+ os.unlink(p)
348
+
349
+ def _apply_seedvc(
350
+ self,
351
+ wav: np.ndarray,
352
+ sr: int,
353
+ chunk_results: list,
354
+ ref_wav_path: str | None,
355
+ vc_steps: int = 20,
356
+ vc_cfg_rate: float = 0.5,
357
+ ) -> np.ndarray:
358
+ """Apply SeedVC voice cloning.
359
+
360
+ If reference_voice_url provided: convert against reference.
361
+ If no reference: convert all against chunk 0 (first chunk sets identity).
362
+ """
363
+ if self.seedvc is None:
364
+ logger.info("SeedVC not available, skipping voice cloning")
365
+ return wav
366
+
367
+ try:
368
+ if not self._keep_resident:
369
+ self.seedvc.load()
370
+ with tempfile.TemporaryDirectory() as tmp:
371
+ source_path = os.path.join(tmp, "source_22k.wav")
372
+ target_path = os.path.join(tmp, "target_22k.wav")
373
+
374
+ source_mono = to_mono(wav)
375
+ source_t = torch.from_numpy(source_mono).float().unsqueeze(0)
376
+ source_22k = torchaudio.functional.resample(source_t, sr, 22050)
377
+ save_wav(source_22k.squeeze(0).numpy(), 22050, source_path)
378
+
379
+ if ref_wav_path:
380
+ target_wav, target_sr = load_wav(ref_wav_path)
381
+ target_mono = to_mono(target_wav)
382
+ target_t = torch.from_numpy(target_mono).float().unsqueeze(0)
383
+ target_22k = torchaudio.functional.resample(
384
+ target_t, target_sr, 22050
385
+ )
386
+ save_wav(target_22k.squeeze(0).numpy(), 22050, target_path)
387
+ else:
388
+ chunk0 = chunk_results[0].waveform_np
389
+ chunk0_mono = to_mono(chunk0)
390
+ chunk0_t = torch.from_numpy(chunk0_mono).float().unsqueeze(0)
391
+ chunk0_22k = torchaudio.functional.resample(
392
+ chunk0_t, chunk_results[0].sample_rate, 22050
393
+ )
394
+ save_wav(chunk0_22k.squeeze(0).numpy(), 22050, target_path)
395
+
396
+ converted = self.seedvc.convert(
397
+ source_path,
398
+ target_path,
399
+ diffusion_steps=vc_steps,
400
+ cfg_rate=vc_cfg_rate,
401
+ )
402
+
403
+ conv_t = torch.from_numpy(converted).float().unsqueeze(0)
404
+ result = torchaudio.functional.resample(conv_t, 22050, sr)
405
+ wav = result.squeeze(0).numpy()
406
+ wav = ensure_stereo(wav)
407
+
408
+ except Exception as e:
409
+ logger.error("SeedVC failed: %s", e, exc_info=True)
410
+ finally:
411
+ if not self._keep_resident:
412
+ try:
413
+ self.seedvc.unload()
414
+ except Exception:
415
+ pass
416
+
417
+ return wav
418
+
419
+ async def _download_reference(self, url: str) -> str:
420
+ """Download reference audio from URL to temp file."""
421
+ if self._http_client is None:
422
+ self._http_client = httpx.AsyncClient(timeout=60.0, follow_redirects=True)
423
+
424
+ response = await self._http_client.get(url)
425
+ response.raise_for_status()
426
+
427
+ suffix = ".wav"
428
+ if "mp3" in url.lower() or "mpeg" in response.headers.get("content-type", ""):
429
+ suffix = ".mp3"
430
+
431
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
432
+ f.write(response.content)
433
+ logger.info(
434
+ "Downloaded reference: %d bytes to %s", len(response.content), f.name
435
+ )
436
+ return f.name
437
+
438
+ def _encode_wav(self, wav_np: np.ndarray, sr: int) -> bytes:
439
+ """Encode numpy array to WAV bytes."""
440
+ buf = io.BytesIO()
441
+ sf.write(buf, wav_np, sr, format="WAV")
442
+ return buf.getvalue()
443
+
444
+ def _build_metadata(
445
+ self,
446
+ config: dict,
447
+ wav_np: np.ndarray,
448
+ sr: int,
449
+ processing_ms: int,
450
+ started_at: str = "",
451
+ ) -> dict:
452
+ """Build comprehensive metadata matching x2v pattern."""
453
+ gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A"
454
+ vram_total_mb = 0
455
+ vram_peak_mb = 0
456
+ if torch.cuda.is_available():
457
+ vram_total_mb = round(
458
+ torch.cuda.get_device_properties(0).total_memory / 1024**2
459
+ )
460
+ vram_peak_mb = round(torch.cuda.max_memory_allocated() / 1024**2)
461
+
462
+ cpu_cores_total = os.cpu_count() or 0
463
+ system_ram_gb = round(psutil.virtual_memory().total / 1024**3)
464
+ disk = shutil.disk_usage("/")
465
+
466
+ return {
467
+ "duration_s": round(wav_np.shape[0] / sr, 2),
468
+ "sample_rate": sr,
469
+ "mode": config["mode"],
470
+ "seed": config["seed"],
471
+ "background_sfx": config["background_sfx"],
472
+ "has_reference_voice": config["reference_voice_url"] is not None,
473
+ "validate": config["validate"],
474
+ "processing_ms": processing_ms,
475
+ "vram_peak_mb": vram_peak_mb,
476
+ "vram_total_mb": vram_total_mb,
477
+ "gpu": gpu_name,
478
+ "cpu_cores_total": cpu_cores_total,
479
+ "system_ram_gb": system_ram_gb,
480
+ "disk_total_gb": round(disk.total / 1024**3, 1),
481
+ "disk_free_gb": round(disk.free / 1024**3, 1),
482
+ "started_at": started_at,
483
+ "completed_at": datetime.now(timezone.utc).isoformat(),
484
+ }
src/audio_core/seedvc.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """SeedVC voice conversion for Scenema Audio.
6
+
7
+ Converts the voice identity of generated audio to match a reference speaker
8
+ while preserving prosody, rhythm, and emotion. Uses the Seed-VC model with
9
+ DiT backbone, CAMPPlus speaker encoder, and BigVGAN vocoder.
10
+
11
+ Expects 22050Hz mono WAV input for both source and target.
12
+ """
13
+
14
+ import inspect
15
+ import logging
16
+ import os
17
+ import sys
18
+ import types
19
+ from argparse import Namespace
20
+ from pathlib import Path
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ DEFAULT_SEEDVC_PATH = Path(os.environ.get("SEEDVC_PATH", "/app/seed-vc"))
28
+ DEFAULT_DIFFUSION_STEPS = 25
29
+ DEFAULT_CFG_RATE = 0.5
30
+
31
+
32
+ class SeedVC:
33
+ """Voice conversion engine using Seed-VC.
34
+
35
+ Converts source audio voice identity to match a target speaker
36
+ while preserving the source's delivery, emotion, and pacing.
37
+ """
38
+
39
+ def __init__(self, seedvc_path: Path = DEFAULT_SEEDVC_PATH):
40
+ self.seedvc_path = seedvc_path
41
+ self._loaded = False
42
+ self._original_cwd: str | None = None
43
+ self._app_vc = None
44
+
45
+ def load(self) -> None:
46
+ """Load SeedVC models to GPU.
47
+
48
+ Changes working directory to seedvc_path (required by SeedVC internals),
49
+ stubs gradio, and loads all models via app_vc.load_models().
50
+ """
51
+ if self._loaded:
52
+ return
53
+
54
+ logger.info("Loading SeedVC from %s", self.seedvc_path)
55
+
56
+ self._original_cwd = os.getcwd()
57
+ os.chdir(self.seedvc_path)
58
+
59
+ if "gradio" not in sys.modules:
60
+ sys.modules["gradio"] = types.ModuleType("gradio")
61
+
62
+ seedvc_str = str(self.seedvc_path)
63
+ if seedvc_str not in sys.path:
64
+ sys.path.insert(0, seedvc_str)
65
+
66
+ os.environ.setdefault(
67
+ "HF_HUB_CACHE",
68
+ str(self.seedvc_path / "checkpoints" / "hf_cache"),
69
+ )
70
+
71
+ # Patch BigVGAN for huggingface_hub compat (same as gpu_vc)
72
+ import modules.bigvgan.bigvgan as _bigvgan_mod
73
+
74
+ _orig = _bigvgan_mod.BigVGAN._from_pretrained
75
+
76
+ @classmethod
77
+ def _patched(cls, **kwargs):
78
+ kwargs.setdefault("proxies", None)
79
+ kwargs.setdefault("resume_download", False)
80
+ return _orig.__func__(cls, **kwargs)
81
+
82
+ _bigvgan_mod.BigVGAN._from_pretrained = _patched
83
+
84
+ # Load models (exact pattern from gpu_vc/seedvc_engine.py)
85
+ import app_vc
86
+
87
+ self._app_vc = app_vc
88
+ app_vc.device = torch.device("cuda")
89
+
90
+ args = Namespace(checkpoint=None, config=None, fp16=True, gpu=0)
91
+ (
92
+ app_vc.model,
93
+ app_vc.semantic_fn,
94
+ app_vc.vocoder_fn,
95
+ app_vc.campplus_model,
96
+ app_vc.to_mel,
97
+ app_vc.mel_fn_args,
98
+ ) = app_vc.load_models(args)
99
+
100
+ app_vc.max_context_window = app_vc.sr // app_vc.hop_length * 30
101
+ app_vc.overlap_wave_len = app_vc.overlap_frame_len * app_vc.hop_length
102
+
103
+ self._loaded = True
104
+ logger.info("SeedVC loaded: sr=%d, device=%s", app_vc.sr, app_vc.device)
105
+
106
+ def unload(self) -> None:
107
+ """Free SeedVC models from GPU."""
108
+ if not self._loaded:
109
+ return
110
+
111
+ if self._app_vc is not None:
112
+ for attr in [
113
+ "model",
114
+ "semantic_fn",
115
+ "vocoder_fn",
116
+ "campplus_model",
117
+ "to_mel",
118
+ ]:
119
+ if hasattr(self._app_vc, attr):
120
+ delattr(self._app_vc, attr)
121
+ self._app_vc = None
122
+
123
+ torch.cuda.empty_cache()
124
+
125
+ if self._original_cwd:
126
+ os.chdir(self._original_cwd)
127
+ self._original_cwd = None
128
+
129
+ self._loaded = False
130
+ logger.info("SeedVC unloaded")
131
+
132
+ def convert(
133
+ self,
134
+ source_wav_path: str,
135
+ target_wav_path: str,
136
+ diffusion_steps: int = DEFAULT_DIFFUSION_STEPS,
137
+ cfg_rate: float = DEFAULT_CFG_RATE,
138
+ ) -> np.ndarray:
139
+ """Convert voice identity of source to match target.
140
+
141
+ Both files must be 22050Hz mono WAV.
142
+
143
+ Args:
144
+ source_wav_path: Path to source audio (generated speech)
145
+ target_wav_path: Path to target audio (reference voice)
146
+ diffusion_steps: Number of diffusion steps (quality vs speed)
147
+ cfg_rate: Classifier-free guidance rate
148
+
149
+ Returns:
150
+ Converted audio as float32 numpy array at 22050Hz mono
151
+ """
152
+ if not self._loaded:
153
+ raise RuntimeError("SeedVC not loaded. Call load() first.")
154
+
155
+ logger.info(
156
+ "Converting voice: %s -> %s (%d steps, cfg_rate=%.2f)",
157
+ source_wav_path,
158
+ target_wav_path,
159
+ diffusion_steps,
160
+ cfg_rate,
161
+ )
162
+
163
+ audio_tuple = None
164
+ vc_kwargs = {
165
+ "source": source_wav_path,
166
+ "target": target_wav_path,
167
+ "diffusion_steps": diffusion_steps,
168
+ "length_adjust": 1.0,
169
+ "inference_cfg_rate": cfg_rate,
170
+ }
171
+ # n_quantizers removed in newer SeedVC versions
172
+ sig = inspect.signature(self._app_vc.voice_conversion)
173
+ if "n_quantizers" in sig.parameters:
174
+ vc_kwargs["n_quantizers"] = 3
175
+ for result in self._app_vc.voice_conversion(**vc_kwargs):
176
+ if isinstance(result, tuple) and len(result) == 2:
177
+ _, audio_tuple = result
178
+
179
+ if audio_tuple is None:
180
+ raise RuntimeError("SeedVC produced no output")
181
+
182
+ sample_rate, samples = audio_tuple
183
+
184
+ if samples.dtype == np.int16:
185
+ samples = samples.astype(np.float32) / 32768.0
186
+ elif samples.dtype != np.float32:
187
+ samples = samples.astype(np.float32)
188
+
189
+ peak = np.abs(samples).max()
190
+ if peak > 1.0:
191
+ samples = samples / peak
192
+
193
+ logger.info("Converted: %.1fs at %dHz", len(samples) / sample_rate, sample_rate)
194
+ return samples
src/audio_core/validate_and_patch.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """Forced alignment and hallucination trimming for Scenema Audio.
6
+
7
+ Uses Needleman-Wunsch sequence alignment (same algorithm as DNA matching)
8
+ to optimally align Whisper-transcribed words against expected text. Words
9
+ in the transcription that are INSERTIONS (not in the expected text) are
10
+ trimmed at silence boundaries. Substitutions (misrecognized words) are kept.
11
+ """
12
+
13
+ import logging
14
+ import re
15
+
16
+ import numpy as np
17
+
18
+ from .audio_utils import to_mono
19
+ from .whisper_aligner import _get_whisper
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ SILENCE_THRESHOLD = 0.015
24
+ TRIM_PAD_S = 0.02
25
+
26
+ # Alignment scoring
27
+ MATCH_SCORE = 2
28
+ MISMATCH_SCORE = -1
29
+ GAP_SCORE = -1 # Cost of insertion or deletion
30
+
31
+
32
+ def _normalize_words(text: str) -> list[str]:
33
+ """Normalize text to lowercase words without punctuation."""
34
+ text = text.lower()
35
+ text = re.sub(r"[^\w\s]", "", text)
36
+ return text.split()
37
+
38
+
39
+ def _fuzzy_match(a: str, b: str) -> bool:
40
+ """Check if two words are similar enough (edit distance based)."""
41
+ if a == b:
42
+ return True
43
+ if not a or not b or len(a) < 4 or len(b) < 4:
44
+ return False
45
+ m, n = len(a), len(b)
46
+ dp = list(range(n + 1))
47
+ for i in range(1, m + 1):
48
+ prev = dp[0]
49
+ dp[0] = i
50
+ for j in range(1, n + 1):
51
+ temp = dp[j]
52
+ dp[j] = prev if a[i - 1] == b[j - 1] else 1 + min(prev, dp[j], dp[j - 1])
53
+ prev = temp
54
+ return 1 - (dp[n] / max(m, n)) >= 0.5
55
+
56
+
57
+ def _score(a: str, b: str) -> int:
58
+ """Score for aligning word a with word b."""
59
+ if a == b:
60
+ return MATCH_SCORE
61
+ if _fuzzy_match(a, b):
62
+ return MATCH_SCORE # Treat fuzzy matches same as exact
63
+ return MISMATCH_SCORE
64
+
65
+
66
+ def _needleman_wunsch(
67
+ transcribed: list[str],
68
+ expected: list[str],
69
+ ) -> list[str]:
70
+ """Needleman-Wunsch global alignment.
71
+
72
+ Returns a list of labels for each transcribed word:
73
+ - "match": word aligns to an expected word (exact or fuzzy)
74
+ - "substitution": word replaces an expected word (poor match)
75
+ - "insertion": word has no counterpart in expected text (hallucinated)
76
+
77
+ Expected words that have no counterpart are deletions (not returned
78
+ since we only label transcribed words).
79
+ """
80
+ m = len(transcribed)
81
+ n = len(expected)
82
+
83
+ # Build score matrix
84
+ dp = [[0] * (n + 1) for _ in range(m + 1)]
85
+ for i in range(1, m + 1):
86
+ dp[i][0] = dp[i - 1][0] + GAP_SCORE
87
+ for j in range(1, n + 1):
88
+ dp[0][j] = dp[0][j - 1] + GAP_SCORE
89
+
90
+ for i in range(1, m + 1):
91
+ for j in range(1, n + 1):
92
+ match = dp[i - 1][j - 1] + _score(transcribed[i - 1], expected[j - 1])
93
+ delete = dp[i - 1][j] + GAP_SCORE # transcribed word is insertion
94
+ insert = dp[i][j - 1] + GAP_SCORE # expected word is deletion
95
+ dp[i][j] = max(match, delete, insert)
96
+
97
+ # Traceback
98
+ labels = []
99
+ i, j = m, n
100
+ while i > 0 or j > 0:
101
+ if (
102
+ i > 0
103
+ and j > 0
104
+ and dp[i][j]
105
+ == dp[i - 1][j - 1] + _score(transcribed[i - 1], expected[j - 1])
106
+ ):
107
+ s = _score(transcribed[i - 1], expected[j - 1])
108
+ labels.append("match" if s == MATCH_SCORE else "substitution")
109
+ i -= 1
110
+ j -= 1
111
+ elif i > 0 and dp[i][j] == dp[i - 1][j] + GAP_SCORE:
112
+ labels.append("insertion")
113
+ i -= 1
114
+ else:
115
+ j -= 1 # Deletion in expected β€” skip
116
+
117
+ labels.reverse()
118
+ return labels
119
+
120
+
121
+ def _transcribe_with_timestamps(
122
+ audio_mono: np.ndarray,
123
+ sr: int,
124
+ language: str,
125
+ ) -> list[dict]:
126
+ """Transcribe audio with word-level timestamps."""
127
+ if sr != 16000:
128
+ import librosa
129
+
130
+ audio_16k = librosa.resample(audio_mono, orig_sr=sr, target_sr=16000)
131
+ else:
132
+ audio_16k = audio_mono
133
+
134
+ model = _get_whisper()
135
+ segments, _ = model.transcribe(
136
+ audio_16k,
137
+ language=language,
138
+ word_timestamps=True,
139
+ vad_filter=True,
140
+ )
141
+
142
+ words = []
143
+ for seg in segments:
144
+ if seg.words:
145
+ for w in seg.words:
146
+ words.append(
147
+ {
148
+ "word": w.word.strip().lower(),
149
+ "start": w.start,
150
+ "end": w.end,
151
+ }
152
+ )
153
+ return words
154
+
155
+
156
+ def _find_silence_boundary(
157
+ audio: np.ndarray,
158
+ sr: int,
159
+ center_sample: int,
160
+ direction: str = "left",
161
+ window_s: float = 0.3,
162
+ ) -> int:
163
+ """Find nearest silence point from center position."""
164
+ hop = int(0.01 * sr)
165
+ window_samples = int(window_s * sr)
166
+
167
+ if direction == "left":
168
+ positions = range(center_sample, max(0, center_sample - window_samples), -hop)
169
+ else:
170
+ positions = range(
171
+ center_sample, min(len(audio), center_sample + window_samples), hop
172
+ )
173
+
174
+ for pos in positions:
175
+ chunk = audio[max(0, pos - hop // 2) : min(len(audio), pos + hop // 2)]
176
+ if (
177
+ len(chunk) > 0
178
+ and np.sqrt(np.mean(chunk.astype(np.float64) ** 2)) < SILENCE_THRESHOLD
179
+ ):
180
+ return pos
181
+
182
+ return center_sample
183
+
184
+
185
+ def _merge_ranges(
186
+ ranges: list[tuple[float, float]], gap: float = 0.15
187
+ ) -> list[tuple[float, float]]:
188
+ """Merge consecutive time ranges that are close together."""
189
+ if not ranges:
190
+ return []
191
+ merged = []
192
+ for start, end in sorted(ranges):
193
+ if merged and start - merged[-1][1] < gap:
194
+ merged[-1] = (merged[-1][0], end)
195
+ else:
196
+ merged.append((start, end))
197
+ return merged
198
+
199
+
200
+ def _detect_audio_repetition(
201
+ mono: np.ndarray,
202
+ sr: int,
203
+ expected_words: list[str],
204
+ min_duration_s: float = 1.5,
205
+ similarity_threshold: float = 0.85,
206
+ ) -> list[tuple[float, float]]:
207
+ """Detect repeated audio segments via mel spectrogram cross-correlation.
208
+
209
+ Slides a window across the audio and compares each segment against
210
+ all subsequent segments. If two non-overlapping segments have high
211
+ cosine similarity and the expected text does NOT contain that phrase
212
+ repeated, the second segment is marked for removal.
213
+
214
+ Only detects segments >= min_duration_s to avoid false positives on
215
+ short common sounds (breaths, pauses).
216
+ """
217
+ import torch
218
+
219
+ total_s = len(mono) / sr
220
+ if total_s < min_duration_s * 3:
221
+ return []
222
+
223
+ # Compute mel spectrogram
224
+ hop_length = int(0.02 * sr) # 20ms hops
225
+ n_fft = int(0.04 * sr) # 40ms window
226
+ audio_t = torch.from_numpy(mono).float()
227
+
228
+ try:
229
+ mel_spec = torch.stft(
230
+ audio_t,
231
+ n_fft=n_fft,
232
+ hop_length=hop_length,
233
+ window=torch.hann_window(n_fft),
234
+ return_complex=True,
235
+ ).abs()
236
+ except Exception:
237
+ return []
238
+
239
+ # Reduce to energy per time frame
240
+ energy = mel_spec.mean(dim=0).numpy() # (time_frames,)
241
+ frames_per_sec = sr / hop_length
242
+
243
+ # Slide window: check segments of varying length
244
+ repeated_ranges = []
245
+
246
+ for window_s in [3.0, 2.0, 1.5]:
247
+ win_frames = int(window_s * frames_per_sec)
248
+ if win_frames >= len(energy):
249
+ continue
250
+
251
+ step = win_frames // 2
252
+ for i in range(0, len(energy) - win_frames, step):
253
+ seg_a = energy[i : i + win_frames]
254
+ norm_a = np.linalg.norm(seg_a)
255
+ if norm_a < 1e-6:
256
+ continue
257
+
258
+ for j in range(i + win_frames, len(energy) - win_frames, step):
259
+ seg_b = energy[j : j + win_frames]
260
+ norm_b = np.linalg.norm(seg_b)
261
+ if norm_b < 1e-6:
262
+ continue
263
+
264
+ similarity = np.dot(seg_a, seg_b) / (norm_a * norm_b)
265
+ if similarity >= similarity_threshold:
266
+ start_s = j / frames_per_sec
267
+ end_s = (j + win_frames) / frames_per_sec
268
+ repeated_ranges.append((start_s, end_s))
269
+
270
+ # Deduplicate overlapping ranges
271
+ if not repeated_ranges:
272
+ return []
273
+
274
+ merged = _merge_ranges(repeated_ranges, gap=0.5)
275
+ logger.debug("Audio fingerprint candidates: %d segments", len(merged))
276
+ return merged
277
+
278
+
279
+ def _build_trim_mask(
280
+ mono: np.ndarray,
281
+ sr: int,
282
+ insertion_ranges: list[tuple[float, float]],
283
+ ) -> np.ndarray:
284
+ """Build boolean mask removing insertion segments at silence boundaries."""
285
+ total_samples = len(mono)
286
+ keep_mask = np.ones(total_samples, dtype=bool)
287
+ pad_samples = int(TRIM_PAD_S * sr)
288
+
289
+ for start_s, end_s in insertion_ranges:
290
+ trim_start = _find_silence_boundary(mono, sr, int(start_s * sr), "left")
291
+ trim_end = _find_silence_boundary(mono, sr, int(end_s * sr), "right")
292
+ trim_start = max(0, trim_start - pad_samples)
293
+ trim_end = min(total_samples, trim_end + pad_samples)
294
+ keep_mask[trim_start:trim_end] = False
295
+
296
+ return keep_mask
297
+
298
+
299
+ def validate_and_patch(
300
+ audio_np: np.ndarray,
301
+ sr: int,
302
+ expected_text: str,
303
+ language: str = "en",
304
+ ) -> np.ndarray:
305
+ """Trim hallucinated content using Needleman-Wunsch sequence alignment.
306
+
307
+ 1. Transcribe audio with Whisper (word timestamps)
308
+ 2. Align transcribed words against expected text (NW algorithm)
309
+ 3. Label each transcribed word: match, substitution, or insertion
310
+ 4. Trim insertion words (hallucinated) at silence boundaries
311
+ 5. Keep substitutions (misrecognized real speech)
312
+
313
+ Args:
314
+ audio_np: Audio array (mono or stereo).
315
+ sr: Sample rate.
316
+ expected_text: Full expected plain text.
317
+ language: Language code.
318
+
319
+ Returns:
320
+ Trimmed audio array.
321
+ """
322
+ expected_words = _normalize_words(expected_text)
323
+ if not expected_words:
324
+ return audio_np
325
+
326
+ mono = to_mono(audio_np).astype(np.float32)
327
+
328
+ try:
329
+ transcribed = _transcribe_with_timestamps(mono, sr, language)
330
+ except Exception as e:
331
+ logger.warning("Forced alignment failed: %s, skipping", e)
332
+ return audio_np
333
+
334
+ if not transcribed:
335
+ logger.info("No words transcribed, skipping trim")
336
+ return audio_np
337
+
338
+ # Extract just the words for alignment
339
+ transcribed_words = [re.sub(r"[^\w]", "", tw["word"]) for tw in transcribed]
340
+ transcribed_words = [w for w in transcribed_words if w] # Remove empty
341
+
342
+ # Build index mapping: filtered word index -> original transcribed index
343
+ word_indices = [
344
+ i for i, tw in enumerate(transcribed) if re.sub(r"[^\w]", "", tw["word"])
345
+ ]
346
+
347
+ # Run Needleman-Wunsch alignment
348
+ labels = _needleman_wunsch(transcribed_words, expected_words)
349
+
350
+ # Collect insertion ranges (hallucinated words)
351
+ insertion_ranges = []
352
+ n_match = 0
353
+ n_sub = 0
354
+ n_ins = 0
355
+
356
+ for idx, label in enumerate(labels):
357
+ orig_idx = word_indices[idx]
358
+ if label == "insertion":
359
+ insertion_ranges.append(
360
+ (transcribed[orig_idx]["start"], transcribed[orig_idx]["end"])
361
+ )
362
+ n_ins += 1
363
+ elif label == "match":
364
+ n_match += 1
365
+ else:
366
+ n_sub += 1
367
+
368
+ logger.info(
369
+ "NW alignment: %d matched, %d substituted, %d inserted (of %d transcribed vs %d expected)",
370
+ n_match,
371
+ n_sub,
372
+ n_ins,
373
+ len(transcribed_words),
374
+ len(expected_words),
375
+ )
376
+
377
+ # Audio fingerprint: detect repeated audio segments that Whisper missed
378
+ fingerprint_ranges = _detect_audio_repetition(mono, sr, expected_words)
379
+ if fingerprint_ranges:
380
+ logger.info(
381
+ "Audio fingerprint found %d repeated segments", len(fingerprint_ranges)
382
+ )
383
+ insertion_ranges.extend(fingerprint_ranges)
384
+
385
+ if not insertion_ranges:
386
+ logger.info("No insertions detected, audio clean")
387
+ return audio_np
388
+
389
+ # Merge consecutive insertions and trim
390
+ merged = _merge_ranges(insertion_ranges)
391
+ keep_mask = _build_trim_mask(mono, sr, merged)
392
+ result = audio_np[keep_mask]
393
+
394
+ trimmed_s = (len(mono) - np.sum(keep_mask)) / sr
395
+ logger.info(
396
+ "Trimmed %.1fs of hallucinated content (%.1fs -> %.1fs)",
397
+ trimmed_s,
398
+ len(mono) / sr,
399
+ np.sum(keep_mask) / sr,
400
+ )
401
+
402
+ return result
src/audio_core/validator.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """XML prompt validation for Scenema Audio.
6
+
7
+ Validates the <speak> XML format:
8
+ <speak voice="..." scene="..." language="...">
9
+ <action>delivery/stage direction</action>
10
+ Speech text here.
11
+ <action>more direction</action>
12
+ More speech text.
13
+ </speak>
14
+
15
+ Only <speak> root with <action> children allowed. All content is freeform.
16
+ """
17
+
18
+ import xml.etree.ElementTree as ET
19
+ from dataclasses import dataclass, field
20
+
21
+ ALLOWED_CHILD_TAGS = {"action", "sound"}
22
+
23
+
24
+ @dataclass
25
+ class ValidationResult:
26
+ valid: bool
27
+ errors: list[str] = field(default_factory=list)
28
+ voice: str | None = None
29
+ scene: str | None = None
30
+ language: str | None = None
31
+
32
+
33
+ def validate_prompt(xml_string: str) -> ValidationResult:
34
+ """Validate a Scenema Audio XML prompt.
35
+
36
+ Checks for valid XML structure, required <speak> root element,
37
+ required voice attribute, and only <action> child elements.
38
+
39
+ Args:
40
+ xml_string: Raw XML string to validate.
41
+
42
+ Returns:
43
+ ValidationResult with parsed attributes if valid,
44
+ or a list of errors if invalid.
45
+ """
46
+ errors: list[str] = []
47
+
48
+ if not xml_string or not xml_string.strip():
49
+ return ValidationResult(valid=False, errors=["Prompt is empty"])
50
+
51
+ try:
52
+ root = ET.fromstring(xml_string)
53
+ except ET.ParseError as e:
54
+ return ValidationResult(valid=False, errors=[f"Invalid XML: {e}"])
55
+
56
+ if root.tag != "speak":
57
+ errors.append(f"Root element must be <speak>, got <{root.tag}>")
58
+ return ValidationResult(valid=False, errors=errors)
59
+
60
+ voice = root.get("voice")
61
+ if not voice or not voice.strip():
62
+ errors.append("Missing required 'voice' attribute on <speak>")
63
+
64
+ gender = root.get("gender")
65
+ if not gender or gender.strip() not in ("male", "female"):
66
+ errors.append(
67
+ "Missing or invalid 'gender' attribute on <speak>. Must be 'male' or 'female'"
68
+ )
69
+
70
+ scene = root.get("scene")
71
+ language = root.get("language", "en")
72
+
73
+ allowed_attrs = {"voice", "scene", "language", "gender", "shot"}
74
+ for attr in root.attrib:
75
+ if attr not in allowed_attrs:
76
+ errors.append(f"Unknown attribute '{attr}' on <speak>")
77
+
78
+ for child in root:
79
+ if child.tag not in ALLOWED_CHILD_TAGS:
80
+ errors.append(
81
+ f"Unsupported tag <{child.tag}>. Only <action> and <sound> are allowed inside <speak>"
82
+ )
83
+ if len(list(child)) > 0:
84
+ errors.append(f"<{child.tag}> must contain only text, no nested elements")
85
+
86
+ has_text = False
87
+ if root.text and root.text.strip():
88
+ has_text = True
89
+ for child in root:
90
+ if child.tail and child.tail.strip():
91
+ has_text = True
92
+ break
93
+
94
+ if not has_text:
95
+ errors.append("Prompt must contain at least one speech text node")
96
+
97
+ if errors:
98
+ return ValidationResult(valid=False, errors=errors)
99
+
100
+ return ValidationResult(
101
+ valid=True,
102
+ voice=voice.strip() if voice else None,
103
+ scene=scene.strip() if scene else None,
104
+ language=language.strip() if language else None,
105
+ )
src/audio_core/vocal_separator.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """MelBandRoFormer vocal separation for Scenema Audio.
6
+
7
+ Separates vocals from background music/SFX in audio. Used to clean
8
+ generated audio that may contain unwanted background sounds from the
9
+ diffusion model (which was trained on video with ambient audio).
10
+
11
+ Expects stereo 44100Hz input. Processes in overlapping chunks for
12
+ smooth transitions.
13
+ """
14
+
15
+ import logging
16
+ import os
17
+ import subprocess
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ import numpy as np
22
+ import torch
23
+ from safetensors.torch import load_file
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ DEFAULT_MODEL_PATH = Path(
28
+ os.environ.get("MELBAND_MODEL_PATH", "/app/models/MelBandRoformer_fp16.safetensors")
29
+ )
30
+ DEFAULT_NODE_PATH = Path(
31
+ os.environ.get("MELBAND_NODE_PATH", "/app/melband_roformer_node")
32
+ )
33
+
34
+ MODEL_CONFIG = {
35
+ "dim": 384,
36
+ "depth": 6,
37
+ "stereo": True,
38
+ "num_stems": 1,
39
+ "time_transformer_depth": 1,
40
+ "freq_transformer_depth": 1,
41
+ "num_bands": 60,
42
+ "dim_head": 64,
43
+ "heads": 8,
44
+ "attn_dropout": 0,
45
+ "ff_dropout": 0,
46
+ "flash_attn": True,
47
+ "dim_freqs_in": 1025,
48
+ "sample_rate": 44100,
49
+ "stft_n_fft": 2048,
50
+ "stft_hop_length": 441,
51
+ "stft_win_length": 2048,
52
+ "stft_normalized": False,
53
+ "mask_estimator_depth": 2,
54
+ "multi_stft_resolution_loss_weight": 1.0,
55
+ "multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
56
+ "multi_stft_hop_size": 147,
57
+ "multi_stft_normalized": False,
58
+ }
59
+
60
+ CHUNK_SIZE = 352800 # ~8 seconds at 44100Hz
61
+ OVERLAP_FACTOR = 2
62
+
63
+
64
+ class VocalSeparator:
65
+ """Separates vocals from background audio using MelBandRoFormer.
66
+
67
+ Processes audio in overlapping chunks with fade windows for
68
+ smooth transitions. Keeps model loaded on GPU for repeated use.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ model_path: Path = DEFAULT_MODEL_PATH,
74
+ node_path: Path = DEFAULT_NODE_PATH,
75
+ ):
76
+ self.model_path = model_path
77
+ self.node_path = node_path
78
+ self._model = None
79
+ self._loaded = False
80
+
81
+ def load(self) -> None:
82
+ """Load MelBandRoFormer model to GPU."""
83
+ if self._loaded:
84
+ return
85
+
86
+ # Lazy import: model architecture only available after node_path added to sys.path
87
+ node_str = str(self.node_path)
88
+ if node_str not in sys.path:
89
+ sys.path.insert(0, node_str)
90
+ from model.mel_band_roformer import MelBandRoformer
91
+
92
+ logger.info("Loading MelBandRoFormer from %s", self.model_path)
93
+
94
+ model = MelBandRoformer(**MODEL_CONFIG)
95
+ sd = load_file(str(self.model_path))
96
+ model.load_state_dict(sd)
97
+ del sd
98
+
99
+ self._model = model.cuda().eval().float()
100
+ self._loaded = True
101
+
102
+ param_count = sum(p.numel() for p in self._model.parameters())
103
+ logger.info("MelBandRoFormer loaded: %.1fM params", param_count / 1e6)
104
+
105
+ def unload(self) -> None:
106
+ """Free model from GPU."""
107
+ if not self._loaded:
108
+ return
109
+
110
+ self._model = None
111
+ torch.cuda.empty_cache()
112
+ self._loaded = False
113
+ logger.info("MelBandRoFormer unloaded")
114
+
115
+ def separate(
116
+ self,
117
+ input_path: str,
118
+ vocals_path: str,
119
+ sfx_path: str | None = None,
120
+ ) -> dict:
121
+ """Separate vocals from background audio.
122
+
123
+ Args:
124
+ input_path: Path to input audio file (any format ffmpeg supports)
125
+ vocals_path: Output path for isolated vocals
126
+ sfx_path: Output path for isolated SFX/background (optional)
127
+
128
+ Returns:
129
+ Dict with metadata: input_duration, sample_rate
130
+ """
131
+ if not self._loaded:
132
+ raise RuntimeError("VocalSeparator not loaded. Call load() first.")
133
+
134
+ sr = MODEL_CONFIG["sample_rate"]
135
+
136
+ audio = self._load_audio_ffmpeg(input_path, sr)
137
+ input_duration = audio.shape[1] / sr
138
+
139
+ logger.info("Separating: %.1fs audio", input_duration)
140
+
141
+ with torch.inference_mode():
142
+ vocals = self._chunked_inference(audio, sr)
143
+
144
+ self._save_audio_ffmpeg(vocals, sr, vocals_path)
145
+
146
+ if sfx_path:
147
+ sfx = audio - vocals
148
+ self._save_audio_ffmpeg(sfx, sr, sfx_path)
149
+
150
+ return {
151
+ "input_duration": input_duration,
152
+ "sample_rate": sr,
153
+ }
154
+
155
+ def _chunked_inference(self, audio: np.ndarray, sr: int) -> np.ndarray:
156
+ """Run model inference in overlapping chunks with fade windows."""
157
+ total_samples = audio.shape[1]
158
+ chunk_size = CHUNK_SIZE
159
+ overlap = chunk_size // OVERLAP_FACTOR
160
+ step = chunk_size - overlap
161
+
162
+ fade_in = np.linspace(0, 1, overlap, dtype=np.float32)
163
+ fade_out = np.linspace(1, 0, overlap, dtype=np.float32)
164
+
165
+ result = np.zeros_like(audio)
166
+ weight = np.zeros(total_samples, dtype=np.float32)
167
+
168
+ pos = 0
169
+ while pos < total_samples:
170
+ end = min(pos + chunk_size, total_samples)
171
+ chunk = audio[:, pos:end]
172
+
173
+ if chunk.shape[1] < chunk_size:
174
+ pad_width = chunk_size - chunk.shape[1]
175
+ chunk = np.pad(chunk, ((0, 0), (0, pad_width)))
176
+
177
+ chunk_t = torch.from_numpy(chunk.copy()).unsqueeze(0).cuda().float()
178
+ out = self._model(chunk_t)
179
+ out_np = out.squeeze(0).cpu().float().numpy()[:, : end - pos]
180
+
181
+ chunk_len = end - pos
182
+ w = np.ones(chunk_len, dtype=np.float32)
183
+ if pos > 0:
184
+ fade_len = min(overlap, chunk_len)
185
+ w[:fade_len] *= fade_in[:fade_len]
186
+ if end < total_samples:
187
+ fade_len = min(overlap, chunk_len)
188
+ w[-fade_len:] *= fade_out[:fade_len]
189
+
190
+ result[:, pos:end] += out_np * w[np.newaxis, :]
191
+ weight[pos:end] += w
192
+
193
+ pos += step
194
+
195
+ weight = np.maximum(weight, 1e-8)
196
+ result /= weight[np.newaxis, :]
197
+
198
+ return result
199
+
200
+ def _load_audio_ffmpeg(self, path: str, target_sr: int) -> np.ndarray:
201
+ """Load audio to stereo float32 numpy via ffmpeg."""
202
+ cmd = [
203
+ "ffmpeg",
204
+ "-i",
205
+ path,
206
+ "-f",
207
+ "f32le",
208
+ "-acodec",
209
+ "pcm_f32le",
210
+ "-ac",
211
+ "2",
212
+ "-ar",
213
+ str(target_sr),
214
+ "-v",
215
+ "quiet",
216
+ "pipe:1",
217
+ ]
218
+ proc = subprocess.run(cmd, capture_output=True, check=True)
219
+ audio = np.frombuffer(proc.stdout, dtype=np.float32)
220
+ return audio.reshape(-1, 2).T # (2, samples)
221
+
222
+ def _save_audio_ffmpeg(self, audio: np.ndarray, sr: int, path: str) -> None:
223
+ """Save stereo float32 numpy to WAV via ffmpeg."""
224
+ interleaved = audio.T.astype(np.float32).tobytes()
225
+ cmd = [
226
+ "ffmpeg",
227
+ "-y",
228
+ "-f",
229
+ "f32le",
230
+ "-acodec",
231
+ "pcm_f32le",
232
+ "-ac",
233
+ "2",
234
+ "-ar",
235
+ str(sr),
236
+ "-i",
237
+ "pipe:0",
238
+ "-acodec",
239
+ "pcm_s16le",
240
+ path,
241
+ "-v",
242
+ "quiet",
243
+ ]
244
+ subprocess.run(cmd, input=interleaved, check=True)
src/audio_core/whisper_aligner.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """Whisper alignment for audio validation in Scenema Audio.
6
+
7
+ Uses faster-whisper (CTranslate2) on GPU to transcribe generated audio
8
+ and validate that the expected text was spoken. Whisper-small is 244M
9
+ params (~1GB VRAM, float16). Runs after denoise when VRAM is free.
10
+ """
11
+
12
+ import logging
13
+ import re
14
+ import unicodedata
15
+
16
+ import numpy as np
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Singleton whisper model (loaded once, reused)
21
+ _whisper_model = None
22
+
23
+
24
+ def _get_whisper():
25
+ """Get or initialize the whisper-small model.
26
+
27
+ Loaded once and cached for the process lifetime.
28
+ Runs on GPU with float16 β€” whisper-small is 244M params (~1GB VRAM).
29
+ By the time validation runs, denoise is complete and VRAM is free.
30
+ CTranslate2 uses its own CUDA allocator so no conflict with PyTorch.
31
+ """
32
+ global _whisper_model
33
+
34
+ if _whisper_model is not None:
35
+ return _whisper_model
36
+
37
+ from faster_whisper import WhisperModel
38
+
39
+ logger.info("Loading whisper-small for alignment validation (GPU, float16)...")
40
+ _whisper_model = WhisperModel("small", device="cuda", compute_type="float16")
41
+ logger.info("whisper-small loaded (GPU)")
42
+ return _whisper_model
43
+
44
+
45
+ def transcribe(audio_np: np.ndarray, sr: int, language: str = "en") -> str:
46
+ """Transcribe audio and return the text.
47
+
48
+ Args:
49
+ audio_np: Audio samples, shape (samples,) or (samples, channels).
50
+ sr: Sample rate in Hz.
51
+ language: Language code for transcription.
52
+
53
+ Returns:
54
+ Transcribed text string.
55
+ """
56
+ model = _get_whisper()
57
+
58
+ # Convert to mono float32 if needed
59
+ if audio_np.ndim == 2:
60
+ audio_mono = audio_np.mean(axis=1).astype(np.float32)
61
+ else:
62
+ audio_mono = audio_np.astype(np.float32)
63
+
64
+ # Resample to 16kHz if needed
65
+ if sr != 16000:
66
+ import librosa
67
+
68
+ audio_mono = librosa.resample(audio_mono, orig_sr=sr, target_sr=16000)
69
+
70
+ try:
71
+ segments, _ = model.transcribe(
72
+ audio_mono,
73
+ language=language,
74
+ word_timestamps=False,
75
+ vad_filter=True,
76
+ )
77
+ text = " ".join(seg.text.strip() for seg in segments).strip()
78
+ except (ValueError, TypeError):
79
+ # Mocked model in tests returns wrong types
80
+ logger.debug("Whisper transcribe returned unexpected type (test env?)")
81
+ text = ""
82
+
83
+ return text
84
+
85
+
86
+ def validate_text(
87
+ audio_np: np.ndarray,
88
+ sr: int,
89
+ expected_text: str,
90
+ language: str = "en",
91
+ min_word_ratio: float = 0.6,
92
+ ) -> tuple[bool, str, float]:
93
+ """Validate that generated audio contains the expected text.
94
+
95
+ Transcribes the audio and checks what fraction of expected words
96
+ appear in the transcription.
97
+
98
+ Args:
99
+ audio_np: Audio samples.
100
+ sr: Sample rate.
101
+ expected_text: The text that should have been spoken.
102
+ language: Language code.
103
+ min_word_ratio: Minimum fraction of expected words that must
104
+ appear in transcription (0.0 to 1.0).
105
+
106
+ Returns:
107
+ Tuple of (passed, transcribed_text, word_match_ratio).
108
+ """
109
+ transcribed = transcribe(audio_np, sr, language)
110
+
111
+ # Normalize both texts for comparison (strip accents for cross-locale matching)
112
+ def normalize(t):
113
+ t = unicodedata.normalize("NFD", t)
114
+ t = "".join(c for c in t if unicodedata.category(c) != "Mn")
115
+ t = t.lower()
116
+ t = re.sub(r"[^\w\s]", "", t)
117
+ return set(t.split())
118
+
119
+ expected_words = normalize(expected_text)
120
+ transcribed_words = normalize(transcribed)
121
+
122
+ if not expected_words:
123
+ return True, transcribed, 1.0
124
+
125
+ matched = expected_words & transcribed_words
126
+ ratio = len(matched) / len(expected_words)
127
+
128
+ passed = ratio >= min_word_ratio
129
+ if not passed:
130
+ logger.warning(
131
+ "Validation failed: %.0f%% word match (need %.0f%%). "
132
+ "Expected: %s... Got: %s...",
133
+ ratio * 100,
134
+ min_word_ratio * 100,
135
+ expected_text[:60],
136
+ transcribed[:60],
137
+ )
138
+
139
+ return passed, transcribed, ratio
src/common/__init__.py ADDED
File without changes
src/common/handlers/__init__.py ADDED
File without changes
src/common/handlers/base.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """Minimal handler types for standalone deployment.
6
+
7
+ Drop-in replacement for the production common.handlers.base module.
8
+ Provides ProcessJob, ProcessOutput, and ProcessResult so that
9
+ audio_core.processor imports resolve without modification.
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+ from typing import Any, Optional
14
+
15
+
16
+ @dataclass
17
+ class ProcessJob:
18
+ job_id: str
19
+ input: dict[str, Any]
20
+ upload_url: Optional[str] = None
21
+ webhook_url: Optional[str] = None
22
+
23
+
24
+ @dataclass
25
+ class ProcessOutput:
26
+ success: bool = True
27
+ data: Optional[bytes] = None
28
+ content_type: Optional[str] = None
29
+ result: Optional[dict] = None
30
+ metadata: Optional[dict] = None
31
+ error: Optional[str] = None
32
+
33
+
34
+ @dataclass
35
+ class ProcessResult:
36
+ job_id: str
37
+ success: bool
38
+ output: Optional[ProcessOutput] = None
39
+ processing_ms: int = 0
40
+ error: Optional[str] = None
src/server.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Scenema AI
2
+ # https://scenema.ai
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """Scenema Audio standalone server.
6
+
7
+ Thin FastAPI wrapper around the production AudioProcessor.
8
+ """
9
+
10
+ import asyncio
11
+ import base64
12
+ import logging
13
+ import os
14
+ import uuid
15
+ from contextlib import asynccontextmanager
16
+ from pathlib import Path
17
+
18
+ from fastapi import FastAPI, Request
19
+ from fastapi.responses import JSONResponse
20
+ from huggingface_hub import hf_hub_download, snapshot_download
21
+ import uvicorn
22
+
23
+ logger = logging.getLogger("scenema-audio")
24
+
25
+ # Must be set before any torch import
26
+ os.environ.setdefault(
27
+ "PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True"
28
+ )
29
+
30
+ from audio_core.processor import AudioProcessor # noqa: E402
31
+ from common.handlers.base import ProcessJob # noqa: E402
32
+
33
+ # ── Model download ──────────────────────────────────────────────
34
+
35
+ HF_REPO = "ScenemaAI/scenema-audio"
36
+ GEMMA_REPO = "google/gemma-3-12b-it"
37
+ SEEDVC_REPO = "Plachta/Seed-VC"
38
+ BIGVGAN_REPO = "nvidia/bigvgan_v2_22khz_80band_256x"
39
+ WHISPER_REPO = "openai/whisper-small"
40
+
41
+ MODEL_DIR = Path(os.environ.get("MODEL_DIR", "/app/models"))
42
+
43
+
44
+ def _download_models():
45
+ """Download missing model checkpoints from HuggingFace."""
46
+
47
+ token = os.environ.get("HF_TOKEN")
48
+
49
+ # Audio transformer (INT8 by default)
50
+ audio_ckpt = Path(os.environ.get(
51
+ "AUDIO_CKPT",
52
+ str(MODEL_DIR / "scenema-audio-transformer-int8.safetensors"),
53
+ ))
54
+ if not audio_ckpt.exists():
55
+ logger.info("Downloading audio transformer (INT8, ~4.9 GB)...")
56
+ hf_hub_download(
57
+ HF_REPO,
58
+ "scenema-audio-transformer-int8.safetensors",
59
+ local_dir=str(audio_ckpt.parent),
60
+ token=token,
61
+ )
62
+
63
+ # Pipeline checkpoint
64
+ pipeline_ckpt = Path(os.environ.get(
65
+ "PIPELINE_CKPT",
66
+ str(MODEL_DIR / "scenema-audio-pipeline.safetensors"),
67
+ ))
68
+ if not pipeline_ckpt.exists():
69
+ logger.info("Downloading pipeline checkpoint (~7.1 GB)...")
70
+ hf_hub_download(
71
+ HF_REPO,
72
+ "scenema-audio-pipeline.safetensors",
73
+ local_dir=str(pipeline_ckpt.parent),
74
+ token=token,
75
+ )
76
+
77
+ # VAE encoder (small, may already be baked)
78
+ vae_ckpt = Path(os.environ.get(
79
+ "VAE_ENCODER_CKPT",
80
+ str(MODEL_DIR / "scenema-audio-vae-encoder.safetensors"),
81
+ ))
82
+ if not vae_ckpt.exists():
83
+ logger.info("Downloading VAE encoder (~42 MB)...")
84
+ hf_hub_download(
85
+ HF_REPO,
86
+ "scenema-audio-vae-encoder.safetensors",
87
+ local_dir=str(vae_ckpt.parent),
88
+ token=token,
89
+ )
90
+
91
+ # Gemma 3 12B IT
92
+ gemma_root = Path(os.environ.get("GEMMA_ROOT", str(MODEL_DIR / "gemma-3-12b-it")))
93
+ if not gemma_root.exists() or not any(gemma_root.glob("*.safetensors")):
94
+ logger.info("Downloading Gemma 3 12B IT (~24 GB, gated model)...")
95
+ snapshot_download(
96
+ GEMMA_REPO,
97
+ local_dir=str(gemma_root),
98
+ ignore_patterns=["*.gguf"],
99
+ token=token,
100
+ )
101
+
102
+ # SeedVC
103
+ seedvc_path = Path(os.environ.get("SEEDVC_PATH", "/app/seed-vc"))
104
+ seedvc_cache = seedvc_path / "checkpoints"
105
+ if not seedvc_cache.exists() or not any(seedvc_cache.glob("*.pth")):
106
+ logger.info("Downloading SeedVC checkpoints (~1.6 GB)...")
107
+ hf_cache = seedvc_cache / "hf_cache"
108
+ hf_cache.mkdir(parents=True, exist_ok=True)
109
+ os.environ["HF_HUB_CACHE"] = str(hf_cache)
110
+ hf_hub_download(
111
+ SEEDVC_REPO,
112
+ "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
113
+ local_dir=str(seedvc_cache),
114
+ token=token,
115
+ )
116
+ hf_hub_download(
117
+ SEEDVC_REPO,
118
+ "config_dit_mel_seed_uvit_whisper_small_wavenet.yml",
119
+ local_dir=str(seedvc_cache),
120
+ token=token,
121
+ )
122
+ snapshot_download(BIGVGAN_REPO, local_dir=str(hf_cache / "bigvgan"))
123
+ snapshot_download(WHISPER_REPO, local_dir=str(hf_cache / "whisper-small"))
124
+
125
+
126
+ # ── FastAPI app ─────────────────────────────────────────────────
127
+
128
+ processor = AudioProcessor()
129
+ _semaphore = asyncio.Semaphore(1)
130
+
131
+
132
+ @asynccontextmanager
133
+ async def lifespan(app: FastAPI):
134
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
135
+ _download_models()
136
+ processor.startup()
137
+ logger.info("Scenema Audio ready on port %s", os.environ.get("PORT", "8000"))
138
+ yield
139
+ processor.shutdown()
140
+
141
+
142
+ app = FastAPI(title="Scenema Audio", lifespan=lifespan)
143
+
144
+
145
+ @app.get("/health")
146
+ async def health():
147
+ return {"status": "ok"}
148
+
149
+
150
+ @app.post("/generate")
151
+ async def generate(request: Request):
152
+ body = await request.json()
153
+
154
+ job = ProcessJob(
155
+ job_id=str(uuid.uuid4()),
156
+ input=body,
157
+ )
158
+
159
+ async with _semaphore:
160
+ result = await processor.process(job)
161
+
162
+ if not result.success:
163
+ return JSONResponse(
164
+ status_code=500,
165
+ content={
166
+ "status": "failed",
167
+ "error": result.error or "Generation failed",
168
+ },
169
+ )
170
+
171
+ output = result.output
172
+ audio_b64 = base64.b64encode(output.data).decode() if output.data else None
173
+
174
+ return {
175
+ "status": "succeeded",
176
+ "audio": audio_b64,
177
+ "content_type": output.content_type or "audio/wav",
178
+ "metadata": output.metadata or {},
179
+ }
180
+
181
+
182
+ if __name__ == "__main__":
183
+ logging.basicConfig(
184
+ level=logging.INFO,
185
+ format="%(asctime)s %(name)s %(levelname)s %(message)s",
186
+ )
187
+ port = int(os.environ.get("PORT", "8000"))
188
+ uvicorn.run(app, host="0.0.0.0", port=port)