Spaces:
Runtime error
Runtime error
multimodalart commited on
Commit Β·
cdc4405
1
Parent(s): c327e46
Initial Gradio ZeroGPU app for Scenema Audio
Browse files- README.md +28 -3
- app.py +385 -0
- requirements.txt +30 -0
- src/audio_core/__init__.py +7 -0
- src/audio_core/audio_utils.py +266 -0
- src/audio_core/chunker.py +334 -0
- src/audio_core/compiler.py +305 -0
- src/audio_core/engine.py +911 -0
- src/audio_core/enhancer.py +121 -0
- src/audio_core/inference.py +183 -0
- src/audio_core/main.py +42 -0
- src/audio_core/processor.py +484 -0
- src/audio_core/seedvc.py +194 -0
- src/audio_core/validate_and_patch.py +402 -0
- src/audio_core/validator.py +105 -0
- src/audio_core/vocal_separator.py +244 -0
- src/audio_core/whisper_aligner.py +139 -0
- src/common/__init__.py +0 -0
- src/common/handlers/__init__.py +0 -0
- src/common/handlers/base.py +40 -0
- src/server.py +188 -0
README.md
CHANGED
|
@@ -1,13 +1,38 @@
|
|
| 1 |
---
|
| 2 |
title: Scenema Audio
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: pink
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.14.0
|
| 8 |
-
python_version: '3.
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
|
|
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: Scenema Audio
|
| 3 |
+
emoji: ποΈ
|
| 4 |
colorFrom: pink
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.14.0
|
| 8 |
+
python_version: '3.12'
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
+
hardware: zero-a10g
|
| 12 |
+
short_description: Zero-shot expressive voice cloning and speech generation
|
| 13 |
+
suggested_storage: large
|
| 14 |
---
|
| 15 |
|
| 16 |
+
# Scenema Audio (ZeroGPU)
|
| 17 |
+
|
| 18 |
+
Gradio wrapper around [ScenemaAI/scenema-audio](https://github.com/ScenemaAI/scenema-audio).
|
| 19 |
+
|
| 20 |
+
Zero-shot expressive voice cloning and speech generation with emotion, pacing,
|
| 21 |
+
and breath control, built on an audio diffusion transformer extracted from
|
| 22 |
+
[LTX 2.3](https://github.com/Lightricks/LTX-2).
|
| 23 |
+
|
| 24 |
+
## Cold start
|
| 25 |
+
|
| 26 |
+
First request downloads ~38 GB of model weights:
|
| 27 |
+
- `scenema-audio-transformer-int8.safetensors` (~4.9 GB)
|
| 28 |
+
- `scenema-audio-pipeline.safetensors` (~6.7 GB)
|
| 29 |
+
- `google/gemma-3-12b-it` (~24 GB, **gated** β requires `HF_TOKEN` secret)
|
| 30 |
+
- SeedVC + BigVGAN + Whisper checkpoints (~3 GB)
|
| 31 |
+
- MelBandRoFormer (~436 MB)
|
| 32 |
+
|
| 33 |
+
Set `HF_TOKEN` in the Space secrets with access to `google/gemma-3-12b-it`.
|
| 34 |
+
|
| 35 |
+
## License
|
| 36 |
+
|
| 37 |
+
- **Model weights:** LTX-2 Community License Agreement
|
| 38 |
+
- **Code:** MIT
|
app.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenema Audio - ZeroGPU Gradio Space.
|
| 2 |
+
|
| 3 |
+
Wraps the ScenemaAI/scenema-audio AudioProcessor in a Gradio UI.
|
| 4 |
+
Heavy model weights (~38 GB) are downloaded on first cold-start and
|
| 5 |
+
cached on persistent storage; generation runs under @spaces.GPU.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import base64
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import tempfile
|
| 14 |
+
import uuid
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| 18 |
+
|
| 19 |
+
# Allow tweaking via env, but default to repo-local cache so weights persist
|
| 20 |
+
# across worker restarts on Spaces persistent storage if mounted at /data.
|
| 21 |
+
MODEL_DIR = Path(os.environ.get("MODEL_DIR", "/data/models")) \
|
| 22 |
+
if Path("/data").exists() else Path(os.environ.get("MODEL_DIR", "./models"))
|
| 23 |
+
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 24 |
+
os.environ["MODEL_DIR"] = str(MODEL_DIR)
|
| 25 |
+
|
| 26 |
+
# Default model paths (must be set before AudioProcessor is imported)
|
| 27 |
+
os.environ.setdefault(
|
| 28 |
+
"AUDIO_CKPT", str(MODEL_DIR / "scenema-audio-transformer-int8.safetensors")
|
| 29 |
+
)
|
| 30 |
+
os.environ.setdefault(
|
| 31 |
+
"PIPELINE_CKPT", str(MODEL_DIR / "scenema-audio-pipeline.safetensors")
|
| 32 |
+
)
|
| 33 |
+
os.environ.setdefault(
|
| 34 |
+
"VAE_ENCODER_CKPT", str(MODEL_DIR / "scenema-audio-vae-encoder.safetensors")
|
| 35 |
+
)
|
| 36 |
+
os.environ.setdefault("GEMMA_ROOT", str(MODEL_DIR / "gemma-3-12b-it"))
|
| 37 |
+
os.environ.setdefault(
|
| 38 |
+
"MELBAND_MODEL_PATH", str(MODEL_DIR / "MelBandRoformer_fp16.safetensors")
|
| 39 |
+
)
|
| 40 |
+
os.environ.setdefault("SEEDVC_PATH", str(Path.cwd() / "seed-vc"))
|
| 41 |
+
os.environ.setdefault("MELBAND_NODE_PATH", str(Path.cwd() / "melband_roformer_node"))
|
| 42 |
+
os.environ.setdefault("HF_HUB_CACHE", str(MODEL_DIR / "hf_cache"))
|
| 43 |
+
os.environ.setdefault("GEMMA_QUANTIZE", "nf4")
|
| 44 |
+
|
| 45 |
+
# Make repo source importable
|
| 46 |
+
sys.path.insert(0, str(Path(__file__).parent / "src"))
|
| 47 |
+
|
| 48 |
+
import gradio as gr
|
| 49 |
+
import spaces
|
| 50 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 51 |
+
|
| 52 |
+
logging.basicConfig(
|
| 53 |
+
level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s"
|
| 54 |
+
)
|
| 55 |
+
logger = logging.getLogger("scenema-space")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ββ Model download (CPU phase, runs at import) ββββββββββββββββββββββββββββ
|
| 59 |
+
|
| 60 |
+
HF_REPO = "ScenemaAI/scenema-audio"
|
| 61 |
+
GEMMA_REPO = "google/gemma-3-12b-it"
|
| 62 |
+
SEEDVC_REPO = "Plachta/Seed-VC"
|
| 63 |
+
BIGVGAN_REPO = "nvidia/bigvgan_v2_22khz_80band_256x"
|
| 64 |
+
WHISPER_REPO = "openai/whisper-small"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _download_all():
|
| 68 |
+
token = os.environ.get("HF_TOKEN")
|
| 69 |
+
|
| 70 |
+
audio_ckpt = Path(os.environ["AUDIO_CKPT"])
|
| 71 |
+
if not audio_ckpt.exists():
|
| 72 |
+
logger.info("Downloading audio transformer INT8 (~4.9 GB)...")
|
| 73 |
+
hf_hub_download(
|
| 74 |
+
HF_REPO,
|
| 75 |
+
"scenema-audio-transformer-int8.safetensors",
|
| 76 |
+
local_dir=str(audio_ckpt.parent),
|
| 77 |
+
token=token,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
pipeline_ckpt = Path(os.environ["PIPELINE_CKPT"])
|
| 81 |
+
if not pipeline_ckpt.exists():
|
| 82 |
+
logger.info("Downloading pipeline checkpoint (~6.7 GB)...")
|
| 83 |
+
hf_hub_download(
|
| 84 |
+
HF_REPO,
|
| 85 |
+
"scenema-audio-pipeline.safetensors",
|
| 86 |
+
local_dir=str(pipeline_ckpt.parent),
|
| 87 |
+
token=token,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
vae = Path(os.environ["VAE_ENCODER_CKPT"])
|
| 91 |
+
if not vae.exists():
|
| 92 |
+
logger.info("Downloading VAE encoder (~42 MB)...")
|
| 93 |
+
hf_hub_download(
|
| 94 |
+
HF_REPO,
|
| 95 |
+
"scenema-audio-vae-encoder.safetensors",
|
| 96 |
+
local_dir=str(vae.parent),
|
| 97 |
+
token=token,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
melband = Path(os.environ["MELBAND_MODEL_PATH"])
|
| 101 |
+
if not melband.exists():
|
| 102 |
+
logger.info("Downloading MelBandRoFormer (~436 MB)...")
|
| 103 |
+
hf_hub_download(
|
| 104 |
+
"Kijai/MelBandRoFormer_comfy",
|
| 105 |
+
"MelBandRoformer_fp16.safetensors",
|
| 106 |
+
local_dir=str(melband.parent),
|
| 107 |
+
token=token,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
gemma = Path(os.environ["GEMMA_ROOT"])
|
| 111 |
+
if not gemma.exists() or not any(gemma.glob("*.safetensors")):
|
| 112 |
+
logger.info("Downloading Gemma 3 12B IT (~24 GB, gated)...")
|
| 113 |
+
snapshot_download(
|
| 114 |
+
GEMMA_REPO,
|
| 115 |
+
local_dir=str(gemma),
|
| 116 |
+
ignore_patterns=["*.gguf"],
|
| 117 |
+
token=token,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
seedvc_path = Path(os.environ["SEEDVC_PATH"])
|
| 121 |
+
seedvc_ckpts = seedvc_path / "checkpoints"
|
| 122 |
+
if not seedvc_ckpts.exists() or not any(seedvc_ckpts.glob("*.pth")):
|
| 123 |
+
logger.info("Downloading SeedVC checkpoints (~1.6 GB)...")
|
| 124 |
+
seedvc_ckpts.mkdir(parents=True, exist_ok=True)
|
| 125 |
+
hf_cache = seedvc_ckpts / "hf_cache"
|
| 126 |
+
hf_cache.mkdir(parents=True, exist_ok=True)
|
| 127 |
+
os.environ["HF_HUB_CACHE"] = str(hf_cache)
|
| 128 |
+
hf_hub_download(
|
| 129 |
+
SEEDVC_REPO,
|
| 130 |
+
"DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
|
| 131 |
+
local_dir=str(seedvc_ckpts),
|
| 132 |
+
token=token,
|
| 133 |
+
)
|
| 134 |
+
hf_hub_download(
|
| 135 |
+
SEEDVC_REPO,
|
| 136 |
+
"config_dit_mel_seed_uvit_whisper_small_wavenet.yml",
|
| 137 |
+
local_dir=str(seedvc_ckpts),
|
| 138 |
+
token=token,
|
| 139 |
+
)
|
| 140 |
+
snapshot_download(BIGVGAN_REPO, local_dir=str(hf_cache / "bigvgan"))
|
| 141 |
+
snapshot_download(WHISPER_REPO, local_dir=str(hf_cache / "whisper-small"))
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _ensure_seedvc_repo():
|
| 145 |
+
"""Clone the seed-vc python source if missing (architecture code)."""
|
| 146 |
+
seedvc = Path(os.environ["SEEDVC_PATH"])
|
| 147 |
+
if not (seedvc / "modules").exists():
|
| 148 |
+
logger.info("Cloning seed-vc source...")
|
| 149 |
+
os.system(f"git clone --depth 1 https://github.com/Plachtaa/seed-vc.git {seedvc}")
|
| 150 |
+
|
| 151 |
+
melband_node = Path(os.environ["MELBAND_NODE_PATH"])
|
| 152 |
+
if not melband_node.exists():
|
| 153 |
+
logger.info("Cloning ComfyUI-MelBandRoFormer source...")
|
| 154 |
+
os.system(
|
| 155 |
+
f"git clone --depth 1 https://github.com/kijai/ComfyUI-MelBandRoFormer {melband_node}"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
_ensure_seedvc_repo()
|
| 160 |
+
_download_all()
|
| 161 |
+
|
| 162 |
+
# Import processor only after model paths/env are set
|
| 163 |
+
from audio_core.processor import AudioProcessor # noqa: E402
|
| 164 |
+
from common.handlers.base import ProcessJob # noqa: E402
|
| 165 |
+
|
| 166 |
+
_processor: AudioProcessor | None = None
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def _get_processor() -> AudioProcessor:
|
| 170 |
+
global _processor
|
| 171 |
+
if _processor is None:
|
| 172 |
+
_processor = AudioProcessor()
|
| 173 |
+
_processor.startup()
|
| 174 |
+
return _processor
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# ββ Generation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _build_prompt(text, voice, gender, scene, language, shot, action, sound_before):
|
| 181 |
+
attrs = [f'voice="{voice}"', f'gender="{gender}"']
|
| 182 |
+
if scene:
|
| 183 |
+
attrs.append(f'scene="{scene}"')
|
| 184 |
+
if language and language != "en":
|
| 185 |
+
attrs.append(f'language="{language}"')
|
| 186 |
+
if shot:
|
| 187 |
+
attrs.append(f'shot="{shot}"')
|
| 188 |
+
|
| 189 |
+
inner = ""
|
| 190 |
+
if sound_before:
|
| 191 |
+
inner += f"<sound>{sound_before}</sound>"
|
| 192 |
+
if action:
|
| 193 |
+
inner += f"<action>{action}</action>"
|
| 194 |
+
inner += text
|
| 195 |
+
|
| 196 |
+
return f"<speak {' '.join(attrs)}>{inner}</speak>"
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@spaces.GPU(duration=300)
|
| 200 |
+
def generate(
|
| 201 |
+
text,
|
| 202 |
+
voice,
|
| 203 |
+
gender,
|
| 204 |
+
scene,
|
| 205 |
+
language,
|
| 206 |
+
shot,
|
| 207 |
+
action,
|
| 208 |
+
sound_before,
|
| 209 |
+
reference_audio,
|
| 210 |
+
mode,
|
| 211 |
+
seed,
|
| 212 |
+
background_sfx,
|
| 213 |
+
skip_vc,
|
| 214 |
+
raw_xml,
|
| 215 |
+
progress=gr.Progress(track_tqdm=True),
|
| 216 |
+
):
|
| 217 |
+
progress(0, desc="Loading models (cold start can take a few minutes)")
|
| 218 |
+
processor = _get_processor()
|
| 219 |
+
|
| 220 |
+
if raw_xml and raw_xml.strip():
|
| 221 |
+
prompt = raw_xml.strip()
|
| 222 |
+
else:
|
| 223 |
+
if not text.strip():
|
| 224 |
+
raise gr.Error("Speech text is required.")
|
| 225 |
+
prompt = _build_prompt(text, voice, gender, scene, language, shot, action, sound_before)
|
| 226 |
+
|
| 227 |
+
# If reference audio is a local file (gradio path), upload-less: we copy into
|
| 228 |
+
# a temp http-less path that AudioProcessor expects URL. Easiest: serve via
|
| 229 |
+
# a file:// URL β but httpx doesn't support file://. Instead, patch path by
|
| 230 |
+
# writing input to a known place and using a fake URL handler via temp.
|
| 231 |
+
body = {
|
| 232 |
+
"prompt": prompt,
|
| 233 |
+
"mode": mode,
|
| 234 |
+
"seed": int(seed) if seed is not None else -1,
|
| 235 |
+
"background_sfx": bool(background_sfx),
|
| 236 |
+
"skip_vc": bool(skip_vc),
|
| 237 |
+
"validate": True,
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
# Reference voice: AudioProcessor downloads from URL. We bypass by directly
|
| 241 |
+
# placing a local path; the _generate function uses `reference_voice_url`
|
| 242 |
+
# and calls `_download_reference`. Workaround: monkey-patch download to
|
| 243 |
+
# return the local path if a file:// URL is given.
|
| 244 |
+
ref_local_path = None
|
| 245 |
+
if reference_audio:
|
| 246 |
+
ref_local_path = reference_audio
|
| 247 |
+
body["reference_voice_url"] = f"file://{ref_local_path}"
|
| 248 |
+
|
| 249 |
+
async def _run():
|
| 250 |
+
# Patch _download_reference for this call to handle file:// URLs
|
| 251 |
+
original = processor._download_reference
|
| 252 |
+
|
| 253 |
+
async def patched(url):
|
| 254 |
+
if url.startswith("file://"):
|
| 255 |
+
return url[len("file://"):]
|
| 256 |
+
return await original(url)
|
| 257 |
+
|
| 258 |
+
processor._download_reference = patched
|
| 259 |
+
try:
|
| 260 |
+
job = ProcessJob(job_id=str(uuid.uuid4()), input=body)
|
| 261 |
+
return await processor.process(job)
|
| 262 |
+
finally:
|
| 263 |
+
processor._download_reference = original
|
| 264 |
+
|
| 265 |
+
progress(0.1, desc="Generating audio")
|
| 266 |
+
result = asyncio.run(_run())
|
| 267 |
+
|
| 268 |
+
if not result.success:
|
| 269 |
+
raise gr.Error(result.error or "Generation failed")
|
| 270 |
+
|
| 271 |
+
# Write to temp wav and return path
|
| 272 |
+
out_path = Path(tempfile.gettempdir()) / f"scenema_{uuid.uuid4().hex}.wav"
|
| 273 |
+
out_path.write_bytes(result.output.data)
|
| 274 |
+
meta = result.output.metadata or {}
|
| 275 |
+
info = (
|
| 276 |
+
f"Duration: {meta.get('duration_s', 0)}s Β· "
|
| 277 |
+
f"Seed: {meta.get('seed')} Β· "
|
| 278 |
+
f"GPU: {meta.get('gpu', 'N/A')} Β· "
|
| 279 |
+
f"Time: {meta.get('processing_ms', 0)} ms"
|
| 280 |
+
)
|
| 281 |
+
return str(out_path), info
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
# ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 285 |
+
|
| 286 |
+
EXAMPLES = [
|
| 287 |
+
[
|
| 288 |
+
"The old lighthouse had stood on the cliff for over a century, its beam cutting through the fog like a blade of light.",
|
| 289 |
+
"A warm, clear male voice with a slight British accent. Measured, thoughtful pacing.",
|
| 290 |
+
"male", "", "en", "closeup", "", "",
|
| 291 |
+
None, "generate", 42, False, False, "",
|
| 292 |
+
],
|
| 293 |
+
[
|
| 294 |
+
"The city never really sleeps. It just closes its eyes and pretends for a while.",
|
| 295 |
+
"A young woman with a smoky, low register voice. Intimate, confessional tone.",
|
| 296 |
+
"female", "", "en", "closeup", "", "",
|
| 297 |
+
None, "voice_design", 7, False, False, "",
|
| 298 |
+
],
|
| 299 |
+
[
|
| 300 |
+
"Get the lines! She is pulling loose! Move! I said move!",
|
| 301 |
+
"Male, mid 40s. Weathered. Urgent, projecting over wind.",
|
| 302 |
+
"male", "Open dock in a thunderstorm, heavy rain", "en", "scene",
|
| 303 |
+
"He shouts over the storm", "Heavy rain and wind howling",
|
| 304 |
+
None, "generate", 11, True, False, "",
|
| 305 |
+
],
|
| 306 |
+
]
|
| 307 |
+
|
| 308 |
+
with gr.Blocks(title="Scenema Audio") as demo:
|
| 309 |
+
gr.Markdown(
|
| 310 |
+
"""
|
| 311 |
+
# Scenema Audio Β· Zero-shot Expressive TTS
|
| 312 |
+
Generate expressive speech with emotion, scene, and voice cloning.
|
| 313 |
+
Built on [ScenemaAI/scenema-audio](https://github.com/ScenemaAI/scenema-audio).
|
| 314 |
+
|
| 315 |
+
**Note:** First request triggers a ~38 GB cold start. Subsequent requests are fast.
|
| 316 |
+
"""
|
| 317 |
+
)
|
| 318 |
+
with gr.Row():
|
| 319 |
+
with gr.Column(scale=3):
|
| 320 |
+
text = gr.Textbox(
|
| 321 |
+
label="Speech text",
|
| 322 |
+
lines=4,
|
| 323 |
+
placeholder="What the voice should say...",
|
| 324 |
+
)
|
| 325 |
+
voice = gr.Textbox(
|
| 326 |
+
label="Voice description",
|
| 327 |
+
lines=2,
|
| 328 |
+
placeholder='e.g. "A warm male voice with a slight British accent..."',
|
| 329 |
+
)
|
| 330 |
+
with gr.Row():
|
| 331 |
+
gender = gr.Radio(["male", "female"], value="male", label="Gender")
|
| 332 |
+
language = gr.Dropdown(
|
| 333 |
+
["en", "es", "fr", "de", "it", "pt", "ja", "zh", "ko"],
|
| 334 |
+
value="en", label="Language",
|
| 335 |
+
)
|
| 336 |
+
shot = gr.Radio(
|
| 337 |
+
["closeup", "wide", "scene"], value="closeup", label="Shot"
|
| 338 |
+
)
|
| 339 |
+
with gr.Accordion("Scene & direction (optional)", open=False):
|
| 340 |
+
scene = gr.Textbox(label="Scene", placeholder="e.g. busy cafe at midday")
|
| 341 |
+
action = gr.Textbox(label="Performance direction (<action>)")
|
| 342 |
+
sound_before = gr.Textbox(label="Sound event before speech (<sound>)")
|
| 343 |
+
with gr.Accordion("Raw XML override (optional)", open=False):
|
| 344 |
+
raw_xml = gr.Textbox(
|
| 345 |
+
label="<speak> XML (overrides fields above when set)",
|
| 346 |
+
lines=4,
|
| 347 |
+
)
|
| 348 |
+
with gr.Accordion("Voice cloning (optional)", open=False):
|
| 349 |
+
reference_audio = gr.Audio(
|
| 350 |
+
label="Reference voice (10-20s)",
|
| 351 |
+
type="filepath",
|
| 352 |
+
)
|
| 353 |
+
with gr.Row():
|
| 354 |
+
mode = gr.Radio(
|
| 355 |
+
["generate", "voice_design"], value="generate", label="Mode"
|
| 356 |
+
)
|
| 357 |
+
seed = gr.Number(value=42, precision=0, label="Seed (-1 = random)")
|
| 358 |
+
with gr.Row():
|
| 359 |
+
background_sfx = gr.Checkbox(value=False, label="Keep background SFX")
|
| 360 |
+
skip_vc = gr.Checkbox(value=False, label="Skip SeedVC post-processing")
|
| 361 |
+
run_btn = gr.Button("Generate", variant="primary")
|
| 362 |
+
with gr.Column(scale=2):
|
| 363 |
+
out_audio = gr.Audio(label="Output", type="filepath")
|
| 364 |
+
info = gr.Textbox(label="Info", interactive=False)
|
| 365 |
+
|
| 366 |
+
gr.Examples(
|
| 367 |
+
examples=EXAMPLES,
|
| 368 |
+
inputs=[
|
| 369 |
+
text, voice, gender, scene, language, shot, action, sound_before,
|
| 370 |
+
reference_audio, mode, seed, background_sfx, skip_vc, raw_xml,
|
| 371 |
+
],
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
run_btn.click(
|
| 375 |
+
generate,
|
| 376 |
+
inputs=[
|
| 377 |
+
text, voice, gender, scene, language, shot, action, sound_before,
|
| 378 |
+
reference_audio, mode, seed, background_sfx, skip_vc, raw_xml,
|
| 379 |
+
],
|
| 380 |
+
outputs=[out_audio, info],
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
if __name__ == "__main__":
|
| 385 |
+
demo.queue().launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==2.2.6
|
| 2 |
+
transformers==4.57.6
|
| 3 |
+
accelerate==1.13.0
|
| 4 |
+
safetensors==0.7.0
|
| 5 |
+
sentencepiece==0.2.1
|
| 6 |
+
ltx-core @ git+https://github.com/Lightricks/LTX-2.git@41d924371612b692c0fd1e4d9d94c3dfb3c02cb3#subdirectory=packages/ltx-core
|
| 7 |
+
ltx-pipelines @ git+https://github.com/Lightricks/LTX-2.git@41d924371612b692c0fd1e4d9d94c3dfb3c02cb3#subdirectory=packages/ltx-pipelines
|
| 8 |
+
scipy==1.13.1
|
| 9 |
+
librosa==0.10.2
|
| 10 |
+
huggingface-hub==0.36.2
|
| 11 |
+
munch==4.0.0
|
| 12 |
+
einops==0.8.0
|
| 13 |
+
descript-audio-codec==1.0.0
|
| 14 |
+
pydub==0.25.1
|
| 15 |
+
soundfile==0.12.1
|
| 16 |
+
hydra-core==1.3.2
|
| 17 |
+
pyyaml==6.0.3
|
| 18 |
+
python-dotenv==1.2.2
|
| 19 |
+
diffusers==0.37.1
|
| 20 |
+
onnxruntime==1.25.0
|
| 21 |
+
funasr==1.3.1
|
| 22 |
+
rotary-embedding-torch==0.8.9
|
| 23 |
+
beartype==0.22.9
|
| 24 |
+
fastapi==0.136.1
|
| 25 |
+
httpx==0.28.1
|
| 26 |
+
psutil==7.2.2
|
| 27 |
+
bitsandbytes==0.49.2
|
| 28 |
+
kokoro==0.9.4
|
| 29 |
+
faster-whisper==1.2.1
|
| 30 |
+
ctranslate2==4.7.1
|
src/audio_core/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""Scenema Audio: Expressive audio generation via LTX 2.3 audio diffusion."""
|
| 6 |
+
|
| 7 |
+
__version__ = "1.0.0"
|
src/audio_core/audio_utils.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""Audio utility functions for Scenema Audio.
|
| 6 |
+
|
| 7 |
+
Silence trimming, volume normalization, wav I/O, format conversion.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import soundfile as sf
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def trim_silence(
|
| 20 |
+
audio_np: np.ndarray,
|
| 21 |
+
sr: int,
|
| 22 |
+
max_silence: float = 0.5,
|
| 23 |
+
threshold_db: float = -40,
|
| 24 |
+
) -> np.ndarray:
|
| 25 |
+
"""Trim silence exceeding max_silence from start and end of audio.
|
| 26 |
+
|
| 27 |
+
Keeps up to max_silence seconds of silence at boundaries.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
audio_np: Audio samples, shape (samples,) or (samples, channels).
|
| 31 |
+
sr: Sample rate in Hz.
|
| 32 |
+
max_silence: Maximum silence to keep at head/tail in seconds.
|
| 33 |
+
threshold_db: Amplitude threshold below which audio is considered silence.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Trimmed audio array with the same number of dimensions as input.
|
| 37 |
+
"""
|
| 38 |
+
threshold = 10 ** (threshold_db / 20.0)
|
| 39 |
+
max_silent_samples = int(max_silence * sr)
|
| 40 |
+
window = int(0.02 * sr) # 20ms analysis window
|
| 41 |
+
|
| 42 |
+
if audio_np.ndim == 2:
|
| 43 |
+
mono = audio_np.mean(axis=1)
|
| 44 |
+
else:
|
| 45 |
+
mono = audio_np
|
| 46 |
+
|
| 47 |
+
if len(mono) < window:
|
| 48 |
+
return audio_np
|
| 49 |
+
|
| 50 |
+
energy = np.array(
|
| 51 |
+
[
|
| 52 |
+
np.abs(mono[i : i + window]).max()
|
| 53 |
+
for i in range(0, len(mono) - window, window)
|
| 54 |
+
]
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
voiced = np.where(energy > threshold)[0]
|
| 58 |
+
if len(voiced) == 0:
|
| 59 |
+
return audio_np
|
| 60 |
+
|
| 61 |
+
first_voiced = max(0, voiced[0] * window - max_silent_samples)
|
| 62 |
+
last_voiced = min(len(audio_np), (voiced[-1] + 1) * window + max_silent_samples)
|
| 63 |
+
|
| 64 |
+
return audio_np[first_voiced:last_voiced]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def normalize_volume(
|
| 68 |
+
audio_np: np.ndarray,
|
| 69 |
+
sr: int,
|
| 70 |
+
target_lufs: float = -23.0,
|
| 71 |
+
) -> np.ndarray:
|
| 72 |
+
"""Normalize audio volume to target LUFS (approximate via RMS).
|
| 73 |
+
|
| 74 |
+
Uses a simplified RMS-based LUFS approximation suitable for
|
| 75 |
+
per-chunk normalization before concatenation.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
audio_np: Audio samples, shape (samples,) or (samples, channels).
|
| 79 |
+
sr: Sample rate in Hz.
|
| 80 |
+
target_lufs: Target loudness in LUFS (default -23, EBU R128).
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Volume-normalized audio array, soft-clipped to prevent distortion.
|
| 84 |
+
"""
|
| 85 |
+
if audio_np.ndim == 2:
|
| 86 |
+
mono = audio_np.mean(axis=1)
|
| 87 |
+
else:
|
| 88 |
+
mono = audio_np
|
| 89 |
+
|
| 90 |
+
rms = np.sqrt(np.mean(mono**2))
|
| 91 |
+
if rms < 1e-8:
|
| 92 |
+
return audio_np
|
| 93 |
+
|
| 94 |
+
current_lufs = 20 * math.log10(rms) - 0.691
|
| 95 |
+
gain_db = target_lufs - current_lufs
|
| 96 |
+
gain = 10 ** (gain_db / 20.0)
|
| 97 |
+
gain = max(0.1, min(gain, 10.0))
|
| 98 |
+
|
| 99 |
+
result = audio_np * gain
|
| 100 |
+
|
| 101 |
+
peak = np.abs(result).max()
|
| 102 |
+
if peak > 0.99:
|
| 103 |
+
result = result * (0.99 / peak)
|
| 104 |
+
|
| 105 |
+
return result
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def extract_wav(audio_obj) -> tuple[np.ndarray, int]:
|
| 109 |
+
"""Extract numpy waveform from an LTX Audio object.
|
| 110 |
+
|
| 111 |
+
Handles shapes: (B,C,samples) -> (samples,C), (C,samples) -> (samples,C).
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
audio_obj: LTX pipeline Audio object with .waveform and .sampling_rate.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Tuple of (waveform as float32 numpy, sample_rate).
|
| 118 |
+
"""
|
| 119 |
+
w = audio_obj.waveform.cpu().float().numpy()
|
| 120 |
+
if w.ndim == 3:
|
| 121 |
+
w = w.squeeze(0)
|
| 122 |
+
if w.ndim == 2:
|
| 123 |
+
w = w.T
|
| 124 |
+
return w, audio_obj.sampling_rate
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def save_wav(audio_np: np.ndarray, sr: int, path: str) -> None:
|
| 128 |
+
"""Save audio to WAV file.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
audio_np: Audio samples, shape (samples,) or (samples, channels).
|
| 132 |
+
sr: Sample rate in Hz.
|
| 133 |
+
path: Output file path.
|
| 134 |
+
"""
|
| 135 |
+
sf.write(path, audio_np, sr)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def load_wav(path: str) -> tuple[np.ndarray, int]:
|
| 139 |
+
"""Load audio from WAV file.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
path: Input file path.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Tuple of (audio samples as float64 numpy, sample_rate).
|
| 146 |
+
"""
|
| 147 |
+
data, sr = sf.read(path)
|
| 148 |
+
return data, sr
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def to_mono(audio_np: np.ndarray) -> np.ndarray:
|
| 152 |
+
"""Convert stereo to mono by averaging channels.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
audio_np: Audio samples, shape (samples, 2) for stereo or (samples,) for mono.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Mono audio array, shape (samples,).
|
| 159 |
+
"""
|
| 160 |
+
if audio_np.ndim == 2 and audio_np.shape[1] == 2:
|
| 161 |
+
return audio_np.mean(axis=1)
|
| 162 |
+
return audio_np
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def shorten_long_silence(
|
| 166 |
+
audio_np: np.ndarray,
|
| 167 |
+
sr: int,
|
| 168 |
+
max_duration: float = 1.0,
|
| 169 |
+
target_duration: float = 0.3,
|
| 170 |
+
threshold_db: float = -35,
|
| 171 |
+
) -> np.ndarray:
|
| 172 |
+
"""Shorten silence regions longer than max_duration to target_duration.
|
| 173 |
+
|
| 174 |
+
Unlike silenceremove which deletes silence entirely, this preserves
|
| 175 |
+
a natural pause of target_duration seconds. Prevents chunk boundary
|
| 176 |
+
artifacts while keeping the audio flow natural.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
audio_np: Audio samples, shape (samples,) or (samples, channels).
|
| 180 |
+
sr: Sample rate in Hz.
|
| 181 |
+
max_duration: Silence longer than this is shortened.
|
| 182 |
+
target_duration: Silence is shortened to this duration.
|
| 183 |
+
threshold_db: Amplitude threshold below which audio is silence.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Audio with long silence regions shortened.
|
| 187 |
+
"""
|
| 188 |
+
threshold = 10 ** (threshold_db / 20.0)
|
| 189 |
+
window = int(0.02 * sr) # 20ms analysis window
|
| 190 |
+
max_samples = int(max_duration * sr)
|
| 191 |
+
target_samples = int(target_duration * sr)
|
| 192 |
+
|
| 193 |
+
if audio_np.ndim == 2:
|
| 194 |
+
mono = audio_np.mean(axis=1)
|
| 195 |
+
else:
|
| 196 |
+
mono = audio_np
|
| 197 |
+
|
| 198 |
+
if len(mono) < window:
|
| 199 |
+
return audio_np
|
| 200 |
+
|
| 201 |
+
# Find silent regions
|
| 202 |
+
energy = np.array(
|
| 203 |
+
[
|
| 204 |
+
np.abs(mono[i : i + window]).max()
|
| 205 |
+
for i in range(0, len(mono) - window, window)
|
| 206 |
+
]
|
| 207 |
+
)
|
| 208 |
+
is_silent = energy < threshold
|
| 209 |
+
|
| 210 |
+
# Build list of (start_sample, end_sample) for silence regions
|
| 211 |
+
silence_regions = []
|
| 212 |
+
in_silence = False
|
| 213 |
+
start = 0
|
| 214 |
+
for i, silent in enumerate(is_silent):
|
| 215 |
+
if silent and not in_silence:
|
| 216 |
+
start = i * window
|
| 217 |
+
in_silence = True
|
| 218 |
+
elif not silent and in_silence:
|
| 219 |
+
end = i * window
|
| 220 |
+
if end - start > max_samples:
|
| 221 |
+
silence_regions.append((start, end))
|
| 222 |
+
in_silence = False
|
| 223 |
+
if in_silence:
|
| 224 |
+
end = len(mono)
|
| 225 |
+
if end - start > max_samples:
|
| 226 |
+
silence_regions.append((start, end))
|
| 227 |
+
|
| 228 |
+
if not silence_regions:
|
| 229 |
+
return audio_np
|
| 230 |
+
|
| 231 |
+
# Build output by keeping non-silence and shortening long silence
|
| 232 |
+
parts = []
|
| 233 |
+
prev_end = 0
|
| 234 |
+
for s_start, s_end in silence_regions:
|
| 235 |
+
# Keep audio before this silence
|
| 236 |
+
parts.append(audio_np[prev_end:s_start])
|
| 237 |
+
# Add shortened silence (target_duration worth)
|
| 238 |
+
parts.append(audio_np[s_start : s_start + target_samples])
|
| 239 |
+
prev_end = s_end
|
| 240 |
+
|
| 241 |
+
# Keep remaining audio after last silence
|
| 242 |
+
parts.append(audio_np[prev_end:])
|
| 243 |
+
|
| 244 |
+
result = np.concatenate(parts, axis=0)
|
| 245 |
+
shortened = (len(audio_np) - len(result)) / sr
|
| 246 |
+
if shortened > 0:
|
| 247 |
+
logger.info(
|
| 248 |
+
"Shortened %d silence regions, removed %.1fs",
|
| 249 |
+
len(silence_regions),
|
| 250 |
+
shortened,
|
| 251 |
+
)
|
| 252 |
+
return result
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def ensure_stereo(audio_np: np.ndarray) -> np.ndarray:
|
| 256 |
+
"""Convert mono to stereo by duplicating the channel.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
audio_np: Audio samples, shape (samples,) for mono or (samples, 2) for stereo.
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
Stereo audio array, shape (samples, 2).
|
| 263 |
+
"""
|
| 264 |
+
if audio_np.ndim == 1:
|
| 265 |
+
return np.stack([audio_np, audio_np], axis=-1)
|
| 266 |
+
return audio_np
|
src/audio_core/chunker.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""Text chunking and duration estimation for Scenema Audio.
|
| 6 |
+
|
| 7 |
+
Splits long text into chunks at sentence boundaries using Kokoro TTS
|
| 8 |
+
phoneme-level timing as the source of truth for duration. No word counting.
|
| 9 |
+
|
| 10 |
+
Algorithm:
|
| 11 |
+
1. Split text into sentences
|
| 12 |
+
2. Estimate each sentence's duration via Kokoro (one call per sentence)
|
| 13 |
+
3. Greedily merge: accumulate sentence durations, start a new chunk
|
| 14 |
+
when running_sum * LTX_MULTIPLIER exceeds MAX_CHUNK_DURATION_S
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
import random
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
|
| 21 |
+
from .compiler import compile_chunk_prompt, compile_prompt, extract_sentence_actions
|
| 22 |
+
from .validator import validate_prompt
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
FALLBACK_WORDS_PER_SEC = 2.2 # Test-environment-only fallback when Kokoro is mocked
|
| 27 |
+
ACTION_DURATION_S = 1.5 # Extra time per action block
|
| 28 |
+
MAX_CHUNK_DURATION_S = (
|
| 29 |
+
15.0 # Safe generation limit β model trained on 20s but repeats beyond ~15s
|
| 30 |
+
)
|
| 31 |
+
LTX_MULTIPLIER = 1.5 # LTX speaks slower than Kokoro; overshoot for trimming
|
| 32 |
+
|
| 33 |
+
# Kokoro singleton (loaded once, reused)
|
| 34 |
+
_kokoro_pipeline = None
|
| 35 |
+
_kokoro_available: bool | None = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _get_kokoro():
|
| 39 |
+
"""Get or initialize the Kokoro TTS pipeline for duration estimation.
|
| 40 |
+
|
| 41 |
+
Kokoro is 82M params, runs on CPU. Loaded once and cached.
|
| 42 |
+
Falls back to word-count heuristic only in test environments.
|
| 43 |
+
"""
|
| 44 |
+
global _kokoro_pipeline, _kokoro_available
|
| 45 |
+
|
| 46 |
+
if _kokoro_available is False:
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
if _kokoro_pipeline is not None:
|
| 50 |
+
return _kokoro_pipeline
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
from kokoro import KPipeline
|
| 54 |
+
|
| 55 |
+
pipe = KPipeline(lang_code="a")
|
| 56 |
+
# Verify it's a real Kokoro pipeline (not a mock in tests)
|
| 57 |
+
if not hasattr(pipe, "__module__") or "kokoro" not in str(
|
| 58 |
+
getattr(pipe, "__module__", "")
|
| 59 |
+
):
|
| 60 |
+
raise TypeError("Kokoro pipeline is not genuine (test mock)")
|
| 61 |
+
_kokoro_pipeline = pipe
|
| 62 |
+
_kokoro_available = True
|
| 63 |
+
logger.info("Kokoro TTS loaded for duration estimation")
|
| 64 |
+
return _kokoro_pipeline
|
| 65 |
+
except TypeError:
|
| 66 |
+
# Test environment with mocks, fall back silently
|
| 67 |
+
_kokoro_available = False
|
| 68 |
+
return None
|
| 69 |
+
except (ImportError, Exception) as e:
|
| 70 |
+
_kokoro_available = False
|
| 71 |
+
logger.error("Kokoro is required but not available: %s", e)
|
| 72 |
+
raise RuntimeError(
|
| 73 |
+
f"Kokoro TTS is a required dependency for duration estimation. "
|
| 74 |
+
f"Install it with: pip install kokoro. Error: {e}"
|
| 75 |
+
) from e
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _kokoro_duration(text: str) -> float | None:
|
| 79 |
+
"""Estimate speech duration using Kokoro TTS phoneme-level timing.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
text: Speech text to estimate duration for
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Duration in seconds, or None if Kokoro unavailable
|
| 86 |
+
"""
|
| 87 |
+
pipe = _get_kokoro()
|
| 88 |
+
if pipe is None:
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
total_frames = 0
|
| 93 |
+
for result in pipe(text, voice="af_heart"):
|
| 94 |
+
if hasattr(result, "audio") and result.audio is not None:
|
| 95 |
+
total_frames += len(result.audio)
|
| 96 |
+
|
| 97 |
+
# Kokoro outputs at 24000Hz
|
| 98 |
+
duration = total_frames / 24000.0
|
| 99 |
+
return duration
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.warning("Kokoro estimation failed: %s", e)
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dataclass
|
| 106 |
+
class ChunkSpec:
|
| 107 |
+
compiled_prompt: str
|
| 108 |
+
duration_s: float
|
| 109 |
+
seed: int
|
| 110 |
+
expected_text: str
|
| 111 |
+
language: str = "en"
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _split_into_sentences(text: str) -> list[str]:
|
| 115 |
+
"""Split text into individual sentences at .!? boundaries."""
|
| 116 |
+
sentences = []
|
| 117 |
+
current = ""
|
| 118 |
+
for char in text:
|
| 119 |
+
current += char
|
| 120 |
+
if char in ".!?":
|
| 121 |
+
stripped = current.strip()
|
| 122 |
+
if stripped:
|
| 123 |
+
sentences.append(stripped)
|
| 124 |
+
current = ""
|
| 125 |
+
if current.strip():
|
| 126 |
+
sentences.append(current.strip())
|
| 127 |
+
return sentences
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _estimate_sentence_durations(sentences: list[str]) -> list[float]:
|
| 131 |
+
"""Estimate Kokoro duration for each sentence individually.
|
| 132 |
+
|
| 133 |
+
One Kokoro call per sentence. Returns raw Kokoro durations (before
|
| 134 |
+
LTX multiplier). Falls back to word-count heuristic per sentence
|
| 135 |
+
only in test environments where Kokoro is mocked.
|
| 136 |
+
"""
|
| 137 |
+
durations = []
|
| 138 |
+
for sent in sentences:
|
| 139 |
+
dur = _kokoro_duration(sent)
|
| 140 |
+
if dur is None:
|
| 141 |
+
# Test environment fallback only
|
| 142 |
+
dur = len(sent.split()) / FALLBACK_WORDS_PER_SEC + 0.3
|
| 143 |
+
durations.append(dur)
|
| 144 |
+
return durations
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def split_text_by_duration(
|
| 148 |
+
text: str,
|
| 149 |
+
multiplier: float = LTX_MULTIPLIER,
|
| 150 |
+
max_duration: float = MAX_CHUNK_DURATION_S,
|
| 151 |
+
) -> list[tuple[str, float]]:
|
| 152 |
+
"""Split text into chunks using Kokoro duration estimation.
|
| 153 |
+
|
| 154 |
+
Kokoro is the source of truth for duration. No word counting.
|
| 155 |
+
|
| 156 |
+
Algorithm:
|
| 157 |
+
1. Split text into sentences
|
| 158 |
+
2. Estimate each sentence's duration via Kokoro (one call per sentence)
|
| 159 |
+
3. Greedily merge: accumulate durations, start a new chunk when
|
| 160 |
+
running_sum * multiplier would exceed max_duration
|
| 161 |
+
|
| 162 |
+
Duration is additive across sentences because Kokoro estimates are
|
| 163 |
+
phoneme-level with no cross-sentence dependencies.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
text: Full speech text.
|
| 167 |
+
multiplier: LTX speaks slower than Kokoro; applied to estimates.
|
| 168 |
+
max_duration: Max audio duration per chunk (model training limit).
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
List of (chunk_text, estimated_ltx_duration) tuples.
|
| 172 |
+
"""
|
| 173 |
+
sentences = _split_into_sentences(text)
|
| 174 |
+
if not sentences:
|
| 175 |
+
return []
|
| 176 |
+
|
| 177 |
+
# Split long sentences at commas if they exceed max_duration on their own
|
| 178 |
+
expanded = []
|
| 179 |
+
for sent in sentences:
|
| 180 |
+
dur = _estimate_sentence_durations([sent])[0]
|
| 181 |
+
if dur * multiplier > max_duration and "," in sent:
|
| 182 |
+
# Split at commas and re-estimate
|
| 183 |
+
clauses = [c.strip() for c in sent.split(",") if c.strip()]
|
| 184 |
+
clause_durs = _estimate_sentence_durations(clauses)
|
| 185 |
+
sub_texts: list[str] = []
|
| 186 |
+
sub_dur = 0.0
|
| 187 |
+
for clause, cdur in zip(clauses, clause_durs):
|
| 188 |
+
if sub_texts and (sub_dur + cdur) * multiplier > max_duration:
|
| 189 |
+
expanded.append(", ".join(sub_texts))
|
| 190 |
+
sub_texts = []
|
| 191 |
+
sub_dur = 0.0
|
| 192 |
+
sub_texts.append(clause)
|
| 193 |
+
sub_dur += cdur
|
| 194 |
+
if sub_texts:
|
| 195 |
+
expanded.append(", ".join(sub_texts))
|
| 196 |
+
else:
|
| 197 |
+
expanded.append(sent)
|
| 198 |
+
|
| 199 |
+
durations = _estimate_sentence_durations(expanded)
|
| 200 |
+
|
| 201 |
+
chunks: list[tuple[str, float]] = []
|
| 202 |
+
current_texts: list[str] = []
|
| 203 |
+
current_dur = 0.0
|
| 204 |
+
|
| 205 |
+
for sent, dur in zip(expanded, durations):
|
| 206 |
+
if current_texts and (current_dur + dur) * multiplier > max_duration:
|
| 207 |
+
chunk_text = " ".join(current_texts)
|
| 208 |
+
chunks.append((chunk_text, min(current_dur * multiplier, max_duration)))
|
| 209 |
+
current_texts = []
|
| 210 |
+
current_dur = 0.0
|
| 211 |
+
|
| 212 |
+
current_texts.append(sent)
|
| 213 |
+
current_dur += dur
|
| 214 |
+
|
| 215 |
+
if current_texts:
|
| 216 |
+
chunk_text = " ".join(current_texts)
|
| 217 |
+
chunks.append((chunk_text, min(current_dur * multiplier, max_duration)))
|
| 218 |
+
|
| 219 |
+
return chunks
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def estimate_duration(
|
| 223 |
+
text: str,
|
| 224 |
+
num_actions: int = 0,
|
| 225 |
+
multiplier: float = LTX_MULTIPLIER,
|
| 226 |
+
) -> float:
|
| 227 |
+
"""Estimate audio duration for a single chunk of text.
|
| 228 |
+
|
| 229 |
+
Used for single-chunk prompts that don't need splitting.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
text: Speech text (no actions)
|
| 233 |
+
num_actions: Number of action blocks (adds time for breaths/pauses)
|
| 234 |
+
multiplier: Duration multiplier (LTX speaks slower than Kokoro)
|
| 235 |
+
"""
|
| 236 |
+
kokoro_dur = _kokoro_duration(text)
|
| 237 |
+
|
| 238 |
+
if kokoro_dur is not None:
|
| 239 |
+
base_duration = kokoro_dur
|
| 240 |
+
logger.debug("Kokoro estimate: %.1fs for '%s'", kokoro_dur, text[:40])
|
| 241 |
+
else:
|
| 242 |
+
words = len(text.split())
|
| 243 |
+
base_duration = words / FALLBACK_WORDS_PER_SEC + 0.5
|
| 244 |
+
|
| 245 |
+
action_time = num_actions * ACTION_DURATION_S
|
| 246 |
+
duration = (base_duration + action_time) * multiplier
|
| 247 |
+
return min(duration, MAX_CHUNK_DURATION_S)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def plan_chunks(
|
| 251 |
+
xml_string: str,
|
| 252 |
+
base_seed: int = -1,
|
| 253 |
+
pace: float = LTX_MULTIPLIER,
|
| 254 |
+
) -> list[ChunkSpec]:
|
| 255 |
+
"""Plan generation chunks from an XML prompt.
|
| 256 |
+
|
| 257 |
+
Validates XML, extracts text, splits into duration-based chunks
|
| 258 |
+
using Kokoro, and builds per-chunk compiled prompts.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
xml_string: Valid <speak> XML string
|
| 262 |
+
base_seed: Base seed (-1 for random, otherwise sequential per chunk)
|
| 263 |
+
pace: Duration multiplier (default 1.5). Higher = slower speech.
|
| 264 |
+
"""
|
| 265 |
+
result = validate_prompt(xml_string)
|
| 266 |
+
if not result.valid:
|
| 267 |
+
raise ValueError(f"Invalid prompt: {'; '.join(result.errors)}")
|
| 268 |
+
|
| 269 |
+
compiled = compile_prompt(xml_string)
|
| 270 |
+
|
| 271 |
+
if base_seed == -1:
|
| 272 |
+
base_seed = random.randint(0, 999999)
|
| 273 |
+
|
| 274 |
+
# Check if entire text fits in a single chunk (uncapped duration for this check)
|
| 275 |
+
kokoro_dur = _kokoro_duration(compiled.speech_text)
|
| 276 |
+
if kokoro_dur is not None:
|
| 277 |
+
total_dur = kokoro_dur * pace
|
| 278 |
+
else:
|
| 279 |
+
words = len(compiled.speech_text.split())
|
| 280 |
+
total_dur = (words / FALLBACK_WORDS_PER_SEC + 0.5) * pace
|
| 281 |
+
|
| 282 |
+
if total_dur <= MAX_CHUNK_DURATION_S:
|
| 283 |
+
return [
|
| 284 |
+
ChunkSpec(
|
| 285 |
+
compiled_prompt=compiled.prompt,
|
| 286 |
+
duration_s=min(total_dur, MAX_CHUNK_DURATION_S),
|
| 287 |
+
seed=base_seed,
|
| 288 |
+
expected_text=compiled.speech_text,
|
| 289 |
+
language=compiled.language,
|
| 290 |
+
)
|
| 291 |
+
]
|
| 292 |
+
|
| 293 |
+
# Extract action-to-sentence mapping before splitting
|
| 294 |
+
sentence_action_map = extract_sentence_actions(xml_string)
|
| 295 |
+
|
| 296 |
+
# Split by Kokoro-estimated duration
|
| 297 |
+
text_chunks = split_text_by_duration(compiled.speech_text, multiplier=pace)
|
| 298 |
+
|
| 299 |
+
# Track which global sentence index each chunk starts at
|
| 300 |
+
global_sentence_idx = 0
|
| 301 |
+
|
| 302 |
+
specs: list[ChunkSpec] = []
|
| 303 |
+
for i, (chunk_text, chunk_dur) in enumerate(text_chunks):
|
| 304 |
+
# Find actions that belong to this chunk's first sentence
|
| 305 |
+
actions_before = sentence_action_map.get(global_sentence_idx)
|
| 306 |
+
|
| 307 |
+
chunk_prompt = compile_chunk_prompt(
|
| 308 |
+
speech_text=chunk_text,
|
| 309 |
+
voice=compiled.voice,
|
| 310 |
+
scene=compiled.scene,
|
| 311 |
+
actions_before=actions_before,
|
| 312 |
+
gender=compiled.gender,
|
| 313 |
+
shot=compiled.shot,
|
| 314 |
+
)
|
| 315 |
+
specs.append(
|
| 316 |
+
ChunkSpec(
|
| 317 |
+
compiled_prompt=chunk_prompt,
|
| 318 |
+
duration_s=chunk_dur,
|
| 319 |
+
seed=base_seed + i * 1000,
|
| 320 |
+
expected_text=chunk_text,
|
| 321 |
+
language=compiled.language,
|
| 322 |
+
)
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Count sentences in this chunk to advance global index
|
| 326 |
+
chunk_sentences = _split_into_sentences(chunk_text)
|
| 327 |
+
global_sentence_idx += len(chunk_sentences)
|
| 328 |
+
|
| 329 |
+
logger.info(
|
| 330 |
+
"Planned %d chunks (%.1fs total estimated)",
|
| 331 |
+
len(specs),
|
| 332 |
+
sum(s.duration_s for s in specs),
|
| 333 |
+
)
|
| 334 |
+
return specs
|
src/audio_core/compiler.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""XML prompt compiler for Scenema Audio.
|
| 6 |
+
|
| 7 |
+
Compiles a <speak> XML prompt into the video-style flat text prompt
|
| 8 |
+
that the LTX 2.3 audio model expects.
|
| 9 |
+
|
| 10 |
+
Supports three block types inside <speak>:
|
| 11 |
+
<action> β delivery/performance cues (how the person speaks/acts)
|
| 12 |
+
<sound> β audio events that should be heard (SFX, ambient sounds)
|
| 13 |
+
Text β the actual speech content
|
| 14 |
+
|
| 15 |
+
And three shot modes via the shot attribute:
|
| 16 |
+
closeup (default) β speech-focused, no SFX, clean audio
|
| 17 |
+
wide β environment + speech, SFX prominent
|
| 18 |
+
scene β raw scene description, maximum SFX
|
| 19 |
+
|
| 20 |
+
Example (closeup mode):
|
| 21 |
+
Input:
|
| 22 |
+
<speak voice="Deep male voice" scene="A dimly lit room" gender="male">
|
| 23 |
+
<action>He takes a slow breath</action>
|
| 24 |
+
Many years later, as he faced the firing squad...
|
| 25 |
+
</speak>
|
| 26 |
+
|
| 27 |
+
Output:
|
| 28 |
+
Close-up in a dimly lit room. He takes a slow breath.
|
| 29 |
+
"Many years later, as he faced the firing squad..."
|
| 30 |
+
Deep male voice.
|
| 31 |
+
|
| 32 |
+
Example (scene mode with SFX):
|
| 33 |
+
Input:
|
| 34 |
+
<speak voice="Tense male whisper" scene="Dark room, heavy rain"
|
| 35 |
+
gender="male" shot="scene">
|
| 36 |
+
<sound>A phone rings twice then stops</sound>
|
| 37 |
+
<action>He picks up the receiver and speaks in a low whisper</action>
|
| 38 |
+
Its done. The package is at the location.
|
| 39 |
+
<sound>Thunder rumbles in the distance</sound>
|
| 40 |
+
<action>He continues urgently</action>
|
| 41 |
+
You have thirty minutes.
|
| 42 |
+
</speak>
|
| 43 |
+
|
| 44 |
+
Output:
|
| 45 |
+
Dark room, heavy rain. A phone rings twice then stops.
|
| 46 |
+
He picks up the receiver and speaks in a low whisper:
|
| 47 |
+
"Its done. The package is at the location."
|
| 48 |
+
Thunder rumbles in the distance. He continues urgently:
|
| 49 |
+
"You have thirty minutes."
|
| 50 |
+
Tense male whisper. Dark room, heavy rain.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
import xml.etree.ElementTree as ET
|
| 54 |
+
from dataclasses import dataclass
|
| 55 |
+
|
| 56 |
+
DEFAULT_SCENE = "a person speaking to camera"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class CompiledPrompt:
|
| 61 |
+
prompt: str
|
| 62 |
+
speech_text: str
|
| 63 |
+
voice: str
|
| 64 |
+
scene: str | None
|
| 65 |
+
language: str
|
| 66 |
+
gender: str
|
| 67 |
+
shot: str
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@dataclass
|
| 71 |
+
class TextBlock:
|
| 72 |
+
text: str
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class ActionBlock:
|
| 77 |
+
text: str
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@dataclass
|
| 81 |
+
class SoundBlock:
|
| 82 |
+
text: str
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
Block = TextBlock | ActionBlock | SoundBlock
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _extract_blocks(root: ET.Element) -> list[Block]:
|
| 89 |
+
"""Walk <speak> children in document order, extract text, action, and sound blocks."""
|
| 90 |
+
blocks: list[Block] = []
|
| 91 |
+
|
| 92 |
+
if root.text and root.text.strip():
|
| 93 |
+
blocks.append(TextBlock(text=root.text.strip()))
|
| 94 |
+
|
| 95 |
+
for child in root:
|
| 96 |
+
if child.tag == "action" and child.text and child.text.strip():
|
| 97 |
+
blocks.append(ActionBlock(text=child.text.strip()))
|
| 98 |
+
elif child.tag == "sound" and child.text and child.text.strip():
|
| 99 |
+
blocks.append(SoundBlock(text=child.text.strip()))
|
| 100 |
+
if child.tail and child.tail.strip():
|
| 101 |
+
blocks.append(TextBlock(text=child.tail.strip()))
|
| 102 |
+
|
| 103 |
+
return blocks
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _ensure_trailing_punctuation(text: str) -> str:
|
| 107 |
+
"""Ensure text ends with sentence-ending punctuation."""
|
| 108 |
+
if text and text[-1] not in ".!?\"'":
|
| 109 |
+
return text + "."
|
| 110 |
+
return text
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
SHOT_PREFIXES = {
|
| 114 |
+
"closeup": "Close-up in",
|
| 115 |
+
"wide": "Wide shot of",
|
| 116 |
+
"scene": "",
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _compile_blocks(
|
| 121 |
+
blocks: list[Block],
|
| 122 |
+
voice: str,
|
| 123 |
+
scene: str | None,
|
| 124 |
+
gender: str = "male",
|
| 125 |
+
shot: str = "closeup",
|
| 126 |
+
) -> str:
|
| 127 |
+
"""Compile blocks into the video-style prompt string."""
|
| 128 |
+
parts: list[str] = []
|
| 129 |
+
is_scene_mode = shot in ("scene", "wide")
|
| 130 |
+
pronoun = "She" if gender == "female" else "He"
|
| 131 |
+
|
| 132 |
+
scene_text = scene if scene else DEFAULT_SCENE
|
| 133 |
+
prefix = SHOT_PREFIXES.get(shot, SHOT_PREFIXES["closeup"])
|
| 134 |
+
if prefix:
|
| 135 |
+
parts.append(f"{prefix} {scene_text}.")
|
| 136 |
+
else:
|
| 137 |
+
parts.append(f"{scene_text}.")
|
| 138 |
+
|
| 139 |
+
first_speech = True
|
| 140 |
+
for block in blocks:
|
| 141 |
+
if isinstance(block, SoundBlock):
|
| 142 |
+
# Sound events compile as standalone sentences
|
| 143 |
+
parts.append(_ensure_trailing_punctuation(block.text))
|
| 144 |
+
elif isinstance(block, ActionBlock):
|
| 145 |
+
if is_scene_mode:
|
| 146 |
+
# In scene/wide mode, action flows into speech with connector
|
| 147 |
+
# Don't add punctuation β the colon before the quote handles it
|
| 148 |
+
parts.append(block.text + ":")
|
| 149 |
+
else:
|
| 150 |
+
# In closeup mode, action is a standalone sentence
|
| 151 |
+
parts.append(_ensure_trailing_punctuation(block.text))
|
| 152 |
+
elif isinstance(block, TextBlock):
|
| 153 |
+
clean_text = _ensure_trailing_punctuation(block.text)
|
| 154 |
+
if (
|
| 155 |
+
is_scene_mode
|
| 156 |
+
and first_speech
|
| 157 |
+
and not any(isinstance(b, ActionBlock) for b in blocks)
|
| 158 |
+
):
|
| 159 |
+
# No action before first speech in scene mode β add pronoun
|
| 160 |
+
parts.append(f'{pronoun} speaks: "{clean_text}"')
|
| 161 |
+
else:
|
| 162 |
+
parts.append(f'"{clean_text}"')
|
| 163 |
+
first_speech = False
|
| 164 |
+
|
| 165 |
+
parts.append(_ensure_trailing_punctuation(voice))
|
| 166 |
+
|
| 167 |
+
# In scene/wide mode, repeat scene as SFX reinforcement at the end
|
| 168 |
+
if is_scene_mode and scene:
|
| 169 |
+
parts.append(_ensure_trailing_punctuation(scene))
|
| 170 |
+
|
| 171 |
+
return " ".join(parts)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _extract_speech_only(blocks: list[Block]) -> str:
|
| 175 |
+
"""Extract only speech text (no actions or sounds) for duration estimation."""
|
| 176 |
+
texts = [b.text for b in blocks if isinstance(b, TextBlock)]
|
| 177 |
+
return " ".join(texts)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def compile_prompt(xml_string: str) -> CompiledPrompt:
|
| 181 |
+
"""Compile a <speak> XML prompt into a video-style text prompt.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
xml_string: Valid <speak> XML string (must pass validate_prompt first)
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
CompiledPrompt with the compiled prompt and extracted metadata
|
| 188 |
+
"""
|
| 189 |
+
root = ET.fromstring(xml_string)
|
| 190 |
+
|
| 191 |
+
voice = root.get("voice", "").strip()
|
| 192 |
+
scene = root.get("scene")
|
| 193 |
+
if scene:
|
| 194 |
+
scene = scene.strip()
|
| 195 |
+
language = root.get("language", "en").strip()
|
| 196 |
+
gender = root.get("gender", "male").strip()
|
| 197 |
+
shot = root.get("shot", "closeup").strip()
|
| 198 |
+
|
| 199 |
+
blocks = _extract_blocks(root)
|
| 200 |
+
prompt = _compile_blocks(blocks, voice, scene, gender, shot)
|
| 201 |
+
speech_text = _extract_speech_only(blocks)
|
| 202 |
+
|
| 203 |
+
return CompiledPrompt(
|
| 204 |
+
prompt=prompt,
|
| 205 |
+
speech_text=speech_text,
|
| 206 |
+
voice=voice,
|
| 207 |
+
scene=scene,
|
| 208 |
+
language=language,
|
| 209 |
+
gender=gender,
|
| 210 |
+
shot=shot,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def extract_sentence_actions(xml_string: str) -> dict[int, list[str]]:
|
| 215 |
+
"""Map sentence indices to their preceding action blocks.
|
| 216 |
+
|
| 217 |
+
Walks the XML blocks in order, tracking the most recent action(s).
|
| 218 |
+
When a text block is encountered, its sentences inherit the pending actions.
|
| 219 |
+
Only the first sentence of each text block gets the actions (the action
|
| 220 |
+
precedes the text block in the XML).
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
Dict mapping sentence index (0-based across all speech text) to a list
|
| 224 |
+
of action strings that precede that sentence.
|
| 225 |
+
"""
|
| 226 |
+
root = ET.fromstring(xml_string)
|
| 227 |
+
blocks = _extract_blocks(root)
|
| 228 |
+
|
| 229 |
+
sentence_actions: dict[int, list[str]] = {}
|
| 230 |
+
pending_actions: list[str] = []
|
| 231 |
+
sentence_idx = 0
|
| 232 |
+
|
| 233 |
+
for block in blocks:
|
| 234 |
+
if isinstance(block, ActionBlock):
|
| 235 |
+
pending_actions.append(block.text)
|
| 236 |
+
elif isinstance(block, TextBlock):
|
| 237 |
+
# Split this text block into sentences to count them
|
| 238 |
+
text = block.text.strip()
|
| 239 |
+
sentences = []
|
| 240 |
+
current = ""
|
| 241 |
+
for char in text:
|
| 242 |
+
current += char
|
| 243 |
+
if char in ".!?":
|
| 244 |
+
s = current.strip()
|
| 245 |
+
if s:
|
| 246 |
+
sentences.append(s)
|
| 247 |
+
current = ""
|
| 248 |
+
if current.strip():
|
| 249 |
+
sentences.append(current.strip())
|
| 250 |
+
|
| 251 |
+
if pending_actions and sentences:
|
| 252 |
+
sentence_actions[sentence_idx] = pending_actions.copy()
|
| 253 |
+
pending_actions.clear()
|
| 254 |
+
|
| 255 |
+
sentence_idx += len(sentences)
|
| 256 |
+
|
| 257 |
+
return sentence_actions
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def extract_speech_text(xml_string: str) -> str:
|
| 261 |
+
"""Extract only the speech text from XML, ignoring actions and sounds.
|
| 262 |
+
|
| 263 |
+
Useful for duration estimation (Kokoro) without compiling the full prompt.
|
| 264 |
+
"""
|
| 265 |
+
root = ET.fromstring(xml_string)
|
| 266 |
+
blocks = _extract_blocks(root)
|
| 267 |
+
return _extract_speech_only(blocks)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def compile_chunk_prompt(
|
| 271 |
+
speech_text: str,
|
| 272 |
+
voice: str,
|
| 273 |
+
scene: str | None = None,
|
| 274 |
+
actions_before: list[str] | None = None,
|
| 275 |
+
actions_after: list[str] | None = None,
|
| 276 |
+
gender: str = "male",
|
| 277 |
+
shot: str = "closeup",
|
| 278 |
+
) -> str:
|
| 279 |
+
"""Compile a single chunk's prompt from pre-split text.
|
| 280 |
+
|
| 281 |
+
Used by the chunker to build per-chunk prompts after text splitting.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
speech_text: The chunk's speech text portion.
|
| 285 |
+
voice: Voice description string.
|
| 286 |
+
scene: Scene description string (optional).
|
| 287 |
+
actions_before: Action blocks to prepend before speech.
|
| 288 |
+
actions_after: Action blocks to append after speech.
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
Compiled video-style prompt string.
|
| 292 |
+
"""
|
| 293 |
+
blocks: list[Block] = []
|
| 294 |
+
|
| 295 |
+
if actions_before:
|
| 296 |
+
for a in actions_before:
|
| 297 |
+
blocks.append(ActionBlock(text=a))
|
| 298 |
+
|
| 299 |
+
blocks.append(TextBlock(text=speech_text))
|
| 300 |
+
|
| 301 |
+
if actions_after:
|
| 302 |
+
for a in actions_after:
|
| 303 |
+
blocks.append(ActionBlock(text=a))
|
| 304 |
+
|
| 305 |
+
return _compile_blocks(blocks, voice, scene, gender, shot)
|
src/audio_core/engine.py
ADDED
|
@@ -0,0 +1,911 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""Audio generation engine for Scenema Audio.
|
| 6 |
+
|
| 7 |
+
Loads the LTX 2.3 audio-only checkpoint, Audio VAE encoder, and
|
| 8 |
+
Gemma 3 12B text encoder. VRAM management is auto-detected: models
|
| 9 |
+
are moved between GPU and CPU as needed per inference phase.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import gc
|
| 13 |
+
import json
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
from contextlib import contextmanager
|
| 18 |
+
from dataclasses import dataclass, replace as dc_replace
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import psutil
|
| 22 |
+
import torch
|
| 23 |
+
import torchaudio
|
| 24 |
+
from safetensors import safe_open
|
| 25 |
+
from safetensors.torch import load_file
|
| 26 |
+
|
| 27 |
+
from ltx_core.batch_split import BatchSplitAdapter, BatchedPerturbationConfig
|
| 28 |
+
from ltx_core.components.diffusion_steps import EulerDiffusionStep
|
| 29 |
+
from ltx_core.components.noisers import GaussianNoiser
|
| 30 |
+
from ltx_core.components.patchifiers import AudioPatchifier, VideoLatentPatchifier
|
| 31 |
+
from ltx_core.model.audio_vae.audio_vae import Audio, encode_audio
|
| 32 |
+
from ltx_core.model.audio_vae.model_configurator import AudioEncoderConfigurator
|
| 33 |
+
from ltx_core.model.transformer.model import X0Model
|
| 34 |
+
from ltx_core.model.transformer.model_configurator import LTXModelConfigurator
|
| 35 |
+
from ltx_core.model.transformer.transformer import BasicAVTransformerBlock, rms_norm
|
| 36 |
+
from ltx_core.tools import AudioLatentTools, LatentState, VideoLatentTools
|
| 37 |
+
from ltx_core.types import AudioLatentShape, VideoLatentShape, VideoPixelShape
|
| 38 |
+
from ltx_pipelines.distilled import DISTILLED_SIGMAS, DistilledPipeline
|
| 39 |
+
from ltx_pipelines.utils.blocks import ModalitySpec, _build_state
|
| 40 |
+
from ltx_pipelines.utils.denoisers import SimpleDenoiser
|
| 41 |
+
from ltx_pipelines.utils.samplers import euler_denoising_loop
|
| 42 |
+
from ltx_pipelines.utils.types import OffloadMode
|
| 43 |
+
from ltx_core.text_encoders.gemma.tokenizer import LTXVGemmaTokenizer
|
| 44 |
+
import bitsandbytes # noqa: F401
|
| 45 |
+
from transformers import BitsAndBytesConfig, Gemma3ForConditionalGeneration
|
| 46 |
+
|
| 47 |
+
from .audio_utils import extract_wav
|
| 48 |
+
|
| 49 |
+
logger = logging.getLogger(__name__)
|
| 50 |
+
|
| 51 |
+
FPS = 24
|
| 52 |
+
MAX_REF_SECONDS = 5
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class _Int8Linear(torch.nn.Module):
|
| 56 |
+
"""Linear layer with INT8 weights, dequantized to input dtype during forward.
|
| 57 |
+
|
| 58 |
+
Keeps weights as int8 buffers in VRAM (~50% of bf16). Dequantization
|
| 59 |
+
happens per forward pass: weight = int8 * scale, then cast to input dtype.
|
| 60 |
+
Ported from bench_full_quantized.py.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self, weight_int8, scale, bias=None):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.register_buffer("weight_int8", weight_int8)
|
| 66 |
+
self.register_buffer("scale", scale)
|
| 67 |
+
if bias is not None:
|
| 68 |
+
self.register_parameter("bias", torch.nn.Parameter(bias))
|
| 69 |
+
else:
|
| 70 |
+
self.bias = None
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
w = self.weight_int8.float() * self.scale.unsqueeze(1)
|
| 74 |
+
w = w.to(x.dtype)
|
| 75 |
+
return torch.nn.functional.linear(x, w, self.bias)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# VRAM threshold: cards with this much VRAM keep all models GPU-resident
|
| 79 |
+
# (Gemma bf16 on GPU, no offloading, MelBandRoFormer + SeedVC preloaded).
|
| 80 |
+
# Below this: Gemma streams from CPU, models load/unload per request.
|
| 81 |
+
HIGH_VRAM_THRESHOLD_GB = 40
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@dataclass
|
| 85 |
+
class AudioResult:
|
| 86 |
+
waveform_np: np.ndarray # (samples,) or (samples, channels) float32
|
| 87 |
+
sample_rate: int
|
| 88 |
+
duration_s: float
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _materialize_meta_tensors(module, device="cpu"):
|
| 92 |
+
"""Replace meta tensors with zeros on the specified device."""
|
| 93 |
+
for name, param in list(module.named_parameters()):
|
| 94 |
+
if param.is_meta:
|
| 95 |
+
parts = name.split(".")
|
| 96 |
+
mod = module
|
| 97 |
+
for p in parts[:-1]:
|
| 98 |
+
mod = getattr(mod, p)
|
| 99 |
+
mod._parameters[parts[-1]] = torch.nn.Parameter(
|
| 100 |
+
torch.zeros(param.shape, dtype=torch.bfloat16, device=device)
|
| 101 |
+
)
|
| 102 |
+
for name, buf in list(module.named_buffers()):
|
| 103 |
+
if buf.is_meta:
|
| 104 |
+
parts = name.split(".")
|
| 105 |
+
mod = module
|
| 106 |
+
for p in parts[:-1]:
|
| 107 |
+
mod = getattr(mod, p)
|
| 108 |
+
mod._buffers[parts[-1]] = torch.zeros(
|
| 109 |
+
buf.shape, dtype=torch.bfloat16, device=device
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _audio_only_forward(self, video, audio, perturbations=None):
|
| 114 |
+
"""Monkey-patched forward for audio-only transformer blocks.
|
| 115 |
+
|
| 116 |
+
Skips all video computation (attn1, attn2, ff, audio_to_video_attn)
|
| 117 |
+
and only runs audio self-attention, cross-attention, and feedforward.
|
| 118 |
+
"""
|
| 119 |
+
if video is None and audio is None:
|
| 120 |
+
raise ValueError("Need at least one modality")
|
| 121 |
+
batch_size = (video or audio).x.shape[0]
|
| 122 |
+
if perturbations is None:
|
| 123 |
+
perturbations = BatchedPerturbationConfig.empty(batch_size)
|
| 124 |
+
vx = video.x if video is not None else None
|
| 125 |
+
ax = audio.x if audio is not None else None
|
| 126 |
+
run_ax = audio is not None and audio.enabled and ax.numel() > 0
|
| 127 |
+
if run_ax:
|
| 128 |
+
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
|
| 129 |
+
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
|
| 130 |
+
)
|
| 131 |
+
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
| 132 |
+
del ashift_msa, ascale_msa
|
| 133 |
+
ax = (
|
| 134 |
+
ax
|
| 135 |
+
+ self.audio_attn1(
|
| 136 |
+
norm_ax, pe=audio.positional_embeddings, mask=audio.self_attention_mask
|
| 137 |
+
)
|
| 138 |
+
* agate_msa
|
| 139 |
+
)
|
| 140 |
+
del agate_msa, norm_ax
|
| 141 |
+
ax = ax + self._apply_text_cross_attention(
|
| 142 |
+
ax,
|
| 143 |
+
audio.context,
|
| 144 |
+
self.audio_attn2,
|
| 145 |
+
self.audio_scale_shift_table,
|
| 146 |
+
getattr(self, "audio_prompt_scale_shift_table", None),
|
| 147 |
+
audio.timesteps,
|
| 148 |
+
audio.prompt_timestep,
|
| 149 |
+
audio.context_mask,
|
| 150 |
+
cross_attention_adaln=self.cross_attention_adaln,
|
| 151 |
+
)
|
| 152 |
+
ashift_ff, ascale_ff, agate_ff = self.get_ada_values(
|
| 153 |
+
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
|
| 154 |
+
)
|
| 155 |
+
norm_ax_ff = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_ff) + ashift_ff
|
| 156 |
+
del ashift_ff, ascale_ff
|
| 157 |
+
ax = ax + self.audio_ff(norm_ax_ff) * agate_ff
|
| 158 |
+
del agate_ff, norm_ax_ff
|
| 159 |
+
if video is not None:
|
| 160 |
+
object.__setattr__(video, "x", vx)
|
| 161 |
+
if audio is not None:
|
| 162 |
+
object.__setattr__(audio, "x", ax)
|
| 163 |
+
return video, audio
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# ββ VRAM Manager ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class VRAMManager:
|
| 170 |
+
"""Manages model placement between GPU and CPU based on available VRAM.
|
| 171 |
+
|
| 172 |
+
Tracks which models are on GPU and moves them as needed per inference phase.
|
| 173 |
+
Offloading is determined by comparing total registered model size against
|
| 174 |
+
available VRAM. If all models fit, no offloading occurs.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def __init__(self, vram_gb: float):
|
| 178 |
+
self.vram_gb = vram_gb
|
| 179 |
+
self._models: dict[str, torch.nn.Module] = {}
|
| 180 |
+
self._model_sizes: dict[str, float] = {} # GB per model
|
| 181 |
+
self._on_gpu: set[str] = set()
|
| 182 |
+
self.needs_offload = False # Determined after all models registered
|
| 183 |
+
|
| 184 |
+
def register(self, name: str, model: torch.nn.Module, on_gpu: bool = True) -> None:
|
| 185 |
+
"""Register a model for VRAM management.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
name: Identifier for the model.
|
| 189 |
+
model: The PyTorch module.
|
| 190 |
+
on_gpu: Whether the model is currently on GPU.
|
| 191 |
+
"""
|
| 192 |
+
self._models[name] = model
|
| 193 |
+
size_gb = sum(p.numel() * p.element_size() for p in model.parameters()) / 1e9
|
| 194 |
+
self._model_sizes[name] = size_gb
|
| 195 |
+
if on_gpu:
|
| 196 |
+
self._on_gpu.add(name)
|
| 197 |
+
|
| 198 |
+
def finalize(self) -> None:
|
| 199 |
+
"""Determine offloading strategy based on total model size vs VRAM.
|
| 200 |
+
|
| 201 |
+
Call after all models are registered. Sets needs_offload based on
|
| 202 |
+
whether all registered models fit in VRAM simultaneously with
|
| 203 |
+
headroom for activations and pipeline overhead (~5GB).
|
| 204 |
+
"""
|
| 205 |
+
total_model_gb = sum(self._model_sizes.values())
|
| 206 |
+
# Gemma overhead depends on quantization mode:
|
| 207 |
+
# bf16 streaming: ~16GB peak (13GB Gemma + 2GB embeddings + 1GB safety)
|
| 208 |
+
# NF4: ~11GB peak (8GB NF4 model on GPU + 2GB embeddings + 1GB safety)
|
| 209 |
+
gemma_nf4 = os.environ.get("GEMMA_QUANTIZE", "").lower() == "nf4"
|
| 210 |
+
gemma_overhead_gb = 11.0 if gemma_nf4 else 16.0
|
| 211 |
+
self.needs_offload = (total_model_gb + gemma_overhead_gb) > self.vram_gb
|
| 212 |
+
logger.info(
|
| 213 |
+
"VRAM strategy: %.1f GB models + %.1f GB Gemma overhead (%s) vs %.1f GB VRAM -> offload=%s",
|
| 214 |
+
total_model_gb,
|
| 215 |
+
gemma_overhead_gb,
|
| 216 |
+
"nf4" if gemma_nf4 else "bf16",
|
| 217 |
+
self.vram_gb,
|
| 218 |
+
"yes" if self.needs_offload else "no",
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
def to_gpu(self, *names: str) -> None:
|
| 222 |
+
"""Move specified models to GPU, offloading others if needed.
|
| 223 |
+
|
| 224 |
+
If offloading is required (VRAM < 40GB), all models NOT in the
|
| 225 |
+
requested set are moved to CPU first to free VRAM.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
names: Model names that should be on GPU for the current phase.
|
| 229 |
+
"""
|
| 230 |
+
if not self.needs_offload:
|
| 231 |
+
# High VRAM: just ensure requested models are on GPU
|
| 232 |
+
for name in names:
|
| 233 |
+
if name not in self._on_gpu and name in self._models:
|
| 234 |
+
self._models[name].cuda()
|
| 235 |
+
self._on_gpu.add(name)
|
| 236 |
+
return
|
| 237 |
+
|
| 238 |
+
# Offload models that shouldn't be on GPU
|
| 239 |
+
needed = set(names)
|
| 240 |
+
to_offload = self._on_gpu - needed
|
| 241 |
+
for name in to_offload:
|
| 242 |
+
if name in self._models:
|
| 243 |
+
self._models[name].cpu()
|
| 244 |
+
self._on_gpu.discard(name)
|
| 245 |
+
logger.debug("Offloaded %s to CPU", name)
|
| 246 |
+
|
| 247 |
+
torch.cuda.empty_cache()
|
| 248 |
+
|
| 249 |
+
# Load requested models to GPU
|
| 250 |
+
for name in names:
|
| 251 |
+
if name not in self._on_gpu and name in self._models:
|
| 252 |
+
self._models[name].cuda()
|
| 253 |
+
self._on_gpu.add(name)
|
| 254 |
+
logger.debug("Loaded %s to GPU", name)
|
| 255 |
+
|
| 256 |
+
def free_all(self) -> None:
|
| 257 |
+
"""Move all models to CPU."""
|
| 258 |
+
for name in list(self._on_gpu):
|
| 259 |
+
if name in self._models:
|
| 260 |
+
self._models[name].cpu()
|
| 261 |
+
self._on_gpu.clear()
|
| 262 |
+
torch.cuda.empty_cache()
|
| 263 |
+
|
| 264 |
+
@contextmanager
|
| 265 |
+
def phase(self, *names: str):
|
| 266 |
+
"""Context manager for a VRAM phase.
|
| 267 |
+
|
| 268 |
+
Ensures specified models are on GPU for the duration, then
|
| 269 |
+
returns to previous state on exit.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
names: Model names needed on GPU for this phase.
|
| 273 |
+
"""
|
| 274 |
+
prev_on_gpu = set(self._on_gpu)
|
| 275 |
+
self.to_gpu(*names)
|
| 276 |
+
try:
|
| 277 |
+
yield
|
| 278 |
+
finally:
|
| 279 |
+
# Restore previous state only if offloading is needed
|
| 280 |
+
if self.needs_offload:
|
| 281 |
+
to_restore = prev_on_gpu - set(names)
|
| 282 |
+
to_remove = set(names) - prev_on_gpu
|
| 283 |
+
for name in to_remove:
|
| 284 |
+
if name in self._models and name in self._on_gpu:
|
| 285 |
+
self._models[name].cpu()
|
| 286 |
+
self._on_gpu.discard(name)
|
| 287 |
+
for name in to_restore:
|
| 288 |
+
if name in self._models and name not in self._on_gpu:
|
| 289 |
+
self._models[name].cuda()
|
| 290 |
+
self._on_gpu.add(name)
|
| 291 |
+
torch.cuda.empty_cache()
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# ββ Audio Engine ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class AudioEngine:
|
| 298 |
+
"""LTX 2.3 audio-only generation engine.
|
| 299 |
+
|
| 300 |
+
Loads the baked audio checkpoint, Audio VAE encoder, and Gemma 3 12B
|
| 301 |
+
text encoder. VRAM is managed automatically per inference phase.
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
def __init__(
|
| 305 |
+
self,
|
| 306 |
+
audio_ckpt_path: str,
|
| 307 |
+
vae_encoder_path: str,
|
| 308 |
+
gemma_root: str,
|
| 309 |
+
pipeline_ckpt_path: str | None = None,
|
| 310 |
+
):
|
| 311 |
+
"""Initialize AudioEngine.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
audio_ckpt_path: Path to the audio-only transformer checkpoint.
|
| 315 |
+
vae_encoder_path: Path to the standalone Audio VAE encoder checkpoint.
|
| 316 |
+
gemma_root: Path to the Gemma 3 12B model directory.
|
| 317 |
+
pipeline_ckpt_path: Path to checkpoint for DistilledPipeline.
|
| 318 |
+
"""
|
| 319 |
+
self.audio_ckpt_path = audio_ckpt_path
|
| 320 |
+
self.vae_encoder_path = vae_encoder_path
|
| 321 |
+
self.gemma_root = gemma_root
|
| 322 |
+
self.pipeline_ckpt_path = pipeline_ckpt_path or audio_ckpt_path
|
| 323 |
+
|
| 324 |
+
self._config = None
|
| 325 |
+
self._mdl_wrapper = None
|
| 326 |
+
self._audio_encoder = None
|
| 327 |
+
self._pipeline = None
|
| 328 |
+
self._vram: VRAMManager | None = None
|
| 329 |
+
self._vae_sr = None
|
| 330 |
+
self._loaded = False
|
| 331 |
+
|
| 332 |
+
@property
|
| 333 |
+
def vae_sample_rate(self) -> int:
|
| 334 |
+
return self._vae_sr or 16000
|
| 335 |
+
|
| 336 |
+
def load(self) -> None:
|
| 337 |
+
"""Load all models. Call once at startup."""
|
| 338 |
+
if self._loaded:
|
| 339 |
+
return
|
| 340 |
+
|
| 341 |
+
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 342 |
+
ram_gb = psutil.virtual_memory().total / 1e9
|
| 343 |
+
logger.info(
|
| 344 |
+
"System: %.1f GB VRAM, %.1f GB RAM, GPU: %s",
|
| 345 |
+
vram_gb,
|
| 346 |
+
ram_gb,
|
| 347 |
+
torch.cuda.get_device_name(0),
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
if vram_gb < 11:
|
| 351 |
+
raise RuntimeError(
|
| 352 |
+
f"Insufficient VRAM: {vram_gb:.0f}GB. Minimum 11GB required."
|
| 353 |
+
)
|
| 354 |
+
if ram_gb < 24:
|
| 355 |
+
raise RuntimeError(
|
| 356 |
+
f"Insufficient RAM: {ram_gb:.0f}GB. Minimum 24GB required."
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
self._vram = VRAMManager(vram_gb)
|
| 360 |
+
|
| 361 |
+
self._load_audio_model()
|
| 362 |
+
self._load_vae_encoder()
|
| 363 |
+
self._patch_transformer_blocks()
|
| 364 |
+
self._build_pipeline()
|
| 365 |
+
|
| 366 |
+
# Determine offloading strategy based on actual model sizes vs VRAM
|
| 367 |
+
self._vram.finalize()
|
| 368 |
+
|
| 369 |
+
self._loaded = True
|
| 370 |
+
logger.info("AudioEngine loaded")
|
| 371 |
+
|
| 372 |
+
def _load_audio_model(self) -> None:
|
| 373 |
+
"""Load the audio-only checkpoint to GPU.
|
| 374 |
+
|
| 375 |
+
Supports both bf16 and INT8 quantized checkpoints. INT8 checkpoints
|
| 376 |
+
store weights as .weight.int8 (int8) + .weight.scale (float32) pairs.
|
| 377 |
+
For INT8, nn.Linear layers are replaced with Int8Linear modules that
|
| 378 |
+
keep weights quantized in VRAM (~5GB vs 9.8GB) and dequantize during
|
| 379 |
+
the forward pass.
|
| 380 |
+
"""
|
| 381 |
+
t0 = time.time()
|
| 382 |
+
|
| 383 |
+
with safe_open(self.audio_ckpt_path, framework="pt") as f:
|
| 384 |
+
self._config = json.loads(f.metadata()["config"])
|
| 385 |
+
|
| 386 |
+
with torch.device("meta"):
|
| 387 |
+
mdl = LTXModelConfigurator.from_config(self._config)
|
| 388 |
+
|
| 389 |
+
sd = load_file(self.audio_ckpt_path, device="cpu")
|
| 390 |
+
|
| 391 |
+
# Detect INT8 checkpoint format
|
| 392 |
+
int8_map = {
|
| 393 |
+
k.replace(".weight.int8", ""): k for k in sd if k.endswith(".weight.int8")
|
| 394 |
+
}
|
| 395 |
+
scale_map = {
|
| 396 |
+
k.replace(".weight.scale", ""): k for k in sd if k.endswith(".weight.scale")
|
| 397 |
+
}
|
| 398 |
+
is_int8 = len(int8_map) > 0
|
| 399 |
+
|
| 400 |
+
if is_int8:
|
| 401 |
+
# Load only non-quantized keys first (biases, norms, embeddings)
|
| 402 |
+
regular_sd = {
|
| 403 |
+
k: v
|
| 404 |
+
for k, v in sd.items()
|
| 405 |
+
if not k.endswith(".int8") and not k.endswith(".scale")
|
| 406 |
+
}
|
| 407 |
+
mdl_wrapper = X0Model(mdl)
|
| 408 |
+
mdl_wrapper.load_state_dict(regular_sd, strict=False, assign=True)
|
| 409 |
+
|
| 410 |
+
# Replace nn.Linear with Int8Linear for quantized weights
|
| 411 |
+
n_replaced = 0
|
| 412 |
+
for name in int8_map:
|
| 413 |
+
w_int8 = sd[int8_map[name]]
|
| 414 |
+
w_scale = sd[scale_map[name]]
|
| 415 |
+
parts = name.split(".")
|
| 416 |
+
parent = mdl_wrapper
|
| 417 |
+
for p in parts[:-1]:
|
| 418 |
+
parent = getattr(parent, p)
|
| 419 |
+
old = getattr(parent, parts[-1])
|
| 420 |
+
bias_key = name + ".bias"
|
| 421 |
+
bias = sd.get(bias_key)
|
| 422 |
+
if bias is None and hasattr(old, "bias") and old.bias is not None:
|
| 423 |
+
bias = old.bias.data
|
| 424 |
+
setattr(parent, parts[-1], _Int8Linear(w_int8, w_scale, bias))
|
| 425 |
+
n_replaced += 1
|
| 426 |
+
|
| 427 |
+
logger.info("INT8: replaced %d Linear layers with Int8Linear", n_replaced)
|
| 428 |
+
else:
|
| 429 |
+
mdl_wrapper = X0Model(mdl)
|
| 430 |
+
mdl_wrapper.load_state_dict(sd, strict=False, assign=True)
|
| 431 |
+
|
| 432 |
+
# Runtime INT8 quantization via BnB (bf16 checkpoint β INT8 on GPU)
|
| 433 |
+
if os.environ.get("TRANSFORMER_QUANTIZE", "").lower() == "int8":
|
| 434 |
+
import bitsandbytes as bnb
|
| 435 |
+
|
| 436 |
+
n_quantized = 0
|
| 437 |
+
for name, module in list(mdl_wrapper.named_modules()):
|
| 438 |
+
for cn, child in list(module.named_children()):
|
| 439 |
+
if (
|
| 440 |
+
isinstance(child, torch.nn.Linear)
|
| 441 |
+
and child.weight.numel() > 1_000_000
|
| 442 |
+
):
|
| 443 |
+
int8_layer = bnb.nn.Linear8bitLt(
|
| 444 |
+
child.in_features,
|
| 445 |
+
child.out_features,
|
| 446 |
+
bias=child.bias is not None,
|
| 447 |
+
has_fp16_weights=False,
|
| 448 |
+
)
|
| 449 |
+
int8_layer.weight = bnb.nn.Int8Params(
|
| 450 |
+
child.weight.data,
|
| 451 |
+
requires_grad=False,
|
| 452 |
+
has_fp16_weights=False,
|
| 453 |
+
)
|
| 454 |
+
if child.bias is not None:
|
| 455 |
+
int8_layer.bias = child.bias
|
| 456 |
+
setattr(module, cn, int8_layer)
|
| 457 |
+
n_quantized += 1
|
| 458 |
+
logger.info(
|
| 459 |
+
"Runtime INT8: quantized %d Linear layers via BnB", n_quantized
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
del sd
|
| 463 |
+
gc.collect()
|
| 464 |
+
|
| 465 |
+
for block in mdl.transformer_blocks:
|
| 466 |
+
block.attn1 = torch.nn.Identity()
|
| 467 |
+
block.attn2 = torch.nn.Identity()
|
| 468 |
+
block.ff = torch.nn.Identity()
|
| 469 |
+
block.audio_to_video_attn = torch.nn.Identity()
|
| 470 |
+
gc.collect()
|
| 471 |
+
|
| 472 |
+
_materialize_meta_tensors(mdl_wrapper)
|
| 473 |
+
|
| 474 |
+
cross_pe = max(
|
| 475 |
+
mdl.positional_embedding_max_pos[0],
|
| 476 |
+
mdl.audio_positional_embedding_max_pos[0],
|
| 477 |
+
)
|
| 478 |
+
mdl._init_preprocessors(cross_pe)
|
| 479 |
+
|
| 480 |
+
self._mdl_wrapper = mdl_wrapper.cuda().eval()
|
| 481 |
+
self._vram.register("audio_model", self._mdl_wrapper, on_gpu=True)
|
| 482 |
+
|
| 483 |
+
logger.info(
|
| 484 |
+
"Audio model loaded: %.1f GB, %.1fs",
|
| 485 |
+
torch.cuda.memory_allocated() / 1e9,
|
| 486 |
+
time.time() - t0,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
def _load_vae_encoder(self) -> None:
|
| 490 |
+
"""Load Audio VAE encoder from standalone checkpoint."""
|
| 491 |
+
t0 = time.time()
|
| 492 |
+
avae_cfg = self._config["audio_vae"]
|
| 493 |
+
preproc = avae_cfg["preprocessing"]
|
| 494 |
+
self._vae_sr = preproc["audio"]["sampling_rate"]
|
| 495 |
+
|
| 496 |
+
with torch.device("meta"):
|
| 497 |
+
encoder = AudioEncoderConfigurator().from_config(avae_cfg)
|
| 498 |
+
|
| 499 |
+
sd = load_file(self.vae_encoder_path, device="cpu")
|
| 500 |
+
encoder.load_state_dict(sd, strict=False, assign=True)
|
| 501 |
+
|
| 502 |
+
pcs = encoder.per_channel_statistics
|
| 503 |
+
if "per_channel_statistics.std-of-means" in sd:
|
| 504 |
+
pcs._buffers["std-of-means"] = sd["per_channel_statistics.std-of-means"]
|
| 505 |
+
pcs._buffers["mean-of-means"] = sd["per_channel_statistics.mean-of-means"]
|
| 506 |
+
del sd
|
| 507 |
+
|
| 508 |
+
dd = avae_cfg["model"]["params"]["ddconfig"]
|
| 509 |
+
encoder.mel_bins = dd["mel_bins"]
|
| 510 |
+
encoder.mid.attn_1 = torch.nn.Identity()
|
| 511 |
+
|
| 512 |
+
_materialize_meta_tensors(encoder, device="cpu")
|
| 513 |
+
|
| 514 |
+
self._audio_encoder = encoder.cuda().eval().to(torch.bfloat16)
|
| 515 |
+
self._vram.register("vae_encoder", self._audio_encoder, on_gpu=True)
|
| 516 |
+
|
| 517 |
+
logger.info(
|
| 518 |
+
"Audio VAE encoder loaded: %.1fM params, %.1fs",
|
| 519 |
+
sum(p.numel() for p in self._audio_encoder.parameters()) / 1e6,
|
| 520 |
+
time.time() - t0,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
def _patch_transformer_blocks(self) -> None:
|
| 524 |
+
"""Monkey-patch transformer blocks for audio-only forward pass."""
|
| 525 |
+
BasicAVTransformerBlock.forward = _audio_only_forward
|
| 526 |
+
logger.info("Transformer blocks patched for audio-only forward")
|
| 527 |
+
|
| 528 |
+
def _build_pipeline(self) -> None:
|
| 529 |
+
"""Build DistilledPipeline and cache Gemma + embeddings processor in CPU RAM.
|
| 530 |
+
|
| 531 |
+
Caching eliminates the ~35s rebuild cost on every encode call.
|
| 532 |
+
Gemma stays in CPU RAM permanently, streams to GPU layer-by-layer.
|
| 533 |
+
Embeddings processor shuttles between CPU and GPU per call.
|
| 534 |
+
"""
|
| 535 |
+
t0 = time.time()
|
| 536 |
+
mdl_wrapper = self._mdl_wrapper
|
| 537 |
+
|
| 538 |
+
# Use NONE offload when VRAM is sufficient so Gemma stays GPU-resident
|
| 539 |
+
# for fast encoding (~0.5s vs ~7s streaming). Fall back to CPU streaming
|
| 540 |
+
# on smaller cards.
|
| 541 |
+
offload = (
|
| 542 |
+
OffloadMode.NONE
|
| 543 |
+
if self._vram.vram_gb >= HIGH_VRAM_THRESHOLD_GB
|
| 544 |
+
else OffloadMode.CPU
|
| 545 |
+
)
|
| 546 |
+
self._pipeline = DistilledPipeline(
|
| 547 |
+
distilled_checkpoint_path=self.pipeline_ckpt_path,
|
| 548 |
+
gemma_root=self.gemma_root,
|
| 549 |
+
spatial_upsampler_path=None,
|
| 550 |
+
loras=[],
|
| 551 |
+
offload_mode=offload,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
@contextmanager
|
| 555 |
+
def _gpu_ctx(**kw):
|
| 556 |
+
yield mdl_wrapper
|
| 557 |
+
|
| 558 |
+
self._pipeline.stage._transformer_ctx = _gpu_ctx
|
| 559 |
+
|
| 560 |
+
pe = self._pipeline.prompt_encoder
|
| 561 |
+
|
| 562 |
+
# Gemma loading strategy:
|
| 563 |
+
# NF4: BitsAndBytes int4 quantization (~8GB on GPU, ~0.1s encode)
|
| 564 |
+
# bf16 GPU: full precision on GPU (~24GB, ~1-2s encode) β when VRAM >= 40GB
|
| 565 |
+
# bf16 streaming: streams from CPU RAM layer-by-layer (~7s encode) β when VRAM < 40GB
|
| 566 |
+
self._gemma_nf4 = os.environ.get("GEMMA_QUANTIZE", "").lower() == "nf4"
|
| 567 |
+
self._gemma_on_gpu = False
|
| 568 |
+
|
| 569 |
+
if self._gemma_nf4:
|
| 570 |
+
self._build_nf4_gemma()
|
| 571 |
+
# NF4 needs its own embeddings processor and tokenizer
|
| 572 |
+
self._cached_emb_proc = pe._embeddings_processor_builder.build(
|
| 573 |
+
device="cuda",
|
| 574 |
+
dtype=torch.bfloat16,
|
| 575 |
+
).eval()
|
| 576 |
+
self._cached_tokenizer = LTXVGemmaTokenizer(self.gemma_root)
|
| 577 |
+
logger.info("Embeddings processor cached on CUDA (NF4 mode)")
|
| 578 |
+
elif self._vram.vram_gb >= HIGH_VRAM_THRESHOLD_GB:
|
| 579 |
+
# Build pipeline's text encoder ONCE on GPU and keep it resident.
|
| 580 |
+
# This uses the same builder as pipeline.prompt_encoder but
|
| 581 |
+
# avoids the build/destroy cycle that makes each call ~30s.
|
| 582 |
+
t_gemma = time.time()
|
| 583 |
+
self._resident_text_encoder = pe._text_encoder_builder.build(
|
| 584 |
+
device=torch.device("cuda"),
|
| 585 |
+
dtype=torch.bfloat16,
|
| 586 |
+
).eval()
|
| 587 |
+
self._cached_emb_proc = pe._embeddings_processor_builder.build(
|
| 588 |
+
device="cuda",
|
| 589 |
+
dtype=torch.bfloat16,
|
| 590 |
+
).eval()
|
| 591 |
+
self._gemma_on_gpu = True
|
| 592 |
+
vram_gb = torch.cuda.memory_allocated() / (1024**3)
|
| 593 |
+
logger.info(
|
| 594 |
+
"Gemma bf16 (pipeline encoder) GPU-resident: %.1fGB VRAM, %.1fs",
|
| 595 |
+
vram_gb,
|
| 596 |
+
time.time() - t_gemma,
|
| 597 |
+
)
|
| 598 |
+
else:
|
| 599 |
+
# Low VRAM: pipeline.prompt_encoder streams from CPU (~7s/encode)
|
| 600 |
+
logger.info("Gemma managed by pipeline prompt_encoder (CPU streaming)")
|
| 601 |
+
|
| 602 |
+
logger.info("Pipeline built: %.1fs", time.time() - t0)
|
| 603 |
+
|
| 604 |
+
def _build_nf4_gemma(self) -> None:
|
| 605 |
+
"""Load Gemma 3 12B with BitsAndBytes NF4 quantization (~8GB on GPU).
|
| 606 |
+
|
| 607 |
+
NF4 Gemma stays on GPU permanently. Encode is near-instant (~0.1s)
|
| 608 |
+
since there's no CPU->GPU streaming. Slight quality tradeoff vs bf16
|
| 609 |
+
but acceptable for production use.
|
| 610 |
+
"""
|
| 611 |
+
t0 = time.time()
|
| 612 |
+
quant_config = BitsAndBytesConfig(
|
| 613 |
+
load_in_4bit=True,
|
| 614 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 615 |
+
bnb_4bit_quant_type="nf4",
|
| 616 |
+
)
|
| 617 |
+
self._nf4_gemma_model = Gemma3ForConditionalGeneration.from_pretrained(
|
| 618 |
+
self.gemma_root,
|
| 619 |
+
quantization_config=quant_config,
|
| 620 |
+
device_map="cuda",
|
| 621 |
+
dtype=torch.bfloat16,
|
| 622 |
+
).eval()
|
| 623 |
+
|
| 624 |
+
# No streaming text encoder needed β _cached_text_encoder stays None
|
| 625 |
+
self._cached_text_encoder = None
|
| 626 |
+
|
| 627 |
+
vram_gb = torch.cuda.memory_allocated() / (1024**3)
|
| 628 |
+
logger.info(
|
| 629 |
+
"Gemma NF4 loaded on GPU: %.1fGB VRAM, %.1fs", vram_gb, time.time() - t0
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
def _build_bf16_gemma_gpu(self) -> None:
|
| 633 |
+
"""Load Gemma 3 12B bf16 directly on GPU (~24GB).
|
| 634 |
+
|
| 635 |
+
For cards with >= 40GB VRAM. Gemma stays on GPU permanently.
|
| 636 |
+
Encode is ~1-2s (pure inference, no CPU->GPU streaming).
|
| 637 |
+
"""
|
| 638 |
+
t0 = time.time()
|
| 639 |
+
self._nf4_gemma_model = Gemma3ForConditionalGeneration.from_pretrained(
|
| 640 |
+
self.gemma_root,
|
| 641 |
+
device_map="cuda",
|
| 642 |
+
torch_dtype=torch.bfloat16,
|
| 643 |
+
).eval()
|
| 644 |
+
|
| 645 |
+
self._cached_text_encoder = None
|
| 646 |
+
self._gemma_on_gpu = True
|
| 647 |
+
|
| 648 |
+
vram_gb = torch.cuda.memory_allocated() / (1024**3)
|
| 649 |
+
logger.info(
|
| 650 |
+
"Gemma bf16 loaded on GPU: %.1fGB VRAM, %.1fs", vram_gb, time.time() - t0
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
def unload(self) -> None:
|
| 654 |
+
"""Free all GPU and CPU memory."""
|
| 655 |
+
if self._vram:
|
| 656 |
+
self._vram.free_all()
|
| 657 |
+
if (
|
| 658 |
+
hasattr(self, "_cached_text_encoder")
|
| 659 |
+
and self._cached_text_encoder is not None
|
| 660 |
+
):
|
| 661 |
+
self._cached_text_encoder.teardown()
|
| 662 |
+
self._cached_text_encoder = None
|
| 663 |
+
if hasattr(self, "_nf4_gemma_model"):
|
| 664 |
+
del self._nf4_gemma_model
|
| 665 |
+
self._nf4_gemma_model = None
|
| 666 |
+
if hasattr(self, "_cached_emb_proc"):
|
| 667 |
+
self._cached_emb_proc = None
|
| 668 |
+
if hasattr(self, "_cached_tokenizer"):
|
| 669 |
+
self._cached_tokenizer = None
|
| 670 |
+
self._mdl_wrapper = None
|
| 671 |
+
self._audio_encoder = None
|
| 672 |
+
self._pipeline = None
|
| 673 |
+
self._vram = None
|
| 674 |
+
self._loaded = False
|
| 675 |
+
gc.collect()
|
| 676 |
+
torch.cuda.empty_cache()
|
| 677 |
+
logger.info("AudioEngine unloaded")
|
| 678 |
+
|
| 679 |
+
def encode_text(self, prompt: str):
|
| 680 |
+
"""Encode text prompt via Gemma 3 12B.
|
| 681 |
+
|
| 682 |
+
Uses the pipeline's PromptEncoder which builds Gemma through
|
| 683 |
+
the LTX-native builder. This ensures identical encoding to the
|
| 684 |
+
reference pipeline (critical for SFX generation quality).
|
| 685 |
+
|
| 686 |
+
Falls back to NF4/bf16 GPU-resident Gemma when available for speed,
|
| 687 |
+
but routes through the pipeline encoder for correctness.
|
| 688 |
+
|
| 689 |
+
Args:
|
| 690 |
+
prompt: Compiled video-style text prompt.
|
| 691 |
+
|
| 692 |
+
Returns:
|
| 693 |
+
Tuple of (video_context, audio_context) tensors for denoising.
|
| 694 |
+
"""
|
| 695 |
+
t0 = time.time()
|
| 696 |
+
with torch.inference_mode():
|
| 697 |
+
if self._gemma_nf4:
|
| 698 |
+
# NF4: use BitsAndBytes quantized Gemma (fast, ~0.1s)
|
| 699 |
+
tp = self._cached_tokenizer.tokenize_with_weights(prompt)["gemma"]
|
| 700 |
+
ids = torch.tensor([[t[0] for t in tp]], device="cuda")
|
| 701 |
+
mask = torch.tensor([[w[1] for w in tp]], device="cuda")
|
| 702 |
+
out = self._nf4_gemma_model.model(
|
| 703 |
+
input_ids=ids,
|
| 704 |
+
attention_mask=mask,
|
| 705 |
+
output_hidden_states=True,
|
| 706 |
+
)
|
| 707 |
+
hs = out.hidden_states
|
| 708 |
+
am = mask
|
| 709 |
+
del out, ids
|
| 710 |
+
emb = self._cached_emb_proc.process_hidden_states(hs, am)
|
| 711 |
+
vc = emb.video_encoding
|
| 712 |
+
ac = emb.audio_encoding
|
| 713 |
+
del hs, am, emb
|
| 714 |
+
elif self._gemma_on_gpu:
|
| 715 |
+
# bf16 GPU-resident: use pipeline's text encoder (fast, ~0.5s)
|
| 716 |
+
hs, am = self._resident_text_encoder.encode(prompt)
|
| 717 |
+
emb = self._cached_emb_proc.process_hidden_states(hs, am)
|
| 718 |
+
vc = emb.video_encoding
|
| 719 |
+
ac = emb.audio_encoding
|
| 720 |
+
del hs, am, emb
|
| 721 |
+
else:
|
| 722 |
+
# CPU streaming: use pipeline's prompt encoder (~7s)
|
| 723 |
+
(emb,) = self._pipeline.prompt_encoder([prompt])
|
| 724 |
+
vc = emb.video_encoding
|
| 725 |
+
ac = emb.audio_encoding
|
| 726 |
+
|
| 727 |
+
logger.info("Gemma encode: %.1fs", time.time() - t0)
|
| 728 |
+
return vc, ac
|
| 729 |
+
|
| 730 |
+
def encode_reference(self, waveform_np: np.ndarray, sample_rate: int):
|
| 731 |
+
"""Encode reference audio to latent via Audio VAE encoder.
|
| 732 |
+
|
| 733 |
+
Args:
|
| 734 |
+
waveform_np: Audio samples, shape (samples,) or (samples, channels).
|
| 735 |
+
sample_rate: Sample rate of the input audio in Hz.
|
| 736 |
+
|
| 737 |
+
Returns:
|
| 738 |
+
Reference latent tensor [B, C, T, F] on GPU.
|
| 739 |
+
"""
|
| 740 |
+
# Ensure VAE encoder is on GPU
|
| 741 |
+
self._vram.to_gpu("vae_encoder")
|
| 742 |
+
|
| 743 |
+
if waveform_np.ndim == 1:
|
| 744 |
+
waveform_np = np.stack([waveform_np, waveform_np], axis=-1)
|
| 745 |
+
|
| 746 |
+
if waveform_np.ndim == 2 and waveform_np.shape[1] == 2:
|
| 747 |
+
wav = torch.from_numpy(waveform_np.T).float()
|
| 748 |
+
else:
|
| 749 |
+
wav = torch.from_numpy(waveform_np).float()
|
| 750 |
+
|
| 751 |
+
if sample_rate != self._vae_sr:
|
| 752 |
+
wav = torchaudio.functional.resample(wav, sample_rate, self._vae_sr)
|
| 753 |
+
|
| 754 |
+
max_samples = MAX_REF_SECONDS * self._vae_sr
|
| 755 |
+
if wav.shape[1] > max_samples:
|
| 756 |
+
wav = wav[:, :max_samples]
|
| 757 |
+
|
| 758 |
+
audio_obj = Audio(waveform=wav.unsqueeze(0), sampling_rate=self._vae_sr)
|
| 759 |
+
with torch.inference_mode():
|
| 760 |
+
latent = encode_audio(audio_obj, self._audio_encoder)
|
| 761 |
+
|
| 762 |
+
logger.info("Reference encoded: %s", latent.shape)
|
| 763 |
+
return latent
|
| 764 |
+
|
| 765 |
+
def generate(
|
| 766 |
+
self,
|
| 767 |
+
vc,
|
| 768 |
+
ac,
|
| 769 |
+
duration: float,
|
| 770 |
+
seed: int,
|
| 771 |
+
ref_latent=None,
|
| 772 |
+
) -> AudioResult:
|
| 773 |
+
"""Generate audio with optional A2V reference conditioning.
|
| 774 |
+
|
| 775 |
+
Args:
|
| 776 |
+
vc: Video context from encode_text().
|
| 777 |
+
ac: Audio context from encode_text().
|
| 778 |
+
duration: Target duration in seconds.
|
| 779 |
+
seed: Random seed for reproducibility.
|
| 780 |
+
ref_latent: Optional reference latent from encode_reference()
|
| 781 |
+
for A2V voice conditioning.
|
| 782 |
+
|
| 783 |
+
Returns:
|
| 784 |
+
AudioResult with waveform numpy array and metadata.
|
| 785 |
+
"""
|
| 786 |
+
return self._generate_impl(vc, ac, duration, seed, ref_latent)
|
| 787 |
+
|
| 788 |
+
@torch.inference_mode()
|
| 789 |
+
def _generate_impl(self, vc, ac, duration, seed, ref_latent=None):
|
| 790 |
+
# Ensure audio model is on GPU for denoising
|
| 791 |
+
self._vram.to_gpu("audio_model")
|
| 792 |
+
|
| 793 |
+
num_frames = ((int(duration * FPS) + 7) // 8) * 8 + 1
|
| 794 |
+
device = torch.device("cuda")
|
| 795 |
+
|
| 796 |
+
gen = torch.Generator(device=device).manual_seed(seed)
|
| 797 |
+
noiser = GaussianNoiser(generator=gen)
|
| 798 |
+
sigmas = DISTILLED_SIGMAS.to(dtype=torch.float32, device=device)
|
| 799 |
+
|
| 800 |
+
pixel_shape = VideoPixelShape(
|
| 801 |
+
batch=1, frames=num_frames, width=64, height=64, fps=FPS
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
v_shape = VideoLatentShape.from_pixel_shape(pixel_shape)
|
| 805 |
+
video_tools = VideoLatentTools(
|
| 806 |
+
VideoLatentPatchifier(patch_size=1), v_shape, fps=FPS
|
| 807 |
+
)
|
| 808 |
+
video_state = _build_state(
|
| 809 |
+
ModalitySpec(context=vc, conditionings=[]),
|
| 810 |
+
video_tools,
|
| 811 |
+
noiser,
|
| 812 |
+
torch.bfloat16,
|
| 813 |
+
device,
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
a_shape = AudioLatentShape.from_video_pixel_shape(pixel_shape)
|
| 817 |
+
audio_tools = AudioLatentTools(AudioPatchifier(patch_size=1), a_shape)
|
| 818 |
+
audio_state = _build_state(
|
| 819 |
+
ModalitySpec(context=ac),
|
| 820 |
+
audio_tools,
|
| 821 |
+
noiser,
|
| 822 |
+
torch.bfloat16,
|
| 823 |
+
device,
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
ref_frames = 0
|
| 827 |
+
if ref_latent is not None:
|
| 828 |
+
ref = ref_latent.to(device=device, dtype=torch.bfloat16)
|
| 829 |
+
ref_frames = ref.shape[2]
|
| 830 |
+
total_t = ref_frames + audio_state.latent.shape[1]
|
| 831 |
+
|
| 832 |
+
ref_patchified = ref.permute(0, 2, 1, 3).reshape(1, ref_frames, -1)
|
| 833 |
+
combined_latent = torch.cat([ref_patchified, audio_state.latent], dim=1)
|
| 834 |
+
|
| 835 |
+
ref_mask = torch.zeros(
|
| 836 |
+
1, ref_frames, 1, device=device, dtype=audio_state.denoise_mask.dtype
|
| 837 |
+
)
|
| 838 |
+
combined_mask = torch.cat([ref_mask, audio_state.denoise_mask], dim=1)
|
| 839 |
+
combined_clean = torch.cat(
|
| 840 |
+
[ref_patchified, torch.zeros_like(audio_state.clean_latent)], dim=1
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
combined_a_shape = AudioLatentShape(
|
| 844 |
+
batch=1, channels=8, frames=total_t, mel_bins=16
|
| 845 |
+
)
|
| 846 |
+
combined_audio_tools = AudioLatentTools(
|
| 847 |
+
AudioPatchifier(patch_size=1), combined_a_shape
|
| 848 |
+
)
|
| 849 |
+
gen2 = torch.Generator(device=device).manual_seed(seed)
|
| 850 |
+
noiser2 = GaussianNoiser(generator=gen2)
|
| 851 |
+
tmp_state = _build_state(
|
| 852 |
+
ModalitySpec(context=ac),
|
| 853 |
+
combined_audio_tools,
|
| 854 |
+
noiser2,
|
| 855 |
+
torch.bfloat16,
|
| 856 |
+
device,
|
| 857 |
+
)
|
| 858 |
+
combined_positions = tmp_state.positions
|
| 859 |
+
del tmp_state
|
| 860 |
+
|
| 861 |
+
audio_state_final = LatentState(
|
| 862 |
+
latent=combined_latent,
|
| 863 |
+
denoise_mask=combined_mask,
|
| 864 |
+
positions=combined_positions,
|
| 865 |
+
clean_latent=combined_clean,
|
| 866 |
+
attention_mask=None,
|
| 867 |
+
)
|
| 868 |
+
else:
|
| 869 |
+
audio_state_final = audio_state
|
| 870 |
+
|
| 871 |
+
stepper = EulerDiffusionStep()
|
| 872 |
+
with self._pipeline.stage._transformer_ctx() as transformer:
|
| 873 |
+
wrapped = BatchSplitAdapter(transformer, max_batch_size=1)
|
| 874 |
+
t0 = time.time()
|
| 875 |
+
_, audio_state_out = euler_denoising_loop(
|
| 876 |
+
sigmas=sigmas,
|
| 877 |
+
video_state=video_state,
|
| 878 |
+
audio_state=audio_state_final,
|
| 879 |
+
stepper=stepper,
|
| 880 |
+
transformer=wrapped,
|
| 881 |
+
denoiser=SimpleDenoiser(vc, ac),
|
| 882 |
+
)
|
| 883 |
+
logger.debug("Denoise: %.2fs", time.time() - t0)
|
| 884 |
+
|
| 885 |
+
if ref_latent is not None and audio_state_out is not None and ref_frames > 0:
|
| 886 |
+
audio_state_out = dc_replace(
|
| 887 |
+
audio_state_out,
|
| 888 |
+
latent=audio_state_out.latent[:, ref_frames:],
|
| 889 |
+
denoise_mask=audio_state_out.denoise_mask[:, ref_frames:],
|
| 890 |
+
positions=audio_state_out.positions[:, :, ref_frames:],
|
| 891 |
+
clean_latent=(
|
| 892 |
+
audio_state_out.clean_latent[:, ref_frames:]
|
| 893 |
+
if audio_state_out.clean_latent is not None
|
| 894 |
+
else None
|
| 895 |
+
),
|
| 896 |
+
)
|
| 897 |
+
|
| 898 |
+
audio_state_out = audio_tools.clear_conditioning(audio_state_out)
|
| 899 |
+
audio_state_out = audio_tools.unpatchify(audio_state_out)
|
| 900 |
+
|
| 901 |
+
if torch.isnan(audio_state_out.latent).any():
|
| 902 |
+
logger.warning("NaN detected in denoised latent")
|
| 903 |
+
|
| 904 |
+
# Offload audio model before VAE decode (pipeline handles decoder GPU usage)
|
| 905 |
+
self._vram.to_gpu()
|
| 906 |
+
audio = self._pipeline.audio_decoder(audio_state_out.latent)
|
| 907 |
+
# Restore audio model after decode
|
| 908 |
+
self._vram.to_gpu("audio_model")
|
| 909 |
+
|
| 910 |
+
w, sr = extract_wav(audio)
|
| 911 |
+
return AudioResult(waveform_np=w, sample_rate=sr, duration_s=w.shape[0] / sr)
|
src/audio_core/enhancer.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""VoiceFixer audio post-processing for Scenema Audio.
|
| 6 |
+
|
| 7 |
+
Applies neural speech restoration to improve clarity, remove artifacts,
|
| 8 |
+
and bring speech to studio quality. Runs on GPU after SeedVC as the
|
| 9 |
+
final processing step.
|
| 10 |
+
|
| 11 |
+
Model is downloaded on first use and cached to disk for subsequent runs.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
import subprocess
|
| 17 |
+
import sys
|
| 18 |
+
import tempfile
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import soundfile as sf
|
| 22 |
+
import torchaudio
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
_voicefixer = None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _ensure_installed():
|
| 30 |
+
"""Install voicefixer if not available."""
|
| 31 |
+
try:
|
| 32 |
+
import voicefixer # noqa: F401
|
| 33 |
+
except ImportError:
|
| 34 |
+
logger.info("Installing voicefixer...")
|
| 35 |
+
try:
|
| 36 |
+
subprocess.check_call(
|
| 37 |
+
[sys.executable, "-m", "pip", "install", "voicefixer", "--quiet"],
|
| 38 |
+
)
|
| 39 |
+
logger.info("voicefixer installed")
|
| 40 |
+
except subprocess.CalledProcessError:
|
| 41 |
+
logger.warning("Failed to install voicefixer, enhancement will be skipped")
|
| 42 |
+
raise ImportError("voicefixer not available")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _get_voicefixer():
|
| 46 |
+
"""Get or initialize the VoiceFixer model.
|
| 47 |
+
|
| 48 |
+
Downloaded on first use and cached by the library's default cache.
|
| 49 |
+
"""
|
| 50 |
+
global _voicefixer
|
| 51 |
+
|
| 52 |
+
if _voicefixer is not None:
|
| 53 |
+
return _voicefixer
|
| 54 |
+
|
| 55 |
+
_ensure_installed()
|
| 56 |
+
|
| 57 |
+
from voicefixer import VoiceFixer # noqa: E402
|
| 58 |
+
|
| 59 |
+
_voicefixer = VoiceFixer()
|
| 60 |
+
logger.info("VoiceFixer model loaded")
|
| 61 |
+
return _voicefixer
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def enhance_audio(audio_np: np.ndarray, sr: int) -> np.ndarray:
|
| 65 |
+
"""Apply VoiceFixer to audio for studio-quality output.
|
| 66 |
+
|
| 67 |
+
VoiceFixer works on WAV files, so we write to temp, process, and read back.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
audio_np: Audio array (mono or stereo), any sample rate.
|
| 71 |
+
sr: Sample rate.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Enhanced audio array at original sample rate.
|
| 75 |
+
"""
|
| 76 |
+
try:
|
| 77 |
+
vf = _get_voicefixer()
|
| 78 |
+
except (ImportError, Exception) as e:
|
| 79 |
+
logger.warning("VoiceFixer unavailable: %s, skipping", e)
|
| 80 |
+
return audio_np
|
| 81 |
+
|
| 82 |
+
is_stereo = audio_np.ndim == 2 and audio_np.shape[1] == 2
|
| 83 |
+
|
| 84 |
+
with tempfile.TemporaryDirectory() as tmp:
|
| 85 |
+
input_path = os.path.join(tmp, "input.wav")
|
| 86 |
+
output_path = os.path.join(tmp, "output.wav")
|
| 87 |
+
|
| 88 |
+
sf.write(input_path, audio_np, sr)
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
vf.restore(
|
| 92 |
+
input=input_path,
|
| 93 |
+
output=output_path,
|
| 94 |
+
cuda=True,
|
| 95 |
+
mode=0, # 0=general, 1=speech-specific
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
enhanced, enhanced_sr = sf.read(output_path)
|
| 99 |
+
|
| 100 |
+
# Resample back to original sr if needed
|
| 101 |
+
if enhanced_sr != sr:
|
| 102 |
+
import torch
|
| 103 |
+
|
| 104 |
+
t = torch.from_numpy(
|
| 105 |
+
enhanced.T if enhanced.ndim == 2 else enhanced
|
| 106 |
+
).float()
|
| 107 |
+
if t.ndim == 1:
|
| 108 |
+
t = t.unsqueeze(0)
|
| 109 |
+
t = torchaudio.functional.resample(t, enhanced_sr, sr)
|
| 110 |
+
enhanced = t.squeeze(0).numpy()
|
| 111 |
+
if enhanced.ndim == 1 and is_stereo:
|
| 112 |
+
enhanced = np.stack([enhanced, enhanced], axis=1)
|
| 113 |
+
elif enhanced.ndim == 2:
|
| 114 |
+
enhanced = enhanced.T
|
| 115 |
+
|
| 116 |
+
logger.info("Enhanced audio: %.1fs", len(enhanced) / sr)
|
| 117 |
+
return enhanced
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.warning("VoiceFixer failed: %s, returning original", e)
|
| 121 |
+
return audio_np
|
src/audio_core/inference.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""Inference orchestration for Scenema Audio.
|
| 6 |
+
|
| 7 |
+
Generates audio for planned chunks with A2V voice conditioning between
|
| 8 |
+
chunks and concatenates the results. A2V reference from each chunk's tail
|
| 9 |
+
guides the next chunk toward a consistent voice, which SeedVC then
|
| 10 |
+
polishes for exact identity matching.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
from .audio_utils import normalize_volume, trim_silence
|
| 18 |
+
from .chunker import ChunkSpec
|
| 19 |
+
from .engine import AudioEngine, AudioResult
|
| 20 |
+
from .whisper_aligner import validate_text
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
REF_TAIL_SECONDS = 3.0
|
| 25 |
+
MAX_RETRIES = 3
|
| 26 |
+
RETRY_DURATION_FACTOR = 1.3
|
| 27 |
+
MIN_WORD_MATCH_RATIO = 0.90
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def generate_chunks(
|
| 31 |
+
engine: AudioEngine,
|
| 32 |
+
chunks: list[ChunkSpec],
|
| 33 |
+
ref_latent=None,
|
| 34 |
+
ref_duration_s: float = REF_TAIL_SECONDS,
|
| 35 |
+
validate: bool = False,
|
| 36 |
+
min_match_ratio: float = MIN_WORD_MATCH_RATIO,
|
| 37 |
+
anchor_ref: bool = False,
|
| 38 |
+
) -> list[AudioResult]:
|
| 39 |
+
"""Generate audio for all chunks with A2V voice conditioning.
|
| 40 |
+
|
| 41 |
+
Each chunk gets its own Gemma encode (since each has different text).
|
| 42 |
+
The tail of each chunk's audio is encoded via Audio VAE and used as
|
| 43 |
+
A2V reference for the next chunk, guiding voice consistency. SeedVC
|
| 44 |
+
is applied afterward by the processor for exact identity matching.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
engine: AudioEngine instance
|
| 48 |
+
chunks: List of ChunkSpec from plan_chunks()
|
| 49 |
+
ref_latent: Initial reference latent (from user-provided voice URL)
|
| 50 |
+
ref_duration_s: Seconds of tail audio to use as A2V reference
|
| 51 |
+
validate: If True, run Whisper validation with retry loop.
|
| 52 |
+
If False (default), generate once without validation.
|
| 53 |
+
anchor_ref: If True, every chunk uses ref_latent instead of
|
| 54 |
+
chaining from the previous chunk's tail. Keeps voice
|
| 55 |
+
anchored to the external reference.
|
| 56 |
+
"""
|
| 57 |
+
results: list[AudioResult] = []
|
| 58 |
+
|
| 59 |
+
for i, chunk in enumerate(chunks):
|
| 60 |
+
label = "with ref" if ref_latent is not None else "no ref"
|
| 61 |
+
logger.info(
|
| 62 |
+
"Chunk %d/%d (%s, %.1fs): %s",
|
| 63 |
+
i + 1,
|
| 64 |
+
len(chunks),
|
| 65 |
+
label,
|
| 66 |
+
chunk.duration_s,
|
| 67 |
+
chunk.expected_text[:60] + ("..." if len(chunk.expected_text) > 60 else ""),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Gemma encode once per chunk (reused across retries)
|
| 71 |
+
logger.info("Compiled prompt: %s", chunk.compiled_prompt)
|
| 72 |
+
vc, ac = engine.encode_text(chunk.compiled_prompt)
|
| 73 |
+
|
| 74 |
+
duration = chunk.duration_s
|
| 75 |
+
seed = chunk.seed
|
| 76 |
+
|
| 77 |
+
if not validate:
|
| 78 |
+
# Single generation, no whisper validation
|
| 79 |
+
result = engine.generate(vc, ac, duration, seed, ref_latent=ref_latent)
|
| 80 |
+
best_result = result
|
| 81 |
+
else:
|
| 82 |
+
# Validation retry loop with whisper
|
| 83 |
+
best_result = None
|
| 84 |
+
best_ratio = -1.0
|
| 85 |
+
|
| 86 |
+
for attempt in range(MAX_RETRIES + 1):
|
| 87 |
+
result = engine.generate(vc, ac, duration, seed, ref_latent=ref_latent)
|
| 88 |
+
|
| 89 |
+
passed, transcribed, ratio = validate_text(
|
| 90 |
+
result.waveform_np,
|
| 91 |
+
result.sample_rate,
|
| 92 |
+
chunk.expected_text,
|
| 93 |
+
language=chunk.language,
|
| 94 |
+
min_word_ratio=min_match_ratio,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
if ratio > best_ratio:
|
| 98 |
+
best_result = result
|
| 99 |
+
best_ratio = ratio
|
| 100 |
+
|
| 101 |
+
if passed:
|
| 102 |
+
logger.info(
|
| 103 |
+
" Chunk %d validated: %.0f%% word match",
|
| 104 |
+
i + 1,
|
| 105 |
+
ratio * 100,
|
| 106 |
+
)
|
| 107 |
+
break
|
| 108 |
+
|
| 109 |
+
if attempt < MAX_RETRIES:
|
| 110 |
+
duration = min(duration * RETRY_DURATION_FACTOR, 20.0)
|
| 111 |
+
seed += 1
|
| 112 |
+
logger.info(
|
| 113 |
+
" Chunk %d retry %d: %.0f%% match, extending to %.1fs, seed=%d",
|
| 114 |
+
i + 1,
|
| 115 |
+
attempt + 1,
|
| 116 |
+
ratio * 100,
|
| 117 |
+
duration,
|
| 118 |
+
seed,
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
logger.warning(
|
| 122 |
+
" Chunk %d: best %.0f%% match after %d retries, accepting",
|
| 123 |
+
i + 1,
|
| 124 |
+
best_ratio * 100,
|
| 125 |
+
MAX_RETRIES,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
results.append(best_result)
|
| 129 |
+
|
| 130 |
+
# A2V: use tail of this chunk as reference for the next
|
| 131 |
+
# In anchor mode, keep using the original ref_latent for every chunk
|
| 132 |
+
if i < len(chunks) - 1 and not anchor_ref:
|
| 133 |
+
tail_samples = int(ref_duration_s * result.sample_rate)
|
| 134 |
+
tail_wav = result.waveform_np[-tail_samples:]
|
| 135 |
+
ref_latent = engine.encode_reference(tail_wav, result.sample_rate)
|
| 136 |
+
|
| 137 |
+
return results
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def concatenate_chunks(
|
| 141 |
+
results: list[AudioResult],
|
| 142 |
+
trim: bool = True,
|
| 143 |
+
normalize: bool = True,
|
| 144 |
+
) -> tuple[np.ndarray, int]:
|
| 145 |
+
"""Concatenate audio chunks with silence trimming and volume normalization.
|
| 146 |
+
|
| 147 |
+
Trims excess silence from chunk boundaries and normalizes volume
|
| 148 |
+
per-chunk to ensure consistent loudness across the full output.
|
| 149 |
+
Chunks are hard-concatenated (no crossfade).
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
results: List of AudioResult from generate_chunks().
|
| 153 |
+
trim: Whether to trim silence from chunk boundaries.
|
| 154 |
+
normalize: Whether to normalize volume per chunk.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Tuple of (concatenated waveform numpy array, sample_rate).
|
| 158 |
+
"""
|
| 159 |
+
if not results:
|
| 160 |
+
raise ValueError("No chunks to concatenate")
|
| 161 |
+
|
| 162 |
+
sr = results[0].sample_rate
|
| 163 |
+
processed: list[np.ndarray] = []
|
| 164 |
+
|
| 165 |
+
for i, r in enumerate(results):
|
| 166 |
+
w = r.waveform_np
|
| 167 |
+
if trim:
|
| 168 |
+
w = trim_silence(w, sr, max_silence=0.5)
|
| 169 |
+
if normalize:
|
| 170 |
+
w = normalize_volume(w, sr)
|
| 171 |
+
processed.append(w)
|
| 172 |
+
logger.debug(
|
| 173 |
+
"Chunk %d: %.1fs -> %.1fs",
|
| 174 |
+
i,
|
| 175 |
+
r.duration_s,
|
| 176 |
+
w.shape[0] / sr,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
result = np.concatenate(processed, axis=0)
|
| 180 |
+
logger.info(
|
| 181 |
+
"Concatenated: %.1fs from %d chunks", result.shape[0] / sr, len(processed)
|
| 182 |
+
)
|
| 183 |
+
return result, sr
|
src/audio_core/main.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""Scenema Audio entry point.
|
| 6 |
+
|
| 7 |
+
CRITICAL: CUDA memory config must happen before torch imports.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
if "expandable_segments" not in os.environ.get("PYTORCH_CUDA_ALLOC_CONF", ""):
|
| 13 |
+
_alloc = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
|
| 14 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
|
| 15 |
+
(_alloc + ",expandable_segments:True") if _alloc else "expandable_segments:True"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
import logging
|
| 19 |
+
|
| 20 |
+
logging.basicConfig(
|
| 21 |
+
level=logging.DEBUG if os.environ.get("DEBUG") else logging.INFO,
|
| 22 |
+
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
| 23 |
+
)
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
# These imports are inside main() because CUDA config above
|
| 29 |
+
# must execute before torch is imported (processor -> engine -> torch)
|
| 30 |
+
from common.runner import run
|
| 31 |
+
|
| 32 |
+
from .processor import AudioProcessor
|
| 33 |
+
|
| 34 |
+
handler_mode = os.environ.get("HANDLER_MODE", "http")
|
| 35 |
+
logger.info("Starting Scenema Audio in %s mode", handler_mode)
|
| 36 |
+
|
| 37 |
+
processor = AudioProcessor()
|
| 38 |
+
run(processor, service_type="scenema_audio")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
main()
|
src/audio_core/processor.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""Scenema Audio processor. Processor protocol implementation.
|
| 6 |
+
|
| 7 |
+
Handles HTTP sync/async requests for audio generation and voice design.
|
| 8 |
+
Follows the pattern of gpu_x2v/processor.py.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import io
|
| 12 |
+
import logging
|
| 13 |
+
import os
|
| 14 |
+
import random
|
| 15 |
+
import shutil
|
| 16 |
+
import tempfile
|
| 17 |
+
import time
|
| 18 |
+
from datetime import datetime, timezone
|
| 19 |
+
|
| 20 |
+
import httpx
|
| 21 |
+
import numpy as np
|
| 22 |
+
import psutil
|
| 23 |
+
import soundfile as sf
|
| 24 |
+
import torch
|
| 25 |
+
import torchaudio
|
| 26 |
+
|
| 27 |
+
from common.handlers.base import ProcessJob, ProcessOutput, ProcessResult
|
| 28 |
+
|
| 29 |
+
from .audio_utils import (
|
| 30 |
+
ensure_stereo,
|
| 31 |
+
load_wav,
|
| 32 |
+
normalize_volume,
|
| 33 |
+
shorten_long_silence,
|
| 34 |
+
save_wav,
|
| 35 |
+
to_mono,
|
| 36 |
+
trim_silence,
|
| 37 |
+
)
|
| 38 |
+
from .chunker import plan_chunks
|
| 39 |
+
from .compiler import compile_prompt
|
| 40 |
+
from .engine import AudioEngine, HIGH_VRAM_THRESHOLD_GB
|
| 41 |
+
from .inference import concatenate_chunks, generate_chunks
|
| 42 |
+
from .seedvc import SeedVC
|
| 43 |
+
from .validate_and_patch import validate_and_patch
|
| 44 |
+
from .validator import validate_prompt
|
| 45 |
+
from .vocal_separator import VocalSeparator
|
| 46 |
+
|
| 47 |
+
logger = logging.getLogger(__name__)
|
| 48 |
+
|
| 49 |
+
VOICE_DESIGN_DURATION_S = 15.0
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class AudioProcessor:
|
| 53 |
+
"""Processor for Scenema Audio generation.
|
| 54 |
+
|
| 55 |
+
Implements the Processor protocol (startup/shutdown/process).
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(self):
|
| 59 |
+
self.engine: AudioEngine | None = None
|
| 60 |
+
self.vocal_separator = None
|
| 61 |
+
self.seedvc = None
|
| 62 |
+
self._http_client = None
|
| 63 |
+
|
| 64 |
+
def startup(self) -> None:
|
| 65 |
+
"""Load models. Called once by handler at startup."""
|
| 66 |
+
if self.engine is not None:
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
audio_ckpt = os.environ.get(
|
| 70 |
+
"AUDIO_CKPT",
|
| 71 |
+
"/app/models/scenema-audio-transformer.safetensors",
|
| 72 |
+
)
|
| 73 |
+
vae_encoder = os.environ.get(
|
| 74 |
+
"VAE_ENCODER_CKPT",
|
| 75 |
+
"/app/models/scenema-audio-vae-encoder.safetensors",
|
| 76 |
+
)
|
| 77 |
+
gemma_root = os.environ.get(
|
| 78 |
+
"GEMMA_ROOT",
|
| 79 |
+
"/app/models/gemma-3-12b-it",
|
| 80 |
+
)
|
| 81 |
+
pipeline_ckpt = os.environ.get(
|
| 82 |
+
"PIPELINE_CKPT",
|
| 83 |
+
"/app/models/ltx-2.3-22b-distilled.safetensors",
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
self.engine = AudioEngine(
|
| 87 |
+
audio_ckpt_path=audio_ckpt,
|
| 88 |
+
vae_encoder_path=vae_encoder,
|
| 89 |
+
gemma_root=gemma_root,
|
| 90 |
+
pipeline_ckpt_path=pipeline_ckpt,
|
| 91 |
+
)
|
| 92 |
+
self.engine.load()
|
| 93 |
+
|
| 94 |
+
self.vocal_separator = VocalSeparator()
|
| 95 |
+
self.seedvc = SeedVC()
|
| 96 |
+
|
| 97 |
+
# Preload all models on high-VRAM cards (>= 40GB), keep resident
|
| 98 |
+
vram_gb = (
|
| 99 |
+
torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 100 |
+
if torch.cuda.is_available()
|
| 101 |
+
else 0
|
| 102 |
+
)
|
| 103 |
+
self._keep_resident = vram_gb >= HIGH_VRAM_THRESHOLD_GB
|
| 104 |
+
if self._keep_resident:
|
| 105 |
+
self.vocal_separator.load()
|
| 106 |
+
self.seedvc.load()
|
| 107 |
+
logger.info("All models preloaded and resident (%.0fGB VRAM)", vram_gb)
|
| 108 |
+
else:
|
| 109 |
+
logger.info("Low VRAM (%.0fGB), models loaded on-demand", vram_gb)
|
| 110 |
+
|
| 111 |
+
logger.info("AudioProcessor ready")
|
| 112 |
+
|
| 113 |
+
def shutdown(self) -> None:
|
| 114 |
+
"""Unload all models."""
|
| 115 |
+
if self.engine:
|
| 116 |
+
self.engine.unload()
|
| 117 |
+
self.engine = None
|
| 118 |
+
if self.vocal_separator:
|
| 119 |
+
self.vocal_separator.unload()
|
| 120 |
+
self.vocal_separator = None
|
| 121 |
+
if self.seedvc and self.seedvc._loaded:
|
| 122 |
+
self.seedvc.unload()
|
| 123 |
+
logger.info("AudioProcessor shutdown")
|
| 124 |
+
|
| 125 |
+
async def process(self, job: ProcessJob) -> ProcessResult:
|
| 126 |
+
"""Process an audio generation job."""
|
| 127 |
+
start_time = time.time()
|
| 128 |
+
started_at = datetime.now(timezone.utc).isoformat()
|
| 129 |
+
torch.cuda.reset_peak_memory_stats()
|
| 130 |
+
|
| 131 |
+
try:
|
| 132 |
+
if self.engine is None:
|
| 133 |
+
self.startup()
|
| 134 |
+
|
| 135 |
+
config = self._parse_input(job)
|
| 136 |
+
|
| 137 |
+
if config["mode"] == "voice_design":
|
| 138 |
+
wav_np, sr = await self._voice_design(config)
|
| 139 |
+
else:
|
| 140 |
+
wav_np, sr = await self._generate(config)
|
| 141 |
+
|
| 142 |
+
wav_bytes = self._encode_wav(wav_np, sr)
|
| 143 |
+
processing_ms = int((time.time() - start_time) * 1000)
|
| 144 |
+
|
| 145 |
+
return ProcessResult(
|
| 146 |
+
job_id=job.job_id,
|
| 147 |
+
success=True,
|
| 148 |
+
output=ProcessOutput(
|
| 149 |
+
success=True,
|
| 150 |
+
data=wav_bytes,
|
| 151 |
+
content_type="audio/wav",
|
| 152 |
+
metadata=self._build_metadata(
|
| 153 |
+
config, wav_np, sr, processing_ms, started_at
|
| 154 |
+
),
|
| 155 |
+
),
|
| 156 |
+
processing_ms=processing_ms,
|
| 157 |
+
)
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.error("Processing failed: %s", e, exc_info=True)
|
| 160 |
+
processing_ms = int((time.time() - start_time) * 1000)
|
| 161 |
+
return ProcessResult(
|
| 162 |
+
job_id=job.job_id,
|
| 163 |
+
success=False,
|
| 164 |
+
output=ProcessOutput(success=False, error=str(e)),
|
| 165 |
+
error=str(e),
|
| 166 |
+
processing_ms=processing_ms,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
def _parse_input(self, job: ProcessJob) -> dict:
|
| 170 |
+
"""Parse and validate job input.
|
| 171 |
+
|
| 172 |
+
Input schema:
|
| 173 |
+
prompt: str - Required. <speak> XML string.
|
| 174 |
+
mode: str - "generate" (default) or "voice_design".
|
| 175 |
+
reference_voice_url: str | None - URL to reference audio for voice cloning.
|
| 176 |
+
background_sfx: bool - Keep background SFX (default: false, strips via MelBandRoFormer).
|
| 177 |
+
validate: bool - Enable Whisper speech validation (default: false).
|
| 178 |
+
When true, each generated chunk is transcribed by faster-whisper
|
| 179 |
+
(GPU, float16, ~1GB VRAM) and compared against the expected text.
|
| 180 |
+
If word match ratio falls below 60%, the chunk is regenerated with
|
| 181 |
+
extended duration and a new seed (up to 3 retries), keeping the
|
| 182 |
+
best result. Adds <1s per chunk on GPU. When false, each chunk is
|
| 183 |
+
generated once with no quality gate, which is faster and sufficient
|
| 184 |
+
for most prompts.
|
| 185 |
+
seed: int - Base seed (-1 for random).
|
| 186 |
+
"""
|
| 187 |
+
inp = job.input
|
| 188 |
+
|
| 189 |
+
prompt = inp.get("prompt")
|
| 190 |
+
if not prompt:
|
| 191 |
+
raise ValueError("Missing required 'prompt' field")
|
| 192 |
+
|
| 193 |
+
mode = inp.get("mode", "generate")
|
| 194 |
+
if mode not in ("generate", "voice_design"):
|
| 195 |
+
raise ValueError(
|
| 196 |
+
f"Invalid mode: {mode}. Must be 'generate' or 'voice_design'"
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
result = validate_prompt(prompt)
|
| 200 |
+
if not result.valid:
|
| 201 |
+
raise ValueError(f"Invalid prompt XML: {'; '.join(result.errors)}")
|
| 202 |
+
|
| 203 |
+
seed = inp.get("seed", -1)
|
| 204 |
+
if seed == -1:
|
| 205 |
+
seed = random.randint(0, 999999)
|
| 206 |
+
|
| 207 |
+
return {
|
| 208 |
+
"prompt": prompt,
|
| 209 |
+
"mode": mode,
|
| 210 |
+
"reference_voice_url": inp.get("reference_voice_url"),
|
| 211 |
+
"background_sfx": inp.get("background_sfx", False),
|
| 212 |
+
"validate": inp.get("validate", True),
|
| 213 |
+
"seed": seed,
|
| 214 |
+
"pace": inp.get("pace", 1.5),
|
| 215 |
+
"min_match_ratio": inp.get("min_match_ratio", 0.90),
|
| 216 |
+
"vc_cfg_rate": inp.get("vc_cfg_rate", 0.5),
|
| 217 |
+
"vc_steps": inp.get("vc_steps", 25),
|
| 218 |
+
"skip_vc": inp.get("skip_vc", False),
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
async def _voice_design(self, config: dict) -> tuple[np.ndarray, int]:
|
| 222 |
+
"""Generate a 15s voice sample for voice design."""
|
| 223 |
+
compiled = compile_prompt(config["prompt"])
|
| 224 |
+
vc, ac = self.engine.encode_text(compiled.prompt)
|
| 225 |
+
result = self.engine.generate(vc, ac, VOICE_DESIGN_DURATION_S, config["seed"])
|
| 226 |
+
|
| 227 |
+
wav = result.waveform_np
|
| 228 |
+
sr = result.sample_rate
|
| 229 |
+
|
| 230 |
+
if not config["background_sfx"]:
|
| 231 |
+
wav = self._strip_background(wav, sr)
|
| 232 |
+
|
| 233 |
+
wav = trim_silence(wav, sr)
|
| 234 |
+
wav = shorten_long_silence(wav, sr)
|
| 235 |
+
wav = normalize_volume(wav, sr)
|
| 236 |
+
|
| 237 |
+
return wav, sr
|
| 238 |
+
|
| 239 |
+
async def _generate(self, config: dict) -> tuple[np.ndarray, int]:
|
| 240 |
+
"""Full generation pipeline with chunking and post-processing."""
|
| 241 |
+
chunks = plan_chunks(
|
| 242 |
+
config["prompt"], base_seed=config["seed"], pace=config["pace"]
|
| 243 |
+
)
|
| 244 |
+
logger.info("Planned %d chunk(s)", len(chunks))
|
| 245 |
+
|
| 246 |
+
ref_wav_path = None
|
| 247 |
+
if config["reference_voice_url"]:
|
| 248 |
+
ref_wav_path = await self._download_reference(config["reference_voice_url"])
|
| 249 |
+
|
| 250 |
+
# skip_vc: seed every chunk with the reference audio's tail latent,
|
| 251 |
+
# identical to how inter-chunk chaining works. The model sees the
|
| 252 |
+
# reference as "what I just generated" and continues in that voice.
|
| 253 |
+
# Disables the normal chaining (each chunk chains from the ref, not
|
| 254 |
+
# from the previous chunk) to keep the voice anchored to the reference.
|
| 255 |
+
anchor_latent = None
|
| 256 |
+
if config["skip_vc"] and ref_wav_path:
|
| 257 |
+
ref_wav, ref_sr = load_wav(ref_wav_path)
|
| 258 |
+
ref_mono = to_mono(ref_wav)
|
| 259 |
+
tail_seconds = 3.0
|
| 260 |
+
tail_samples = int(tail_seconds * ref_sr)
|
| 261 |
+
if ref_mono.shape[0] > tail_samples:
|
| 262 |
+
ref_tail = ref_mono[-tail_samples:]
|
| 263 |
+
else:
|
| 264 |
+
ref_tail = ref_mono
|
| 265 |
+
anchor_latent = self.engine.encode_reference(ref_tail, ref_sr)
|
| 266 |
+
logger.info(
|
| 267 |
+
"Anchor mode: every chunk seeded from %.1fs reference tail",
|
| 268 |
+
ref_tail.shape[0] / ref_sr,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
with torch.inference_mode():
|
| 272 |
+
results = generate_chunks(
|
| 273 |
+
self.engine,
|
| 274 |
+
chunks,
|
| 275 |
+
ref_latent=anchor_latent,
|
| 276 |
+
anchor_ref=anchor_latent is not None,
|
| 277 |
+
validate=config["validate"],
|
| 278 |
+
min_match_ratio=config["min_match_ratio"],
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
wav, sr = concatenate_chunks(results)
|
| 282 |
+
|
| 283 |
+
# Strip background music/SFX from the concatenated audio (single pass)
|
| 284 |
+
if not config["background_sfx"]:
|
| 285 |
+
wav = self._strip_background(wav, sr)
|
| 286 |
+
|
| 287 |
+
# Cap silence β scale with pace
|
| 288 |
+
max_silence = min(0.5 * config["pace"], 1.5)
|
| 289 |
+
wav = shorten_long_silence(
|
| 290 |
+
wav, sr, max_duration=max_silence, target_duration=max_silence * 0.6
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Apply SeedVC when: reference voice provided, or multiple chunks (voice consistency).
|
| 294 |
+
# Skip for single-chunk generations without reference (preserves SFX).
|
| 295 |
+
needs_vc = ref_wav_path or len(results) > 1
|
| 296 |
+
if not config["skip_vc"] and needs_vc:
|
| 297 |
+
wav = self._apply_seedvc(
|
| 298 |
+
wav,
|
| 299 |
+
sr,
|
| 300 |
+
results,
|
| 301 |
+
ref_wav_path,
|
| 302 |
+
vc_steps=config["vc_steps"],
|
| 303 |
+
vc_cfg_rate=config["vc_cfg_rate"],
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Post-SeedVC alignment trimming (disabled by default, needs refinement)
|
| 307 |
+
if config.get("patch", False):
|
| 308 |
+
expected_text = " ".join(c.expected_text for c in chunks)
|
| 309 |
+
wav = validate_and_patch(wav, sr, expected_text)
|
| 310 |
+
|
| 311 |
+
# Ensure stereo final output
|
| 312 |
+
wav = ensure_stereo(wav)
|
| 313 |
+
|
| 314 |
+
if ref_wav_path and os.path.exists(ref_wav_path):
|
| 315 |
+
os.unlink(ref_wav_path)
|
| 316 |
+
|
| 317 |
+
return wav, sr
|
| 318 |
+
|
| 319 |
+
def _strip_background(self, wav_np: np.ndarray, sr: int) -> np.ndarray:
|
| 320 |
+
"""Strip background music/SFX using MelBandRoFormer.
|
| 321 |
+
|
| 322 |
+
Loads the model on-demand and unloads after to free VRAM.
|
| 323 |
+
"""
|
| 324 |
+
if self.vocal_separator is None:
|
| 325 |
+
return wav_np
|
| 326 |
+
|
| 327 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
| 328 |
+
input_path = f.name
|
| 329 |
+
vocals_path = input_path.replace(".wav", "_vocals.wav")
|
| 330 |
+
|
| 331 |
+
try:
|
| 332 |
+
if not self._keep_resident:
|
| 333 |
+
self.vocal_separator.load()
|
| 334 |
+
stereo = ensure_stereo(wav_np)
|
| 335 |
+
save_wav(stereo, sr, input_path)
|
| 336 |
+
self.vocal_separator.separate(input_path, vocals_path, None)
|
| 337 |
+
vocals, _ = load_wav(vocals_path)
|
| 338 |
+
return vocals
|
| 339 |
+
except Exception as e:
|
| 340 |
+
logger.warning("Vocal separation failed: %s", e)
|
| 341 |
+
return wav_np
|
| 342 |
+
finally:
|
| 343 |
+
if not self._keep_resident:
|
| 344 |
+
self.vocal_separator.unload()
|
| 345 |
+
for p in [input_path, vocals_path]:
|
| 346 |
+
if os.path.exists(p):
|
| 347 |
+
os.unlink(p)
|
| 348 |
+
|
| 349 |
+
def _apply_seedvc(
|
| 350 |
+
self,
|
| 351 |
+
wav: np.ndarray,
|
| 352 |
+
sr: int,
|
| 353 |
+
chunk_results: list,
|
| 354 |
+
ref_wav_path: str | None,
|
| 355 |
+
vc_steps: int = 20,
|
| 356 |
+
vc_cfg_rate: float = 0.5,
|
| 357 |
+
) -> np.ndarray:
|
| 358 |
+
"""Apply SeedVC voice cloning.
|
| 359 |
+
|
| 360 |
+
If reference_voice_url provided: convert against reference.
|
| 361 |
+
If no reference: convert all against chunk 0 (first chunk sets identity).
|
| 362 |
+
"""
|
| 363 |
+
if self.seedvc is None:
|
| 364 |
+
logger.info("SeedVC not available, skipping voice cloning")
|
| 365 |
+
return wav
|
| 366 |
+
|
| 367 |
+
try:
|
| 368 |
+
if not self._keep_resident:
|
| 369 |
+
self.seedvc.load()
|
| 370 |
+
with tempfile.TemporaryDirectory() as tmp:
|
| 371 |
+
source_path = os.path.join(tmp, "source_22k.wav")
|
| 372 |
+
target_path = os.path.join(tmp, "target_22k.wav")
|
| 373 |
+
|
| 374 |
+
source_mono = to_mono(wav)
|
| 375 |
+
source_t = torch.from_numpy(source_mono).float().unsqueeze(0)
|
| 376 |
+
source_22k = torchaudio.functional.resample(source_t, sr, 22050)
|
| 377 |
+
save_wav(source_22k.squeeze(0).numpy(), 22050, source_path)
|
| 378 |
+
|
| 379 |
+
if ref_wav_path:
|
| 380 |
+
target_wav, target_sr = load_wav(ref_wav_path)
|
| 381 |
+
target_mono = to_mono(target_wav)
|
| 382 |
+
target_t = torch.from_numpy(target_mono).float().unsqueeze(0)
|
| 383 |
+
target_22k = torchaudio.functional.resample(
|
| 384 |
+
target_t, target_sr, 22050
|
| 385 |
+
)
|
| 386 |
+
save_wav(target_22k.squeeze(0).numpy(), 22050, target_path)
|
| 387 |
+
else:
|
| 388 |
+
chunk0 = chunk_results[0].waveform_np
|
| 389 |
+
chunk0_mono = to_mono(chunk0)
|
| 390 |
+
chunk0_t = torch.from_numpy(chunk0_mono).float().unsqueeze(0)
|
| 391 |
+
chunk0_22k = torchaudio.functional.resample(
|
| 392 |
+
chunk0_t, chunk_results[0].sample_rate, 22050
|
| 393 |
+
)
|
| 394 |
+
save_wav(chunk0_22k.squeeze(0).numpy(), 22050, target_path)
|
| 395 |
+
|
| 396 |
+
converted = self.seedvc.convert(
|
| 397 |
+
source_path,
|
| 398 |
+
target_path,
|
| 399 |
+
diffusion_steps=vc_steps,
|
| 400 |
+
cfg_rate=vc_cfg_rate,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
conv_t = torch.from_numpy(converted).float().unsqueeze(0)
|
| 404 |
+
result = torchaudio.functional.resample(conv_t, 22050, sr)
|
| 405 |
+
wav = result.squeeze(0).numpy()
|
| 406 |
+
wav = ensure_stereo(wav)
|
| 407 |
+
|
| 408 |
+
except Exception as e:
|
| 409 |
+
logger.error("SeedVC failed: %s", e, exc_info=True)
|
| 410 |
+
finally:
|
| 411 |
+
if not self._keep_resident:
|
| 412 |
+
try:
|
| 413 |
+
self.seedvc.unload()
|
| 414 |
+
except Exception:
|
| 415 |
+
pass
|
| 416 |
+
|
| 417 |
+
return wav
|
| 418 |
+
|
| 419 |
+
async def _download_reference(self, url: str) -> str:
|
| 420 |
+
"""Download reference audio from URL to temp file."""
|
| 421 |
+
if self._http_client is None:
|
| 422 |
+
self._http_client = httpx.AsyncClient(timeout=60.0, follow_redirects=True)
|
| 423 |
+
|
| 424 |
+
response = await self._http_client.get(url)
|
| 425 |
+
response.raise_for_status()
|
| 426 |
+
|
| 427 |
+
suffix = ".wav"
|
| 428 |
+
if "mp3" in url.lower() or "mpeg" in response.headers.get("content-type", ""):
|
| 429 |
+
suffix = ".mp3"
|
| 430 |
+
|
| 431 |
+
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
|
| 432 |
+
f.write(response.content)
|
| 433 |
+
logger.info(
|
| 434 |
+
"Downloaded reference: %d bytes to %s", len(response.content), f.name
|
| 435 |
+
)
|
| 436 |
+
return f.name
|
| 437 |
+
|
| 438 |
+
def _encode_wav(self, wav_np: np.ndarray, sr: int) -> bytes:
|
| 439 |
+
"""Encode numpy array to WAV bytes."""
|
| 440 |
+
buf = io.BytesIO()
|
| 441 |
+
sf.write(buf, wav_np, sr, format="WAV")
|
| 442 |
+
return buf.getvalue()
|
| 443 |
+
|
| 444 |
+
def _build_metadata(
|
| 445 |
+
self,
|
| 446 |
+
config: dict,
|
| 447 |
+
wav_np: np.ndarray,
|
| 448 |
+
sr: int,
|
| 449 |
+
processing_ms: int,
|
| 450 |
+
started_at: str = "",
|
| 451 |
+
) -> dict:
|
| 452 |
+
"""Build comprehensive metadata matching x2v pattern."""
|
| 453 |
+
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A"
|
| 454 |
+
vram_total_mb = 0
|
| 455 |
+
vram_peak_mb = 0
|
| 456 |
+
if torch.cuda.is_available():
|
| 457 |
+
vram_total_mb = round(
|
| 458 |
+
torch.cuda.get_device_properties(0).total_memory / 1024**2
|
| 459 |
+
)
|
| 460 |
+
vram_peak_mb = round(torch.cuda.max_memory_allocated() / 1024**2)
|
| 461 |
+
|
| 462 |
+
cpu_cores_total = os.cpu_count() or 0
|
| 463 |
+
system_ram_gb = round(psutil.virtual_memory().total / 1024**3)
|
| 464 |
+
disk = shutil.disk_usage("/")
|
| 465 |
+
|
| 466 |
+
return {
|
| 467 |
+
"duration_s": round(wav_np.shape[0] / sr, 2),
|
| 468 |
+
"sample_rate": sr,
|
| 469 |
+
"mode": config["mode"],
|
| 470 |
+
"seed": config["seed"],
|
| 471 |
+
"background_sfx": config["background_sfx"],
|
| 472 |
+
"has_reference_voice": config["reference_voice_url"] is not None,
|
| 473 |
+
"validate": config["validate"],
|
| 474 |
+
"processing_ms": processing_ms,
|
| 475 |
+
"vram_peak_mb": vram_peak_mb,
|
| 476 |
+
"vram_total_mb": vram_total_mb,
|
| 477 |
+
"gpu": gpu_name,
|
| 478 |
+
"cpu_cores_total": cpu_cores_total,
|
| 479 |
+
"system_ram_gb": system_ram_gb,
|
| 480 |
+
"disk_total_gb": round(disk.total / 1024**3, 1),
|
| 481 |
+
"disk_free_gb": round(disk.free / 1024**3, 1),
|
| 482 |
+
"started_at": started_at,
|
| 483 |
+
"completed_at": datetime.now(timezone.utc).isoformat(),
|
| 484 |
+
}
|
src/audio_core/seedvc.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""SeedVC voice conversion for Scenema Audio.
|
| 6 |
+
|
| 7 |
+
Converts the voice identity of generated audio to match a reference speaker
|
| 8 |
+
while preserving prosody, rhythm, and emotion. Uses the Seed-VC model with
|
| 9 |
+
DiT backbone, CAMPPlus speaker encoder, and BigVGAN vocoder.
|
| 10 |
+
|
| 11 |
+
Expects 22050Hz mono WAV input for both source and target.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import inspect
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
import types
|
| 19 |
+
from argparse import Namespace
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
DEFAULT_SEEDVC_PATH = Path(os.environ.get("SEEDVC_PATH", "/app/seed-vc"))
|
| 28 |
+
DEFAULT_DIFFUSION_STEPS = 25
|
| 29 |
+
DEFAULT_CFG_RATE = 0.5
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SeedVC:
|
| 33 |
+
"""Voice conversion engine using Seed-VC.
|
| 34 |
+
|
| 35 |
+
Converts source audio voice identity to match a target speaker
|
| 36 |
+
while preserving the source's delivery, emotion, and pacing.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, seedvc_path: Path = DEFAULT_SEEDVC_PATH):
|
| 40 |
+
self.seedvc_path = seedvc_path
|
| 41 |
+
self._loaded = False
|
| 42 |
+
self._original_cwd: str | None = None
|
| 43 |
+
self._app_vc = None
|
| 44 |
+
|
| 45 |
+
def load(self) -> None:
|
| 46 |
+
"""Load SeedVC models to GPU.
|
| 47 |
+
|
| 48 |
+
Changes working directory to seedvc_path (required by SeedVC internals),
|
| 49 |
+
stubs gradio, and loads all models via app_vc.load_models().
|
| 50 |
+
"""
|
| 51 |
+
if self._loaded:
|
| 52 |
+
return
|
| 53 |
+
|
| 54 |
+
logger.info("Loading SeedVC from %s", self.seedvc_path)
|
| 55 |
+
|
| 56 |
+
self._original_cwd = os.getcwd()
|
| 57 |
+
os.chdir(self.seedvc_path)
|
| 58 |
+
|
| 59 |
+
if "gradio" not in sys.modules:
|
| 60 |
+
sys.modules["gradio"] = types.ModuleType("gradio")
|
| 61 |
+
|
| 62 |
+
seedvc_str = str(self.seedvc_path)
|
| 63 |
+
if seedvc_str not in sys.path:
|
| 64 |
+
sys.path.insert(0, seedvc_str)
|
| 65 |
+
|
| 66 |
+
os.environ.setdefault(
|
| 67 |
+
"HF_HUB_CACHE",
|
| 68 |
+
str(self.seedvc_path / "checkpoints" / "hf_cache"),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Patch BigVGAN for huggingface_hub compat (same as gpu_vc)
|
| 72 |
+
import modules.bigvgan.bigvgan as _bigvgan_mod
|
| 73 |
+
|
| 74 |
+
_orig = _bigvgan_mod.BigVGAN._from_pretrained
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def _patched(cls, **kwargs):
|
| 78 |
+
kwargs.setdefault("proxies", None)
|
| 79 |
+
kwargs.setdefault("resume_download", False)
|
| 80 |
+
return _orig.__func__(cls, **kwargs)
|
| 81 |
+
|
| 82 |
+
_bigvgan_mod.BigVGAN._from_pretrained = _patched
|
| 83 |
+
|
| 84 |
+
# Load models (exact pattern from gpu_vc/seedvc_engine.py)
|
| 85 |
+
import app_vc
|
| 86 |
+
|
| 87 |
+
self._app_vc = app_vc
|
| 88 |
+
app_vc.device = torch.device("cuda")
|
| 89 |
+
|
| 90 |
+
args = Namespace(checkpoint=None, config=None, fp16=True, gpu=0)
|
| 91 |
+
(
|
| 92 |
+
app_vc.model,
|
| 93 |
+
app_vc.semantic_fn,
|
| 94 |
+
app_vc.vocoder_fn,
|
| 95 |
+
app_vc.campplus_model,
|
| 96 |
+
app_vc.to_mel,
|
| 97 |
+
app_vc.mel_fn_args,
|
| 98 |
+
) = app_vc.load_models(args)
|
| 99 |
+
|
| 100 |
+
app_vc.max_context_window = app_vc.sr // app_vc.hop_length * 30
|
| 101 |
+
app_vc.overlap_wave_len = app_vc.overlap_frame_len * app_vc.hop_length
|
| 102 |
+
|
| 103 |
+
self._loaded = True
|
| 104 |
+
logger.info("SeedVC loaded: sr=%d, device=%s", app_vc.sr, app_vc.device)
|
| 105 |
+
|
| 106 |
+
def unload(self) -> None:
|
| 107 |
+
"""Free SeedVC models from GPU."""
|
| 108 |
+
if not self._loaded:
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
if self._app_vc is not None:
|
| 112 |
+
for attr in [
|
| 113 |
+
"model",
|
| 114 |
+
"semantic_fn",
|
| 115 |
+
"vocoder_fn",
|
| 116 |
+
"campplus_model",
|
| 117 |
+
"to_mel",
|
| 118 |
+
]:
|
| 119 |
+
if hasattr(self._app_vc, attr):
|
| 120 |
+
delattr(self._app_vc, attr)
|
| 121 |
+
self._app_vc = None
|
| 122 |
+
|
| 123 |
+
torch.cuda.empty_cache()
|
| 124 |
+
|
| 125 |
+
if self._original_cwd:
|
| 126 |
+
os.chdir(self._original_cwd)
|
| 127 |
+
self._original_cwd = None
|
| 128 |
+
|
| 129 |
+
self._loaded = False
|
| 130 |
+
logger.info("SeedVC unloaded")
|
| 131 |
+
|
| 132 |
+
def convert(
|
| 133 |
+
self,
|
| 134 |
+
source_wav_path: str,
|
| 135 |
+
target_wav_path: str,
|
| 136 |
+
diffusion_steps: int = DEFAULT_DIFFUSION_STEPS,
|
| 137 |
+
cfg_rate: float = DEFAULT_CFG_RATE,
|
| 138 |
+
) -> np.ndarray:
|
| 139 |
+
"""Convert voice identity of source to match target.
|
| 140 |
+
|
| 141 |
+
Both files must be 22050Hz mono WAV.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
source_wav_path: Path to source audio (generated speech)
|
| 145 |
+
target_wav_path: Path to target audio (reference voice)
|
| 146 |
+
diffusion_steps: Number of diffusion steps (quality vs speed)
|
| 147 |
+
cfg_rate: Classifier-free guidance rate
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Converted audio as float32 numpy array at 22050Hz mono
|
| 151 |
+
"""
|
| 152 |
+
if not self._loaded:
|
| 153 |
+
raise RuntimeError("SeedVC not loaded. Call load() first.")
|
| 154 |
+
|
| 155 |
+
logger.info(
|
| 156 |
+
"Converting voice: %s -> %s (%d steps, cfg_rate=%.2f)",
|
| 157 |
+
source_wav_path,
|
| 158 |
+
target_wav_path,
|
| 159 |
+
diffusion_steps,
|
| 160 |
+
cfg_rate,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
audio_tuple = None
|
| 164 |
+
vc_kwargs = {
|
| 165 |
+
"source": source_wav_path,
|
| 166 |
+
"target": target_wav_path,
|
| 167 |
+
"diffusion_steps": diffusion_steps,
|
| 168 |
+
"length_adjust": 1.0,
|
| 169 |
+
"inference_cfg_rate": cfg_rate,
|
| 170 |
+
}
|
| 171 |
+
# n_quantizers removed in newer SeedVC versions
|
| 172 |
+
sig = inspect.signature(self._app_vc.voice_conversion)
|
| 173 |
+
if "n_quantizers" in sig.parameters:
|
| 174 |
+
vc_kwargs["n_quantizers"] = 3
|
| 175 |
+
for result in self._app_vc.voice_conversion(**vc_kwargs):
|
| 176 |
+
if isinstance(result, tuple) and len(result) == 2:
|
| 177 |
+
_, audio_tuple = result
|
| 178 |
+
|
| 179 |
+
if audio_tuple is None:
|
| 180 |
+
raise RuntimeError("SeedVC produced no output")
|
| 181 |
+
|
| 182 |
+
sample_rate, samples = audio_tuple
|
| 183 |
+
|
| 184 |
+
if samples.dtype == np.int16:
|
| 185 |
+
samples = samples.astype(np.float32) / 32768.0
|
| 186 |
+
elif samples.dtype != np.float32:
|
| 187 |
+
samples = samples.astype(np.float32)
|
| 188 |
+
|
| 189 |
+
peak = np.abs(samples).max()
|
| 190 |
+
if peak > 1.0:
|
| 191 |
+
samples = samples / peak
|
| 192 |
+
|
| 193 |
+
logger.info("Converted: %.1fs at %dHz", len(samples) / sample_rate, sample_rate)
|
| 194 |
+
return samples
|
src/audio_core/validate_and_patch.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""Forced alignment and hallucination trimming for Scenema Audio.
|
| 6 |
+
|
| 7 |
+
Uses Needleman-Wunsch sequence alignment (same algorithm as DNA matching)
|
| 8 |
+
to optimally align Whisper-transcribed words against expected text. Words
|
| 9 |
+
in the transcription that are INSERTIONS (not in the expected text) are
|
| 10 |
+
trimmed at silence boundaries. Substitutions (misrecognized words) are kept.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
import re
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
from .audio_utils import to_mono
|
| 19 |
+
from .whisper_aligner import _get_whisper
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
SILENCE_THRESHOLD = 0.015
|
| 24 |
+
TRIM_PAD_S = 0.02
|
| 25 |
+
|
| 26 |
+
# Alignment scoring
|
| 27 |
+
MATCH_SCORE = 2
|
| 28 |
+
MISMATCH_SCORE = -1
|
| 29 |
+
GAP_SCORE = -1 # Cost of insertion or deletion
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _normalize_words(text: str) -> list[str]:
|
| 33 |
+
"""Normalize text to lowercase words without punctuation."""
|
| 34 |
+
text = text.lower()
|
| 35 |
+
text = re.sub(r"[^\w\s]", "", text)
|
| 36 |
+
return text.split()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _fuzzy_match(a: str, b: str) -> bool:
|
| 40 |
+
"""Check if two words are similar enough (edit distance based)."""
|
| 41 |
+
if a == b:
|
| 42 |
+
return True
|
| 43 |
+
if not a or not b or len(a) < 4 or len(b) < 4:
|
| 44 |
+
return False
|
| 45 |
+
m, n = len(a), len(b)
|
| 46 |
+
dp = list(range(n + 1))
|
| 47 |
+
for i in range(1, m + 1):
|
| 48 |
+
prev = dp[0]
|
| 49 |
+
dp[0] = i
|
| 50 |
+
for j in range(1, n + 1):
|
| 51 |
+
temp = dp[j]
|
| 52 |
+
dp[j] = prev if a[i - 1] == b[j - 1] else 1 + min(prev, dp[j], dp[j - 1])
|
| 53 |
+
prev = temp
|
| 54 |
+
return 1 - (dp[n] / max(m, n)) >= 0.5
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _score(a: str, b: str) -> int:
|
| 58 |
+
"""Score for aligning word a with word b."""
|
| 59 |
+
if a == b:
|
| 60 |
+
return MATCH_SCORE
|
| 61 |
+
if _fuzzy_match(a, b):
|
| 62 |
+
return MATCH_SCORE # Treat fuzzy matches same as exact
|
| 63 |
+
return MISMATCH_SCORE
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _needleman_wunsch(
|
| 67 |
+
transcribed: list[str],
|
| 68 |
+
expected: list[str],
|
| 69 |
+
) -> list[str]:
|
| 70 |
+
"""Needleman-Wunsch global alignment.
|
| 71 |
+
|
| 72 |
+
Returns a list of labels for each transcribed word:
|
| 73 |
+
- "match": word aligns to an expected word (exact or fuzzy)
|
| 74 |
+
- "substitution": word replaces an expected word (poor match)
|
| 75 |
+
- "insertion": word has no counterpart in expected text (hallucinated)
|
| 76 |
+
|
| 77 |
+
Expected words that have no counterpart are deletions (not returned
|
| 78 |
+
since we only label transcribed words).
|
| 79 |
+
"""
|
| 80 |
+
m = len(transcribed)
|
| 81 |
+
n = len(expected)
|
| 82 |
+
|
| 83 |
+
# Build score matrix
|
| 84 |
+
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
| 85 |
+
for i in range(1, m + 1):
|
| 86 |
+
dp[i][0] = dp[i - 1][0] + GAP_SCORE
|
| 87 |
+
for j in range(1, n + 1):
|
| 88 |
+
dp[0][j] = dp[0][j - 1] + GAP_SCORE
|
| 89 |
+
|
| 90 |
+
for i in range(1, m + 1):
|
| 91 |
+
for j in range(1, n + 1):
|
| 92 |
+
match = dp[i - 1][j - 1] + _score(transcribed[i - 1], expected[j - 1])
|
| 93 |
+
delete = dp[i - 1][j] + GAP_SCORE # transcribed word is insertion
|
| 94 |
+
insert = dp[i][j - 1] + GAP_SCORE # expected word is deletion
|
| 95 |
+
dp[i][j] = max(match, delete, insert)
|
| 96 |
+
|
| 97 |
+
# Traceback
|
| 98 |
+
labels = []
|
| 99 |
+
i, j = m, n
|
| 100 |
+
while i > 0 or j > 0:
|
| 101 |
+
if (
|
| 102 |
+
i > 0
|
| 103 |
+
and j > 0
|
| 104 |
+
and dp[i][j]
|
| 105 |
+
== dp[i - 1][j - 1] + _score(transcribed[i - 1], expected[j - 1])
|
| 106 |
+
):
|
| 107 |
+
s = _score(transcribed[i - 1], expected[j - 1])
|
| 108 |
+
labels.append("match" if s == MATCH_SCORE else "substitution")
|
| 109 |
+
i -= 1
|
| 110 |
+
j -= 1
|
| 111 |
+
elif i > 0 and dp[i][j] == dp[i - 1][j] + GAP_SCORE:
|
| 112 |
+
labels.append("insertion")
|
| 113 |
+
i -= 1
|
| 114 |
+
else:
|
| 115 |
+
j -= 1 # Deletion in expected β skip
|
| 116 |
+
|
| 117 |
+
labels.reverse()
|
| 118 |
+
return labels
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _transcribe_with_timestamps(
|
| 122 |
+
audio_mono: np.ndarray,
|
| 123 |
+
sr: int,
|
| 124 |
+
language: str,
|
| 125 |
+
) -> list[dict]:
|
| 126 |
+
"""Transcribe audio with word-level timestamps."""
|
| 127 |
+
if sr != 16000:
|
| 128 |
+
import librosa
|
| 129 |
+
|
| 130 |
+
audio_16k = librosa.resample(audio_mono, orig_sr=sr, target_sr=16000)
|
| 131 |
+
else:
|
| 132 |
+
audio_16k = audio_mono
|
| 133 |
+
|
| 134 |
+
model = _get_whisper()
|
| 135 |
+
segments, _ = model.transcribe(
|
| 136 |
+
audio_16k,
|
| 137 |
+
language=language,
|
| 138 |
+
word_timestamps=True,
|
| 139 |
+
vad_filter=True,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
words = []
|
| 143 |
+
for seg in segments:
|
| 144 |
+
if seg.words:
|
| 145 |
+
for w in seg.words:
|
| 146 |
+
words.append(
|
| 147 |
+
{
|
| 148 |
+
"word": w.word.strip().lower(),
|
| 149 |
+
"start": w.start,
|
| 150 |
+
"end": w.end,
|
| 151 |
+
}
|
| 152 |
+
)
|
| 153 |
+
return words
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _find_silence_boundary(
|
| 157 |
+
audio: np.ndarray,
|
| 158 |
+
sr: int,
|
| 159 |
+
center_sample: int,
|
| 160 |
+
direction: str = "left",
|
| 161 |
+
window_s: float = 0.3,
|
| 162 |
+
) -> int:
|
| 163 |
+
"""Find nearest silence point from center position."""
|
| 164 |
+
hop = int(0.01 * sr)
|
| 165 |
+
window_samples = int(window_s * sr)
|
| 166 |
+
|
| 167 |
+
if direction == "left":
|
| 168 |
+
positions = range(center_sample, max(0, center_sample - window_samples), -hop)
|
| 169 |
+
else:
|
| 170 |
+
positions = range(
|
| 171 |
+
center_sample, min(len(audio), center_sample + window_samples), hop
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
for pos in positions:
|
| 175 |
+
chunk = audio[max(0, pos - hop // 2) : min(len(audio), pos + hop // 2)]
|
| 176 |
+
if (
|
| 177 |
+
len(chunk) > 0
|
| 178 |
+
and np.sqrt(np.mean(chunk.astype(np.float64) ** 2)) < SILENCE_THRESHOLD
|
| 179 |
+
):
|
| 180 |
+
return pos
|
| 181 |
+
|
| 182 |
+
return center_sample
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _merge_ranges(
|
| 186 |
+
ranges: list[tuple[float, float]], gap: float = 0.15
|
| 187 |
+
) -> list[tuple[float, float]]:
|
| 188 |
+
"""Merge consecutive time ranges that are close together."""
|
| 189 |
+
if not ranges:
|
| 190 |
+
return []
|
| 191 |
+
merged = []
|
| 192 |
+
for start, end in sorted(ranges):
|
| 193 |
+
if merged and start - merged[-1][1] < gap:
|
| 194 |
+
merged[-1] = (merged[-1][0], end)
|
| 195 |
+
else:
|
| 196 |
+
merged.append((start, end))
|
| 197 |
+
return merged
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _detect_audio_repetition(
|
| 201 |
+
mono: np.ndarray,
|
| 202 |
+
sr: int,
|
| 203 |
+
expected_words: list[str],
|
| 204 |
+
min_duration_s: float = 1.5,
|
| 205 |
+
similarity_threshold: float = 0.85,
|
| 206 |
+
) -> list[tuple[float, float]]:
|
| 207 |
+
"""Detect repeated audio segments via mel spectrogram cross-correlation.
|
| 208 |
+
|
| 209 |
+
Slides a window across the audio and compares each segment against
|
| 210 |
+
all subsequent segments. If two non-overlapping segments have high
|
| 211 |
+
cosine similarity and the expected text does NOT contain that phrase
|
| 212 |
+
repeated, the second segment is marked for removal.
|
| 213 |
+
|
| 214 |
+
Only detects segments >= min_duration_s to avoid false positives on
|
| 215 |
+
short common sounds (breaths, pauses).
|
| 216 |
+
"""
|
| 217 |
+
import torch
|
| 218 |
+
|
| 219 |
+
total_s = len(mono) / sr
|
| 220 |
+
if total_s < min_duration_s * 3:
|
| 221 |
+
return []
|
| 222 |
+
|
| 223 |
+
# Compute mel spectrogram
|
| 224 |
+
hop_length = int(0.02 * sr) # 20ms hops
|
| 225 |
+
n_fft = int(0.04 * sr) # 40ms window
|
| 226 |
+
audio_t = torch.from_numpy(mono).float()
|
| 227 |
+
|
| 228 |
+
try:
|
| 229 |
+
mel_spec = torch.stft(
|
| 230 |
+
audio_t,
|
| 231 |
+
n_fft=n_fft,
|
| 232 |
+
hop_length=hop_length,
|
| 233 |
+
window=torch.hann_window(n_fft),
|
| 234 |
+
return_complex=True,
|
| 235 |
+
).abs()
|
| 236 |
+
except Exception:
|
| 237 |
+
return []
|
| 238 |
+
|
| 239 |
+
# Reduce to energy per time frame
|
| 240 |
+
energy = mel_spec.mean(dim=0).numpy() # (time_frames,)
|
| 241 |
+
frames_per_sec = sr / hop_length
|
| 242 |
+
|
| 243 |
+
# Slide window: check segments of varying length
|
| 244 |
+
repeated_ranges = []
|
| 245 |
+
|
| 246 |
+
for window_s in [3.0, 2.0, 1.5]:
|
| 247 |
+
win_frames = int(window_s * frames_per_sec)
|
| 248 |
+
if win_frames >= len(energy):
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
step = win_frames // 2
|
| 252 |
+
for i in range(0, len(energy) - win_frames, step):
|
| 253 |
+
seg_a = energy[i : i + win_frames]
|
| 254 |
+
norm_a = np.linalg.norm(seg_a)
|
| 255 |
+
if norm_a < 1e-6:
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
for j in range(i + win_frames, len(energy) - win_frames, step):
|
| 259 |
+
seg_b = energy[j : j + win_frames]
|
| 260 |
+
norm_b = np.linalg.norm(seg_b)
|
| 261 |
+
if norm_b < 1e-6:
|
| 262 |
+
continue
|
| 263 |
+
|
| 264 |
+
similarity = np.dot(seg_a, seg_b) / (norm_a * norm_b)
|
| 265 |
+
if similarity >= similarity_threshold:
|
| 266 |
+
start_s = j / frames_per_sec
|
| 267 |
+
end_s = (j + win_frames) / frames_per_sec
|
| 268 |
+
repeated_ranges.append((start_s, end_s))
|
| 269 |
+
|
| 270 |
+
# Deduplicate overlapping ranges
|
| 271 |
+
if not repeated_ranges:
|
| 272 |
+
return []
|
| 273 |
+
|
| 274 |
+
merged = _merge_ranges(repeated_ranges, gap=0.5)
|
| 275 |
+
logger.debug("Audio fingerprint candidates: %d segments", len(merged))
|
| 276 |
+
return merged
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _build_trim_mask(
|
| 280 |
+
mono: np.ndarray,
|
| 281 |
+
sr: int,
|
| 282 |
+
insertion_ranges: list[tuple[float, float]],
|
| 283 |
+
) -> np.ndarray:
|
| 284 |
+
"""Build boolean mask removing insertion segments at silence boundaries."""
|
| 285 |
+
total_samples = len(mono)
|
| 286 |
+
keep_mask = np.ones(total_samples, dtype=bool)
|
| 287 |
+
pad_samples = int(TRIM_PAD_S * sr)
|
| 288 |
+
|
| 289 |
+
for start_s, end_s in insertion_ranges:
|
| 290 |
+
trim_start = _find_silence_boundary(mono, sr, int(start_s * sr), "left")
|
| 291 |
+
trim_end = _find_silence_boundary(mono, sr, int(end_s * sr), "right")
|
| 292 |
+
trim_start = max(0, trim_start - pad_samples)
|
| 293 |
+
trim_end = min(total_samples, trim_end + pad_samples)
|
| 294 |
+
keep_mask[trim_start:trim_end] = False
|
| 295 |
+
|
| 296 |
+
return keep_mask
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def validate_and_patch(
|
| 300 |
+
audio_np: np.ndarray,
|
| 301 |
+
sr: int,
|
| 302 |
+
expected_text: str,
|
| 303 |
+
language: str = "en",
|
| 304 |
+
) -> np.ndarray:
|
| 305 |
+
"""Trim hallucinated content using Needleman-Wunsch sequence alignment.
|
| 306 |
+
|
| 307 |
+
1. Transcribe audio with Whisper (word timestamps)
|
| 308 |
+
2. Align transcribed words against expected text (NW algorithm)
|
| 309 |
+
3. Label each transcribed word: match, substitution, or insertion
|
| 310 |
+
4. Trim insertion words (hallucinated) at silence boundaries
|
| 311 |
+
5. Keep substitutions (misrecognized real speech)
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
audio_np: Audio array (mono or stereo).
|
| 315 |
+
sr: Sample rate.
|
| 316 |
+
expected_text: Full expected plain text.
|
| 317 |
+
language: Language code.
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
Trimmed audio array.
|
| 321 |
+
"""
|
| 322 |
+
expected_words = _normalize_words(expected_text)
|
| 323 |
+
if not expected_words:
|
| 324 |
+
return audio_np
|
| 325 |
+
|
| 326 |
+
mono = to_mono(audio_np).astype(np.float32)
|
| 327 |
+
|
| 328 |
+
try:
|
| 329 |
+
transcribed = _transcribe_with_timestamps(mono, sr, language)
|
| 330 |
+
except Exception as e:
|
| 331 |
+
logger.warning("Forced alignment failed: %s, skipping", e)
|
| 332 |
+
return audio_np
|
| 333 |
+
|
| 334 |
+
if not transcribed:
|
| 335 |
+
logger.info("No words transcribed, skipping trim")
|
| 336 |
+
return audio_np
|
| 337 |
+
|
| 338 |
+
# Extract just the words for alignment
|
| 339 |
+
transcribed_words = [re.sub(r"[^\w]", "", tw["word"]) for tw in transcribed]
|
| 340 |
+
transcribed_words = [w for w in transcribed_words if w] # Remove empty
|
| 341 |
+
|
| 342 |
+
# Build index mapping: filtered word index -> original transcribed index
|
| 343 |
+
word_indices = [
|
| 344 |
+
i for i, tw in enumerate(transcribed) if re.sub(r"[^\w]", "", tw["word"])
|
| 345 |
+
]
|
| 346 |
+
|
| 347 |
+
# Run Needleman-Wunsch alignment
|
| 348 |
+
labels = _needleman_wunsch(transcribed_words, expected_words)
|
| 349 |
+
|
| 350 |
+
# Collect insertion ranges (hallucinated words)
|
| 351 |
+
insertion_ranges = []
|
| 352 |
+
n_match = 0
|
| 353 |
+
n_sub = 0
|
| 354 |
+
n_ins = 0
|
| 355 |
+
|
| 356 |
+
for idx, label in enumerate(labels):
|
| 357 |
+
orig_idx = word_indices[idx]
|
| 358 |
+
if label == "insertion":
|
| 359 |
+
insertion_ranges.append(
|
| 360 |
+
(transcribed[orig_idx]["start"], transcribed[orig_idx]["end"])
|
| 361 |
+
)
|
| 362 |
+
n_ins += 1
|
| 363 |
+
elif label == "match":
|
| 364 |
+
n_match += 1
|
| 365 |
+
else:
|
| 366 |
+
n_sub += 1
|
| 367 |
+
|
| 368 |
+
logger.info(
|
| 369 |
+
"NW alignment: %d matched, %d substituted, %d inserted (of %d transcribed vs %d expected)",
|
| 370 |
+
n_match,
|
| 371 |
+
n_sub,
|
| 372 |
+
n_ins,
|
| 373 |
+
len(transcribed_words),
|
| 374 |
+
len(expected_words),
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Audio fingerprint: detect repeated audio segments that Whisper missed
|
| 378 |
+
fingerprint_ranges = _detect_audio_repetition(mono, sr, expected_words)
|
| 379 |
+
if fingerprint_ranges:
|
| 380 |
+
logger.info(
|
| 381 |
+
"Audio fingerprint found %d repeated segments", len(fingerprint_ranges)
|
| 382 |
+
)
|
| 383 |
+
insertion_ranges.extend(fingerprint_ranges)
|
| 384 |
+
|
| 385 |
+
if not insertion_ranges:
|
| 386 |
+
logger.info("No insertions detected, audio clean")
|
| 387 |
+
return audio_np
|
| 388 |
+
|
| 389 |
+
# Merge consecutive insertions and trim
|
| 390 |
+
merged = _merge_ranges(insertion_ranges)
|
| 391 |
+
keep_mask = _build_trim_mask(mono, sr, merged)
|
| 392 |
+
result = audio_np[keep_mask]
|
| 393 |
+
|
| 394 |
+
trimmed_s = (len(mono) - np.sum(keep_mask)) / sr
|
| 395 |
+
logger.info(
|
| 396 |
+
"Trimmed %.1fs of hallucinated content (%.1fs -> %.1fs)",
|
| 397 |
+
trimmed_s,
|
| 398 |
+
len(mono) / sr,
|
| 399 |
+
np.sum(keep_mask) / sr,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
return result
|
src/audio_core/validator.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""XML prompt validation for Scenema Audio.
|
| 6 |
+
|
| 7 |
+
Validates the <speak> XML format:
|
| 8 |
+
<speak voice="..." scene="..." language="...">
|
| 9 |
+
<action>delivery/stage direction</action>
|
| 10 |
+
Speech text here.
|
| 11 |
+
<action>more direction</action>
|
| 12 |
+
More speech text.
|
| 13 |
+
</speak>
|
| 14 |
+
|
| 15 |
+
Only <speak> root with <action> children allowed. All content is freeform.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import xml.etree.ElementTree as ET
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
|
| 21 |
+
ALLOWED_CHILD_TAGS = {"action", "sound"}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class ValidationResult:
|
| 26 |
+
valid: bool
|
| 27 |
+
errors: list[str] = field(default_factory=list)
|
| 28 |
+
voice: str | None = None
|
| 29 |
+
scene: str | None = None
|
| 30 |
+
language: str | None = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def validate_prompt(xml_string: str) -> ValidationResult:
|
| 34 |
+
"""Validate a Scenema Audio XML prompt.
|
| 35 |
+
|
| 36 |
+
Checks for valid XML structure, required <speak> root element,
|
| 37 |
+
required voice attribute, and only <action> child elements.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
xml_string: Raw XML string to validate.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
ValidationResult with parsed attributes if valid,
|
| 44 |
+
or a list of errors if invalid.
|
| 45 |
+
"""
|
| 46 |
+
errors: list[str] = []
|
| 47 |
+
|
| 48 |
+
if not xml_string or not xml_string.strip():
|
| 49 |
+
return ValidationResult(valid=False, errors=["Prompt is empty"])
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
root = ET.fromstring(xml_string)
|
| 53 |
+
except ET.ParseError as e:
|
| 54 |
+
return ValidationResult(valid=False, errors=[f"Invalid XML: {e}"])
|
| 55 |
+
|
| 56 |
+
if root.tag != "speak":
|
| 57 |
+
errors.append(f"Root element must be <speak>, got <{root.tag}>")
|
| 58 |
+
return ValidationResult(valid=False, errors=errors)
|
| 59 |
+
|
| 60 |
+
voice = root.get("voice")
|
| 61 |
+
if not voice or not voice.strip():
|
| 62 |
+
errors.append("Missing required 'voice' attribute on <speak>")
|
| 63 |
+
|
| 64 |
+
gender = root.get("gender")
|
| 65 |
+
if not gender or gender.strip() not in ("male", "female"):
|
| 66 |
+
errors.append(
|
| 67 |
+
"Missing or invalid 'gender' attribute on <speak>. Must be 'male' or 'female'"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
scene = root.get("scene")
|
| 71 |
+
language = root.get("language", "en")
|
| 72 |
+
|
| 73 |
+
allowed_attrs = {"voice", "scene", "language", "gender", "shot"}
|
| 74 |
+
for attr in root.attrib:
|
| 75 |
+
if attr not in allowed_attrs:
|
| 76 |
+
errors.append(f"Unknown attribute '{attr}' on <speak>")
|
| 77 |
+
|
| 78 |
+
for child in root:
|
| 79 |
+
if child.tag not in ALLOWED_CHILD_TAGS:
|
| 80 |
+
errors.append(
|
| 81 |
+
f"Unsupported tag <{child.tag}>. Only <action> and <sound> are allowed inside <speak>"
|
| 82 |
+
)
|
| 83 |
+
if len(list(child)) > 0:
|
| 84 |
+
errors.append(f"<{child.tag}> must contain only text, no nested elements")
|
| 85 |
+
|
| 86 |
+
has_text = False
|
| 87 |
+
if root.text and root.text.strip():
|
| 88 |
+
has_text = True
|
| 89 |
+
for child in root:
|
| 90 |
+
if child.tail and child.tail.strip():
|
| 91 |
+
has_text = True
|
| 92 |
+
break
|
| 93 |
+
|
| 94 |
+
if not has_text:
|
| 95 |
+
errors.append("Prompt must contain at least one speech text node")
|
| 96 |
+
|
| 97 |
+
if errors:
|
| 98 |
+
return ValidationResult(valid=False, errors=errors)
|
| 99 |
+
|
| 100 |
+
return ValidationResult(
|
| 101 |
+
valid=True,
|
| 102 |
+
voice=voice.strip() if voice else None,
|
| 103 |
+
scene=scene.strip() if scene else None,
|
| 104 |
+
language=language.strip() if language else None,
|
| 105 |
+
)
|
src/audio_core/vocal_separator.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""MelBandRoFormer vocal separation for Scenema Audio.
|
| 6 |
+
|
| 7 |
+
Separates vocals from background music/SFX in audio. Used to clean
|
| 8 |
+
generated audio that may contain unwanted background sounds from the
|
| 9 |
+
diffusion model (which was trained on video with ambient audio).
|
| 10 |
+
|
| 11 |
+
Expects stereo 44100Hz input. Processes in overlapping chunks for
|
| 12 |
+
smooth transitions.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
import subprocess
|
| 18 |
+
import sys
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from safetensors.torch import load_file
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
DEFAULT_MODEL_PATH = Path(
|
| 28 |
+
os.environ.get("MELBAND_MODEL_PATH", "/app/models/MelBandRoformer_fp16.safetensors")
|
| 29 |
+
)
|
| 30 |
+
DEFAULT_NODE_PATH = Path(
|
| 31 |
+
os.environ.get("MELBAND_NODE_PATH", "/app/melband_roformer_node")
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
MODEL_CONFIG = {
|
| 35 |
+
"dim": 384,
|
| 36 |
+
"depth": 6,
|
| 37 |
+
"stereo": True,
|
| 38 |
+
"num_stems": 1,
|
| 39 |
+
"time_transformer_depth": 1,
|
| 40 |
+
"freq_transformer_depth": 1,
|
| 41 |
+
"num_bands": 60,
|
| 42 |
+
"dim_head": 64,
|
| 43 |
+
"heads": 8,
|
| 44 |
+
"attn_dropout": 0,
|
| 45 |
+
"ff_dropout": 0,
|
| 46 |
+
"flash_attn": True,
|
| 47 |
+
"dim_freqs_in": 1025,
|
| 48 |
+
"sample_rate": 44100,
|
| 49 |
+
"stft_n_fft": 2048,
|
| 50 |
+
"stft_hop_length": 441,
|
| 51 |
+
"stft_win_length": 2048,
|
| 52 |
+
"stft_normalized": False,
|
| 53 |
+
"mask_estimator_depth": 2,
|
| 54 |
+
"multi_stft_resolution_loss_weight": 1.0,
|
| 55 |
+
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
|
| 56 |
+
"multi_stft_hop_size": 147,
|
| 57 |
+
"multi_stft_normalized": False,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
CHUNK_SIZE = 352800 # ~8 seconds at 44100Hz
|
| 61 |
+
OVERLAP_FACTOR = 2
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class VocalSeparator:
|
| 65 |
+
"""Separates vocals from background audio using MelBandRoFormer.
|
| 66 |
+
|
| 67 |
+
Processes audio in overlapping chunks with fade windows for
|
| 68 |
+
smooth transitions. Keeps model loaded on GPU for repeated use.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
model_path: Path = DEFAULT_MODEL_PATH,
|
| 74 |
+
node_path: Path = DEFAULT_NODE_PATH,
|
| 75 |
+
):
|
| 76 |
+
self.model_path = model_path
|
| 77 |
+
self.node_path = node_path
|
| 78 |
+
self._model = None
|
| 79 |
+
self._loaded = False
|
| 80 |
+
|
| 81 |
+
def load(self) -> None:
|
| 82 |
+
"""Load MelBandRoFormer model to GPU."""
|
| 83 |
+
if self._loaded:
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
# Lazy import: model architecture only available after node_path added to sys.path
|
| 87 |
+
node_str = str(self.node_path)
|
| 88 |
+
if node_str not in sys.path:
|
| 89 |
+
sys.path.insert(0, node_str)
|
| 90 |
+
from model.mel_band_roformer import MelBandRoformer
|
| 91 |
+
|
| 92 |
+
logger.info("Loading MelBandRoFormer from %s", self.model_path)
|
| 93 |
+
|
| 94 |
+
model = MelBandRoformer(**MODEL_CONFIG)
|
| 95 |
+
sd = load_file(str(self.model_path))
|
| 96 |
+
model.load_state_dict(sd)
|
| 97 |
+
del sd
|
| 98 |
+
|
| 99 |
+
self._model = model.cuda().eval().float()
|
| 100 |
+
self._loaded = True
|
| 101 |
+
|
| 102 |
+
param_count = sum(p.numel() for p in self._model.parameters())
|
| 103 |
+
logger.info("MelBandRoFormer loaded: %.1fM params", param_count / 1e6)
|
| 104 |
+
|
| 105 |
+
def unload(self) -> None:
|
| 106 |
+
"""Free model from GPU."""
|
| 107 |
+
if not self._loaded:
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
self._model = None
|
| 111 |
+
torch.cuda.empty_cache()
|
| 112 |
+
self._loaded = False
|
| 113 |
+
logger.info("MelBandRoFormer unloaded")
|
| 114 |
+
|
| 115 |
+
def separate(
|
| 116 |
+
self,
|
| 117 |
+
input_path: str,
|
| 118 |
+
vocals_path: str,
|
| 119 |
+
sfx_path: str | None = None,
|
| 120 |
+
) -> dict:
|
| 121 |
+
"""Separate vocals from background audio.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
input_path: Path to input audio file (any format ffmpeg supports)
|
| 125 |
+
vocals_path: Output path for isolated vocals
|
| 126 |
+
sfx_path: Output path for isolated SFX/background (optional)
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Dict with metadata: input_duration, sample_rate
|
| 130 |
+
"""
|
| 131 |
+
if not self._loaded:
|
| 132 |
+
raise RuntimeError("VocalSeparator not loaded. Call load() first.")
|
| 133 |
+
|
| 134 |
+
sr = MODEL_CONFIG["sample_rate"]
|
| 135 |
+
|
| 136 |
+
audio = self._load_audio_ffmpeg(input_path, sr)
|
| 137 |
+
input_duration = audio.shape[1] / sr
|
| 138 |
+
|
| 139 |
+
logger.info("Separating: %.1fs audio", input_duration)
|
| 140 |
+
|
| 141 |
+
with torch.inference_mode():
|
| 142 |
+
vocals = self._chunked_inference(audio, sr)
|
| 143 |
+
|
| 144 |
+
self._save_audio_ffmpeg(vocals, sr, vocals_path)
|
| 145 |
+
|
| 146 |
+
if sfx_path:
|
| 147 |
+
sfx = audio - vocals
|
| 148 |
+
self._save_audio_ffmpeg(sfx, sr, sfx_path)
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
"input_duration": input_duration,
|
| 152 |
+
"sample_rate": sr,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
def _chunked_inference(self, audio: np.ndarray, sr: int) -> np.ndarray:
|
| 156 |
+
"""Run model inference in overlapping chunks with fade windows."""
|
| 157 |
+
total_samples = audio.shape[1]
|
| 158 |
+
chunk_size = CHUNK_SIZE
|
| 159 |
+
overlap = chunk_size // OVERLAP_FACTOR
|
| 160 |
+
step = chunk_size - overlap
|
| 161 |
+
|
| 162 |
+
fade_in = np.linspace(0, 1, overlap, dtype=np.float32)
|
| 163 |
+
fade_out = np.linspace(1, 0, overlap, dtype=np.float32)
|
| 164 |
+
|
| 165 |
+
result = np.zeros_like(audio)
|
| 166 |
+
weight = np.zeros(total_samples, dtype=np.float32)
|
| 167 |
+
|
| 168 |
+
pos = 0
|
| 169 |
+
while pos < total_samples:
|
| 170 |
+
end = min(pos + chunk_size, total_samples)
|
| 171 |
+
chunk = audio[:, pos:end]
|
| 172 |
+
|
| 173 |
+
if chunk.shape[1] < chunk_size:
|
| 174 |
+
pad_width = chunk_size - chunk.shape[1]
|
| 175 |
+
chunk = np.pad(chunk, ((0, 0), (0, pad_width)))
|
| 176 |
+
|
| 177 |
+
chunk_t = torch.from_numpy(chunk.copy()).unsqueeze(0).cuda().float()
|
| 178 |
+
out = self._model(chunk_t)
|
| 179 |
+
out_np = out.squeeze(0).cpu().float().numpy()[:, : end - pos]
|
| 180 |
+
|
| 181 |
+
chunk_len = end - pos
|
| 182 |
+
w = np.ones(chunk_len, dtype=np.float32)
|
| 183 |
+
if pos > 0:
|
| 184 |
+
fade_len = min(overlap, chunk_len)
|
| 185 |
+
w[:fade_len] *= fade_in[:fade_len]
|
| 186 |
+
if end < total_samples:
|
| 187 |
+
fade_len = min(overlap, chunk_len)
|
| 188 |
+
w[-fade_len:] *= fade_out[:fade_len]
|
| 189 |
+
|
| 190 |
+
result[:, pos:end] += out_np * w[np.newaxis, :]
|
| 191 |
+
weight[pos:end] += w
|
| 192 |
+
|
| 193 |
+
pos += step
|
| 194 |
+
|
| 195 |
+
weight = np.maximum(weight, 1e-8)
|
| 196 |
+
result /= weight[np.newaxis, :]
|
| 197 |
+
|
| 198 |
+
return result
|
| 199 |
+
|
| 200 |
+
def _load_audio_ffmpeg(self, path: str, target_sr: int) -> np.ndarray:
|
| 201 |
+
"""Load audio to stereo float32 numpy via ffmpeg."""
|
| 202 |
+
cmd = [
|
| 203 |
+
"ffmpeg",
|
| 204 |
+
"-i",
|
| 205 |
+
path,
|
| 206 |
+
"-f",
|
| 207 |
+
"f32le",
|
| 208 |
+
"-acodec",
|
| 209 |
+
"pcm_f32le",
|
| 210 |
+
"-ac",
|
| 211 |
+
"2",
|
| 212 |
+
"-ar",
|
| 213 |
+
str(target_sr),
|
| 214 |
+
"-v",
|
| 215 |
+
"quiet",
|
| 216 |
+
"pipe:1",
|
| 217 |
+
]
|
| 218 |
+
proc = subprocess.run(cmd, capture_output=True, check=True)
|
| 219 |
+
audio = np.frombuffer(proc.stdout, dtype=np.float32)
|
| 220 |
+
return audio.reshape(-1, 2).T # (2, samples)
|
| 221 |
+
|
| 222 |
+
def _save_audio_ffmpeg(self, audio: np.ndarray, sr: int, path: str) -> None:
|
| 223 |
+
"""Save stereo float32 numpy to WAV via ffmpeg."""
|
| 224 |
+
interleaved = audio.T.astype(np.float32).tobytes()
|
| 225 |
+
cmd = [
|
| 226 |
+
"ffmpeg",
|
| 227 |
+
"-y",
|
| 228 |
+
"-f",
|
| 229 |
+
"f32le",
|
| 230 |
+
"-acodec",
|
| 231 |
+
"pcm_f32le",
|
| 232 |
+
"-ac",
|
| 233 |
+
"2",
|
| 234 |
+
"-ar",
|
| 235 |
+
str(sr),
|
| 236 |
+
"-i",
|
| 237 |
+
"pipe:0",
|
| 238 |
+
"-acodec",
|
| 239 |
+
"pcm_s16le",
|
| 240 |
+
path,
|
| 241 |
+
"-v",
|
| 242 |
+
"quiet",
|
| 243 |
+
]
|
| 244 |
+
subprocess.run(cmd, input=interleaved, check=True)
|
src/audio_core/whisper_aligner.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""Whisper alignment for audio validation in Scenema Audio.
|
| 6 |
+
|
| 7 |
+
Uses faster-whisper (CTranslate2) on GPU to transcribe generated audio
|
| 8 |
+
and validate that the expected text was spoken. Whisper-small is 244M
|
| 9 |
+
params (~1GB VRAM, float16). Runs after denoise when VRAM is free.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
import re
|
| 14 |
+
import unicodedata
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# Singleton whisper model (loaded once, reused)
|
| 21 |
+
_whisper_model = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_whisper():
|
| 25 |
+
"""Get or initialize the whisper-small model.
|
| 26 |
+
|
| 27 |
+
Loaded once and cached for the process lifetime.
|
| 28 |
+
Runs on GPU with float16 β whisper-small is 244M params (~1GB VRAM).
|
| 29 |
+
By the time validation runs, denoise is complete and VRAM is free.
|
| 30 |
+
CTranslate2 uses its own CUDA allocator so no conflict with PyTorch.
|
| 31 |
+
"""
|
| 32 |
+
global _whisper_model
|
| 33 |
+
|
| 34 |
+
if _whisper_model is not None:
|
| 35 |
+
return _whisper_model
|
| 36 |
+
|
| 37 |
+
from faster_whisper import WhisperModel
|
| 38 |
+
|
| 39 |
+
logger.info("Loading whisper-small for alignment validation (GPU, float16)...")
|
| 40 |
+
_whisper_model = WhisperModel("small", device="cuda", compute_type="float16")
|
| 41 |
+
logger.info("whisper-small loaded (GPU)")
|
| 42 |
+
return _whisper_model
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def transcribe(audio_np: np.ndarray, sr: int, language: str = "en") -> str:
|
| 46 |
+
"""Transcribe audio and return the text.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
audio_np: Audio samples, shape (samples,) or (samples, channels).
|
| 50 |
+
sr: Sample rate in Hz.
|
| 51 |
+
language: Language code for transcription.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
Transcribed text string.
|
| 55 |
+
"""
|
| 56 |
+
model = _get_whisper()
|
| 57 |
+
|
| 58 |
+
# Convert to mono float32 if needed
|
| 59 |
+
if audio_np.ndim == 2:
|
| 60 |
+
audio_mono = audio_np.mean(axis=1).astype(np.float32)
|
| 61 |
+
else:
|
| 62 |
+
audio_mono = audio_np.astype(np.float32)
|
| 63 |
+
|
| 64 |
+
# Resample to 16kHz if needed
|
| 65 |
+
if sr != 16000:
|
| 66 |
+
import librosa
|
| 67 |
+
|
| 68 |
+
audio_mono = librosa.resample(audio_mono, orig_sr=sr, target_sr=16000)
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
segments, _ = model.transcribe(
|
| 72 |
+
audio_mono,
|
| 73 |
+
language=language,
|
| 74 |
+
word_timestamps=False,
|
| 75 |
+
vad_filter=True,
|
| 76 |
+
)
|
| 77 |
+
text = " ".join(seg.text.strip() for seg in segments).strip()
|
| 78 |
+
except (ValueError, TypeError):
|
| 79 |
+
# Mocked model in tests returns wrong types
|
| 80 |
+
logger.debug("Whisper transcribe returned unexpected type (test env?)")
|
| 81 |
+
text = ""
|
| 82 |
+
|
| 83 |
+
return text
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def validate_text(
|
| 87 |
+
audio_np: np.ndarray,
|
| 88 |
+
sr: int,
|
| 89 |
+
expected_text: str,
|
| 90 |
+
language: str = "en",
|
| 91 |
+
min_word_ratio: float = 0.6,
|
| 92 |
+
) -> tuple[bool, str, float]:
|
| 93 |
+
"""Validate that generated audio contains the expected text.
|
| 94 |
+
|
| 95 |
+
Transcribes the audio and checks what fraction of expected words
|
| 96 |
+
appear in the transcription.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
audio_np: Audio samples.
|
| 100 |
+
sr: Sample rate.
|
| 101 |
+
expected_text: The text that should have been spoken.
|
| 102 |
+
language: Language code.
|
| 103 |
+
min_word_ratio: Minimum fraction of expected words that must
|
| 104 |
+
appear in transcription (0.0 to 1.0).
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Tuple of (passed, transcribed_text, word_match_ratio).
|
| 108 |
+
"""
|
| 109 |
+
transcribed = transcribe(audio_np, sr, language)
|
| 110 |
+
|
| 111 |
+
# Normalize both texts for comparison (strip accents for cross-locale matching)
|
| 112 |
+
def normalize(t):
|
| 113 |
+
t = unicodedata.normalize("NFD", t)
|
| 114 |
+
t = "".join(c for c in t if unicodedata.category(c) != "Mn")
|
| 115 |
+
t = t.lower()
|
| 116 |
+
t = re.sub(r"[^\w\s]", "", t)
|
| 117 |
+
return set(t.split())
|
| 118 |
+
|
| 119 |
+
expected_words = normalize(expected_text)
|
| 120 |
+
transcribed_words = normalize(transcribed)
|
| 121 |
+
|
| 122 |
+
if not expected_words:
|
| 123 |
+
return True, transcribed, 1.0
|
| 124 |
+
|
| 125 |
+
matched = expected_words & transcribed_words
|
| 126 |
+
ratio = len(matched) / len(expected_words)
|
| 127 |
+
|
| 128 |
+
passed = ratio >= min_word_ratio
|
| 129 |
+
if not passed:
|
| 130 |
+
logger.warning(
|
| 131 |
+
"Validation failed: %.0f%% word match (need %.0f%%). "
|
| 132 |
+
"Expected: %s... Got: %s...",
|
| 133 |
+
ratio * 100,
|
| 134 |
+
min_word_ratio * 100,
|
| 135 |
+
expected_text[:60],
|
| 136 |
+
transcribed[:60],
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
return passed, transcribed, ratio
|
src/common/__init__.py
ADDED
|
File without changes
|
src/common/handlers/__init__.py
ADDED
|
File without changes
|
src/common/handlers/base.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""Minimal handler types for standalone deployment.
|
| 6 |
+
|
| 7 |
+
Drop-in replacement for the production common.handlers.base module.
|
| 8 |
+
Provides ProcessJob, ProcessOutput, and ProcessResult so that
|
| 9 |
+
audio_core.processor imports resolve without modification.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import Any, Optional
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class ProcessJob:
|
| 18 |
+
job_id: str
|
| 19 |
+
input: dict[str, Any]
|
| 20 |
+
upload_url: Optional[str] = None
|
| 21 |
+
webhook_url: Optional[str] = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class ProcessOutput:
|
| 26 |
+
success: bool = True
|
| 27 |
+
data: Optional[bytes] = None
|
| 28 |
+
content_type: Optional[str] = None
|
| 29 |
+
result: Optional[dict] = None
|
| 30 |
+
metadata: Optional[dict] = None
|
| 31 |
+
error: Optional[str] = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class ProcessResult:
|
| 36 |
+
job_id: str
|
| 37 |
+
success: bool
|
| 38 |
+
output: Optional[ProcessOutput] = None
|
| 39 |
+
processing_ms: int = 0
|
| 40 |
+
error: Optional[str] = None
|
src/server.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Scenema AI
|
| 2 |
+
# https://scenema.ai
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
|
| 5 |
+
"""Scenema Audio standalone server.
|
| 6 |
+
|
| 7 |
+
Thin FastAPI wrapper around the production AudioProcessor.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import asyncio
|
| 11 |
+
import base64
|
| 12 |
+
import logging
|
| 13 |
+
import os
|
| 14 |
+
import uuid
|
| 15 |
+
from contextlib import asynccontextmanager
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
from fastapi import FastAPI, Request
|
| 19 |
+
from fastapi.responses import JSONResponse
|
| 20 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 21 |
+
import uvicorn
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger("scenema-audio")
|
| 24 |
+
|
| 25 |
+
# Must be set before any torch import
|
| 26 |
+
os.environ.setdefault(
|
| 27 |
+
"PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
from audio_core.processor import AudioProcessor # noqa: E402
|
| 31 |
+
from common.handlers.base import ProcessJob # noqa: E402
|
| 32 |
+
|
| 33 |
+
# ββ Model download ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
+
|
| 35 |
+
HF_REPO = "ScenemaAI/scenema-audio"
|
| 36 |
+
GEMMA_REPO = "google/gemma-3-12b-it"
|
| 37 |
+
SEEDVC_REPO = "Plachta/Seed-VC"
|
| 38 |
+
BIGVGAN_REPO = "nvidia/bigvgan_v2_22khz_80band_256x"
|
| 39 |
+
WHISPER_REPO = "openai/whisper-small"
|
| 40 |
+
|
| 41 |
+
MODEL_DIR = Path(os.environ.get("MODEL_DIR", "/app/models"))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _download_models():
|
| 45 |
+
"""Download missing model checkpoints from HuggingFace."""
|
| 46 |
+
|
| 47 |
+
token = os.environ.get("HF_TOKEN")
|
| 48 |
+
|
| 49 |
+
# Audio transformer (INT8 by default)
|
| 50 |
+
audio_ckpt = Path(os.environ.get(
|
| 51 |
+
"AUDIO_CKPT",
|
| 52 |
+
str(MODEL_DIR / "scenema-audio-transformer-int8.safetensors"),
|
| 53 |
+
))
|
| 54 |
+
if not audio_ckpt.exists():
|
| 55 |
+
logger.info("Downloading audio transformer (INT8, ~4.9 GB)...")
|
| 56 |
+
hf_hub_download(
|
| 57 |
+
HF_REPO,
|
| 58 |
+
"scenema-audio-transformer-int8.safetensors",
|
| 59 |
+
local_dir=str(audio_ckpt.parent),
|
| 60 |
+
token=token,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Pipeline checkpoint
|
| 64 |
+
pipeline_ckpt = Path(os.environ.get(
|
| 65 |
+
"PIPELINE_CKPT",
|
| 66 |
+
str(MODEL_DIR / "scenema-audio-pipeline.safetensors"),
|
| 67 |
+
))
|
| 68 |
+
if not pipeline_ckpt.exists():
|
| 69 |
+
logger.info("Downloading pipeline checkpoint (~7.1 GB)...")
|
| 70 |
+
hf_hub_download(
|
| 71 |
+
HF_REPO,
|
| 72 |
+
"scenema-audio-pipeline.safetensors",
|
| 73 |
+
local_dir=str(pipeline_ckpt.parent),
|
| 74 |
+
token=token,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# VAE encoder (small, may already be baked)
|
| 78 |
+
vae_ckpt = Path(os.environ.get(
|
| 79 |
+
"VAE_ENCODER_CKPT",
|
| 80 |
+
str(MODEL_DIR / "scenema-audio-vae-encoder.safetensors"),
|
| 81 |
+
))
|
| 82 |
+
if not vae_ckpt.exists():
|
| 83 |
+
logger.info("Downloading VAE encoder (~42 MB)...")
|
| 84 |
+
hf_hub_download(
|
| 85 |
+
HF_REPO,
|
| 86 |
+
"scenema-audio-vae-encoder.safetensors",
|
| 87 |
+
local_dir=str(vae_ckpt.parent),
|
| 88 |
+
token=token,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Gemma 3 12B IT
|
| 92 |
+
gemma_root = Path(os.environ.get("GEMMA_ROOT", str(MODEL_DIR / "gemma-3-12b-it")))
|
| 93 |
+
if not gemma_root.exists() or not any(gemma_root.glob("*.safetensors")):
|
| 94 |
+
logger.info("Downloading Gemma 3 12B IT (~24 GB, gated model)...")
|
| 95 |
+
snapshot_download(
|
| 96 |
+
GEMMA_REPO,
|
| 97 |
+
local_dir=str(gemma_root),
|
| 98 |
+
ignore_patterns=["*.gguf"],
|
| 99 |
+
token=token,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# SeedVC
|
| 103 |
+
seedvc_path = Path(os.environ.get("SEEDVC_PATH", "/app/seed-vc"))
|
| 104 |
+
seedvc_cache = seedvc_path / "checkpoints"
|
| 105 |
+
if not seedvc_cache.exists() or not any(seedvc_cache.glob("*.pth")):
|
| 106 |
+
logger.info("Downloading SeedVC checkpoints (~1.6 GB)...")
|
| 107 |
+
hf_cache = seedvc_cache / "hf_cache"
|
| 108 |
+
hf_cache.mkdir(parents=True, exist_ok=True)
|
| 109 |
+
os.environ["HF_HUB_CACHE"] = str(hf_cache)
|
| 110 |
+
hf_hub_download(
|
| 111 |
+
SEEDVC_REPO,
|
| 112 |
+
"DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
|
| 113 |
+
local_dir=str(seedvc_cache),
|
| 114 |
+
token=token,
|
| 115 |
+
)
|
| 116 |
+
hf_hub_download(
|
| 117 |
+
SEEDVC_REPO,
|
| 118 |
+
"config_dit_mel_seed_uvit_whisper_small_wavenet.yml",
|
| 119 |
+
local_dir=str(seedvc_cache),
|
| 120 |
+
token=token,
|
| 121 |
+
)
|
| 122 |
+
snapshot_download(BIGVGAN_REPO, local_dir=str(hf_cache / "bigvgan"))
|
| 123 |
+
snapshot_download(WHISPER_REPO, local_dir=str(hf_cache / "whisper-small"))
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# ββ FastAPI app βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 127 |
+
|
| 128 |
+
processor = AudioProcessor()
|
| 129 |
+
_semaphore = asyncio.Semaphore(1)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@asynccontextmanager
|
| 133 |
+
async def lifespan(app: FastAPI):
|
| 134 |
+
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 135 |
+
_download_models()
|
| 136 |
+
processor.startup()
|
| 137 |
+
logger.info("Scenema Audio ready on port %s", os.environ.get("PORT", "8000"))
|
| 138 |
+
yield
|
| 139 |
+
processor.shutdown()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
app = FastAPI(title="Scenema Audio", lifespan=lifespan)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@app.get("/health")
|
| 146 |
+
async def health():
|
| 147 |
+
return {"status": "ok"}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@app.post("/generate")
|
| 151 |
+
async def generate(request: Request):
|
| 152 |
+
body = await request.json()
|
| 153 |
+
|
| 154 |
+
job = ProcessJob(
|
| 155 |
+
job_id=str(uuid.uuid4()),
|
| 156 |
+
input=body,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
async with _semaphore:
|
| 160 |
+
result = await processor.process(job)
|
| 161 |
+
|
| 162 |
+
if not result.success:
|
| 163 |
+
return JSONResponse(
|
| 164 |
+
status_code=500,
|
| 165 |
+
content={
|
| 166 |
+
"status": "failed",
|
| 167 |
+
"error": result.error or "Generation failed",
|
| 168 |
+
},
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
output = result.output
|
| 172 |
+
audio_b64 = base64.b64encode(output.data).decode() if output.data else None
|
| 173 |
+
|
| 174 |
+
return {
|
| 175 |
+
"status": "succeeded",
|
| 176 |
+
"audio": audio_b64,
|
| 177 |
+
"content_type": output.content_type or "audio/wav",
|
| 178 |
+
"metadata": output.metadata or {},
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
logging.basicConfig(
|
| 184 |
+
level=logging.INFO,
|
| 185 |
+
format="%(asctime)s %(name)s %(levelname)s %(message)s",
|
| 186 |
+
)
|
| 187 |
+
port = int(os.environ.get("PORT", "8000"))
|
| 188 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|