github-actions[bot] commited on
Commit
2c3df98
·
1 Parent(s): 4230483

deploy: switch to dramabox requirements @ b5b35d7

Browse files
dramabox_src/audio_conditioning.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio reference conditioning item for IC-LoRA voice cloning."""
2
+
3
+ import torch
4
+
5
+ from ltx_core.components.patchifiers import AudioPatchifier
6
+ from ltx_core.conditioning.item import ConditioningItem
7
+ from ltx_core.tools import AudioLatentTools
8
+ from ltx_core.types import AudioLatentShape, LatentState
9
+
10
+
11
+ class AudioConditionByReferenceLatent(ConditioningItem):
12
+ """Conditions audio generation on a reference audio latent for voice cloning.
13
+
14
+ Mirrors VideoConditionByReferenceLatent but for audio:
15
+ - Patchifies reference latent [B, C, T, F] -> [B, ref_T, 128]
16
+ - Computes 1D temporal positions via AudioPatchifier
17
+ - Sets denoise_mask = 1.0 - strength (strength=1.0 -> mask=0 -> frozen)
18
+ - Builds ASYMMETRIC attention mask: target->ref=1 (attend), ref->target=0 (read-only)
19
+ - APPENDS ref tokens to END of latent sequence (IC-LoRA pattern)
20
+ - Uses OVERLAPPING positions (same coordinate space) so RoPE doesn't
21
+ decay target->ref attention. The asymmetric mask provides the structural
22
+ signal that ref tokens are conditioning, not reconstruction targets.
23
+
24
+ Args:
25
+ latent: Reference audio latent [B, C, T, F] (pre-VAE-encoded).
26
+ strength: Conditioning strength. 1.0 = full (ref kept clean),
27
+ 0.0 = none (ref fully denoised). Default 1.0.
28
+ """
29
+
30
+ def __init__(self, latent: torch.Tensor, strength: float = 1.0):
31
+ self.latent = latent
32
+ self.strength = strength
33
+
34
+ def apply_to(
35
+ self,
36
+ latent_state: LatentState,
37
+ latent_tools: AudioLatentTools,
38
+ ) -> LatentState:
39
+ """Append reference audio tokens with positions and attention mask."""
40
+ tokens = latent_tools.patchifier.patchify(self.latent)
41
+
42
+ # Compute positions for the reference audio — small offset (0.5s) from
43
+ # target start to avoid exact t=0 overlap (which causes ref content to
44
+ # bleed into target start), while keeping RoPE decay minimal.
45
+ # 0.5s / max_pos(20s) = 0.025 fractional — negligible RoPE decay.
46
+ ref_shape = AudioLatentShape(
47
+ batch=self.latent.shape[0],
48
+ channels=self.latent.shape[1],
49
+ frames=self.latent.shape[2],
50
+ mel_bins=self.latent.shape[3],
51
+ )
52
+ positions = latent_tools.patchifier.get_patch_grid_bounds(
53
+ output_shape=ref_shape,
54
+ device=self.latent.device,
55
+ )
56
+ # Small offset to prevent t=0 position collision between target and ref
57
+ positions = positions + 0.5
58
+
59
+ # Denoise mask: 0 for frozen (strength=1.0), 1 for fully denoised (strength=0.0)
60
+ denoise_mask = torch.full(
61
+ size=(*tokens.shape[:2], 1),
62
+ fill_value=1.0 - self.strength,
63
+ device=self.latent.device,
64
+ dtype=torch.float32,
65
+ )
66
+
67
+ # Build ASYMMETRIC attention mask manually.
68
+ # Structure:
69
+ # target (N) ref (M)
70
+ # ┌────────────┬──────────┐
71
+ # target │ 1.0 │ 1.0 │ target attends to everything
72
+ # (N) │ │ │
73
+ # ├────────────┼──────────┤
74
+ # ref │ 0.0 │ 1.0 │ ref only attends to itself
75
+ # (M) │ │ │
76
+ # └────────────┴──────────┘
77
+ #
78
+ # This makes reference tokens "read-only conditioning":
79
+ # - Target tokens freely attend to ref (voice cloning signal)
80
+ # - Ref tokens don't attend to noisy target (stays clean/stable)
81
+ batch_size = tokens.shape[0]
82
+ num_target = latent_state.latent.shape[1]
83
+ num_ref = tokens.shape[1]
84
+ total = num_target + num_ref
85
+
86
+ # Use float32 for the [0,1] mask — _prepare_self_attention_mask converts
87
+ # to log-space bias in the model's compute dtype before it reaches attention.
88
+ mask = torch.zeros(
89
+ (batch_size, total, total),
90
+ device=self.latent.device,
91
+ dtype=torch.float32,
92
+ )
93
+
94
+ # Incorporate existing mask if present, otherwise full attention for target
95
+ if latent_state.attention_mask is not None:
96
+ mask[:, :num_target, :num_target] = latent_state.attention_mask
97
+ else:
98
+ mask[:, :num_target, :num_target] = 1.0
99
+
100
+ # Target -> ref: FULL attention (target can read reference voice)
101
+ mask[:, :num_target, num_target:] = 1.0
102
+
103
+ # Ref -> target: BLOCKED (ref is read-only, doesn't see noisy target)
104
+ # mask[:, num_target:, :num_target] remains 0.0
105
+
106
+ # Ref -> ref: full self-attention within reference
107
+ mask[:, num_target:, num_target:] = 1.0
108
+
109
+ return LatentState(
110
+ latent=torch.cat([latent_state.latent, tokens], dim=1),
111
+ denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
112
+ positions=torch.cat([latent_state.positions, positions], dim=2),
113
+ clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
114
+ attention_mask=mask,
115
+ )
dramabox_src/audio_conditioning.py.training_helpers ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio reference conditioning item for IC-LoRA voice cloning."""
2
+
3
+ import torch
4
+
5
+ from ltx_core.components.patchifiers import AudioPatchifier
6
+ from ltx_core.conditioning.item import ConditioningItem
7
+ from ltx_core.tools import AudioLatentTools
8
+ from ltx_core.types import AudioLatentShape, LatentState
9
+
10
+
11
+ class AudioConditionByReferenceLatent(ConditioningItem):
12
+ """Conditions audio generation on a reference audio latent for voice cloning.
13
+
14
+ Mirrors VideoConditionByReferenceLatent but for audio:
15
+ - Patchifies reference latent [B, C, T, F] -> [B, ref_T, 128]
16
+ - Computes 1D temporal positions via AudioPatchifier
17
+ - Sets denoise_mask = 1.0 - strength (strength=1.0 -> mask=0 -> frozen)
18
+ - Builds ASYMMETRIC attention mask: target->ref=1 (attend), ref->target=0 (read-only)
19
+ - APPENDS ref tokens to END of latent sequence (IC-LoRA pattern)
20
+ - Uses OVERLAPPING positions (same coordinate space) so RoPE doesn't
21
+ decay target->ref attention. The asymmetric mask provides the structural
22
+ signal that ref tokens are conditioning, not reconstruction targets.
23
+
24
+ Args:
25
+ latent: Reference audio latent [B, C, T, F] (pre-VAE-encoded).
26
+ strength: Conditioning strength. 1.0 = full (ref kept clean),
27
+ 0.0 = none (ref fully denoised). Default 1.0.
28
+ """
29
+
30
+ def __init__(self, latent: torch.Tensor, strength: float = 1.0):
31
+ self.latent = latent
32
+ self.strength = strength
33
+
34
+ def apply_to(
35
+ self,
36
+ latent_state: LatentState,
37
+ latent_tools: AudioLatentTools,
38
+ ) -> LatentState:
39
+ """Append reference audio tokens with positions and attention mask."""
40
+ tokens = latent_tools.patchifier.patchify(self.latent)
41
+
42
+ # Compute positions for the reference audio — small offset (0.5s) from
43
+ # target start to avoid exact t=0 overlap (which causes ref content to
44
+ # bleed into target start), while keeping RoPE decay minimal.
45
+ # 0.5s / max_pos(20s) = 0.025 fractional — negligible RoPE decay.
46
+ ref_shape = AudioLatentShape(
47
+ batch=self.latent.shape[0],
48
+ channels=self.latent.shape[1],
49
+ frames=self.latent.shape[2],
50
+ mel_bins=self.latent.shape[3],
51
+ )
52
+ positions = latent_tools.patchifier.get_patch_grid_bounds(
53
+ output_shape=ref_shape,
54
+ device=self.latent.device,
55
+ )
56
+ # Small offset to prevent t=0 position collision between target and ref
57
+ positions = positions + 0.5
58
+
59
+ # Denoise mask: 0 for frozen (strength=1.0), 1 for fully denoised (strength=0.0)
60
+ denoise_mask = torch.full(
61
+ size=(*tokens.shape[:2], 1),
62
+ fill_value=1.0 - self.strength,
63
+ device=self.latent.device,
64
+ dtype=torch.float32,
65
+ )
66
+
67
+ # Build ASYMMETRIC attention mask manually.
68
+ # Structure:
69
+ # target (N) ref (M)
70
+ # ┌────────────┬──────────┐
71
+ # target │ 1.0 │ 1.0 │ target attends to everything
72
+ # (N) │ │ │
73
+ # ├────────────┼──────────┤
74
+ # ref │ 0.0 │ 1.0 │ ref only attends to itself
75
+ # (M) │ │ │
76
+ # └────────────┴──────────┘
77
+ #
78
+ # This makes reference tokens "read-only conditioning":
79
+ # - Target tokens freely attend to ref (voice cloning signal)
80
+ # - Ref tokens don't attend to noisy target (stays clean/stable)
81
+ batch_size = tokens.shape[0]
82
+ num_target = latent_state.latent.shape[1]
83
+ num_ref = tokens.shape[1]
84
+ total = num_target + num_ref
85
+
86
+ # Use float32 for the [0,1] mask — _prepare_self_attention_mask converts
87
+ # to log-space bias in the model's compute dtype before it reaches attention.
88
+ mask = torch.zeros(
89
+ (batch_size, total, total),
90
+ device=self.latent.device,
91
+ dtype=torch.float32,
92
+ )
93
+
94
+ # Incorporate existing mask if present, otherwise full attention for target
95
+ if latent_state.attention_mask is not None:
96
+ mask[:, :num_target, :num_target] = latent_state.attention_mask
97
+ else:
98
+ mask[:, :num_target, :num_target] = 1.0
99
+
100
+ # Target -> ref: FULL attention (target can read reference voice)
101
+ mask[:, :num_target, num_target:] = 1.0
102
+
103
+ # Ref -> target: BLOCKED (ref is read-only, doesn't see noisy target)
104
+ # mask[:, num_target:, :num_target] remains 0.0
105
+
106
+ # Ref -> ref: full self-attention within reference
107
+ mask[:, num_target:, num_target:] = 1.0
108
+
109
+ return LatentState(
110
+ latent=torch.cat([latent_state.latent, tokens], dim=1),
111
+ denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
112
+ positions=torch.cat([latent_state.positions, positions], dim=2),
113
+ clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
114
+ attention_mask=mask,
115
+ )
dramabox_src/inference.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ LTX-2.3 TTS with IC-LoRA voice cloning.
4
+
5
+ Uses AudioConditionByReferenceLatent to append reference audio tokens to the
6
+ end of the target sequence. Auto-detects distilled vs dev checkpoint and
7
+ selects the appropriate denoiser (SimpleDenoiser / GuidedDenoiser) and sigma
8
+ schedule. Leverages the official euler_denoising_loop, AudioLatentTools,
9
+ GaussianNoiser, and X0Model wrapper throughout.
10
+
11
+ Usage (distilled):
12
+ python tts_iclora.py \
13
+ --voice-sample reference.wav \
14
+ --prompt "A woman speaks clearly: The weather today will be sunny." \
15
+ --output tts_output.wav
16
+
17
+ Usage (dev):
18
+ python tts_iclora.py \
19
+ --voice-sample reference.wav \
20
+ --prompt "A woman speaks clearly: The weather today will be sunny." \
21
+ --checkpoint ltx-2.3-22b-dev-audio-only.safetensors \
22
+ --full-checkpoint ltx-2.3-22b-dev.safetensors \
23
+ --output tts_output.wav
24
+ """
25
+
26
+ import argparse
27
+ import json
28
+ import logging
29
+ import os
30
+ import re
31
+ import struct
32
+ import sys
33
+ import time
34
+ from pathlib import Path
35
+
36
+ import torch
37
+ import torchaudio
38
+
39
+ REPO_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
40
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ltx2"))
41
+ # ltx-pipelines already on path via ltx2/
42
+
43
+ # Also add the local directory so audio_conditioning.py is importable
44
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
45
+
46
+ MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "models")
47
+ GEMMA_DIR = os.environ.get("GEMMA_DIR", "gemma-3-12b-it-qat-q4_0-unquantized")
48
+
49
+
50
+ # ---------------------------------------------------------------------------
51
+ # Helpers
52
+ # ---------------------------------------------------------------------------
53
+
54
+
55
+ def detect_model_type(checkpoint_path: str) -> str:
56
+ """Detect if checkpoint is distilled or dev by checking filename and metadata."""
57
+ path_lower = checkpoint_path.lower()
58
+ if "distilled" in path_lower:
59
+ return "distilled"
60
+ if "dev" in path_lower:
61
+ return "dev"
62
+ # Fallback: try to read safetensors metadata
63
+ try:
64
+ with open(checkpoint_path, "rb") as f:
65
+ header_size = struct.unpack("<Q", f.read(8))[0]
66
+ header = json.loads(f.read(header_size).decode())
67
+ metadata = header.get("__metadata__", {})
68
+ version = metadata.get("model_version", "")
69
+ if "distilled" in version.lower():
70
+ return "distilled"
71
+ except Exception:
72
+ pass
73
+ # Default to distilled (most common for audio-only)
74
+ return "distilled"
75
+
76
+
77
+ _LAUGH_VERBS = {
78
+ # base seconds per occurrence; gets scaled by the modifier found nearby.
79
+ # Verb regex covers inflections: laugh/laughs/laughed/laughing.
80
+ r"\blaugh(?:s|ed|ing)?\b": 1.5,
81
+ r"\bcackl(?:e|es|ed|ing)\b": 1.5,
82
+ r"\bchuckl(?:e|es|ed|ing)\b": 1.0,
83
+ r"\bgiggl(?:e|es|ed|ing)\b": 1.0,
84
+ r"\bsnicker(?:s|ed|ing)?\b": 0.8,
85
+ r"\bcru?el laugh\b": 1.5,
86
+ }
87
+
88
+
89
+ def _contextual_laugh_duration(text: str) -> float:
90
+ """Context-aware laugh budget.
91
+
92
+ For each laugh verb in the prompt, look at the adjective/adverb that
93
+ modifies it and scale the base duration:
94
+ - short modifiers (briefly, softly, once) -> 0.4x base
95
+ - long modifiers (maniacally, heartily, ...) -> 1.2x base
96
+ - default (no mod / neutral) -> 1.0x base
97
+ Also reward phonetic repetition inside quotes -- 'Hahahahahaha' buys more
98
+ time than 'Haha' -- at ~0.2s per extra repeated syllable.
99
+ """
100
+ # "softly" / "quietly" describe volume not length, so keep at default 1.0x.
101
+ short_mod = re.compile(
102
+ r"^\s*(?:[a-z]+ly )?(?:briefly|shortly|once|quickly)",
103
+ re.IGNORECASE)
104
+ long_mod = re.compile(
105
+ r"^\s*(?:[a-z]+ly )?(?:maniacally|heartily|uproariously|uncontrollably|"
106
+ r"hysterically|darkly|wickedly|evilly|loudly|long)"
107
+ r"|^\s*between phrases", re.IGNORECASE)
108
+
109
+ total = 0.0
110
+ for pat, base_dur in _LAUGH_VERBS.items():
111
+ for m in re.finditer(pat, text, re.IGNORECASE):
112
+ ctx = text[m.end(): m.end() + 40]
113
+ if short_mod.match(ctx):
114
+ total += base_dur * 0.4
115
+ elif long_mod.match(ctx):
116
+ total += base_dur * 1.2
117
+ else:
118
+ total += base_dur
119
+
120
+ # Phonetic laugh repetition inside quotes:
121
+ # 'Haha' = 2 syllables (base, no bonus)
122
+ # 'Hahahaha' = 4 syllables (+0.4s)
123
+ # 'Hehehehahahahahahahaha' ~ 10 syllables (+1.6s)
124
+ for q in re.findall(r'"([^"]+)"', text) + re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text):
125
+ for run in re.findall(r"(?:h[ae]){3,}|(?:h[ae][ \-]?){3,}", q, re.IGNORECASE):
126
+ syls = len(re.findall(r"h[ae]", run, re.IGNORECASE))
127
+ total += 0.2 * max(syls - 2, 0)
128
+ return total
129
+
130
+
131
+ def _estimate_nonverbal_duration(text: str) -> float:
132
+ """Estimate extra duration for non-verbal sounds and actions in the prompt.
133
+
134
+ Laugh-verb handling lives in ``_contextual_laugh_duration`` so cackle /
135
+ chuckle / laugh budgets scale with the adjective ("maniacally" vs
136
+ "briefly") and with the repetition length of 'Ha'/'He' tokens inside
137
+ quotes.
138
+ """
139
+ PATTERNS = {
140
+ # Breathing / sighs
141
+ r'\bsighs?\b': 0.8, r'\bshaky breath\b': 1.0, r'\bbreathing deeply\b': 1.0,
142
+ r'\bgasps?\b': 0.5, r'\bburps?\b': 0.5, r'\byawns?\b': 1.0,
143
+ r'\bpants?\b': 0.8, r'\bwheezes?\b': 0.8, r'\bcoughs?\b': 0.8,
144
+ r'\bsniffles?\b': 0.5, r'\bsnorts?\b': 0.3, r'\bgroans?\b': 0.8,
145
+ # Pauses (trimmed; earlier values over-budgeted silence)
146
+ r'\blong pause\b': 1.0, r'\bpauses? briefly\b': 0.3,
147
+ r'\bpauses?\b': 0.5, r'\bsilence\b': 1.0,
148
+ r'\blets? the .{1,20} hang\b': 1.0, r'\blets? .{1,20} sink in\b': 1.0,
149
+ # Physical actions that produce sound
150
+ r'\bslams?\b': 0.5, r'\bclaps?\b': 0.3,
151
+ r'\bdraws? (?:his|her|a) sword\b': 0.5,
152
+ r'\btakes? a (?:drag|swig|sip|drink)\b': 0.5,
153
+ r'\bwhistles?\b': 1.0, r'\bhums?\b': 0.8,
154
+ # Vocal actions (not in quotes but take time)
155
+ r'\bmutters?\b': 1.5, r'\bmumbles?\b': 1.0, r'\bwhispers?\b': 0.0,
156
+ r'\bclears? (?:his|her) throat\b': 0.5, r'\bgulps?\b': 0.5,
157
+ r'\bswallows?\b': 0.5,
158
+ # (laugh / chuckle / cackle / giggle / snicker handled by
159
+ # _contextual_laugh_duration below -- modifier-aware, not flat.)
160
+ # Emotional transitions
161
+ r'\bvoice (?:breaks?|cracks?|trembles?|drops?|rises?)\b': 0.5,
162
+ r'\bsteadies? (?:him|her)self\b': 1.0,
163
+ r'\bcatches? (?:his|her) breath\b': 1.0,
164
+ r'\bcomposes? (?:him|her)self\b': 0.8,
165
+ # Scene transitions that imply time
166
+ r'\bdemeanor shifts?\b': 0.5, r'\bsettles? in\b': 0.5,
167
+ r'\bleans? in\b': 0.3, r'\bwipes? (?:his|her) eyes\b': 0.5,
168
+ }
169
+ extra = 0.0
170
+ for pattern, dur in PATTERNS.items():
171
+ extra += dur * len(re.findall(pattern, text, re.IGNORECASE))
172
+ extra += _contextual_laugh_duration(text)
173
+ return extra
174
+
175
+
176
+ def estimate_speech_duration(text: str, speed: float = 1.0) -> float:
177
+ """Estimate speech duration from spoken content + non-verbal actions.
178
+
179
+ Extracts spoken text by priority:
180
+ 1. Quoted text ('...' or "...") -- official prompt guide format
181
+ 2. Text after colon -- simple "Speaker: dialogue" format
182
+ 3. Full text -- fallback
183
+
184
+ Also scans the full prompt for non-verbal cues (laughs, pauses, sighs,
185
+ gasps, etc.) and adds estimated duration for each.
186
+ """
187
+ # Try double quotes first (clean, no contraction issues)
188
+ quotes = re.findall(r'"([^"]+)"', text)
189
+ if not quotes:
190
+ # Single quotes: allow apostrophes in contractions (don't, can't, it's)
191
+ # Match ' to ' but apostrophes NOT followed by space/punctuation are kept inside
192
+ quotes = re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text)
193
+ # Filter out short fragments (scene directions like "He pauses")
194
+ quotes = [q for q in quotes if len(q.split()) > 3]
195
+ if quotes:
196
+ spoken = " ".join(quotes)
197
+ elif ":" in text:
198
+ spoken = text.split(":", 1)[1].strip()
199
+ else:
200
+ spoken = text
201
+
202
+ CHARS_PER_SEC = 14.0
203
+ text_len = len(spoken)
204
+
205
+ if text_len < 40:
206
+ chars_per_sec = CHARS_PER_SEC * 0.6
207
+ elif text_len < 80:
208
+ chars_per_sec = CHARS_PER_SEC * 0.8
209
+ else:
210
+ chars_per_sec = CHARS_PER_SEC
211
+
212
+ chars_per_sec *= speed
213
+ duration = text_len / chars_per_sec
214
+
215
+ sentence_count = spoken.count(".") + spoken.count("!") + spoken.count("?")
216
+ duration += sentence_count * 0.3
217
+
218
+ # Add time for non-verbal sounds/actions in the full prompt
219
+ duration += _estimate_nonverbal_duration(text)
220
+
221
+ return max(3.0, round(duration + 2.0, 1))
222
+
223
+
224
+ def parse_args():
225
+ p = argparse.ArgumentParser(description="LTX-2.3 TTS with IC-LoRA voice cloning")
226
+
227
+ p.add_argument("--voice-sample", default=None, help="Voice reference WAV")
228
+ p.add_argument("--no-ref", action="store_true", help="Skip voice reference conditioning (raw base model)")
229
+ p.add_argument("--prompt", required=True, help="Text/scene description to synthesize")
230
+ p.add_argument("--output", default="tts_output.wav")
231
+
232
+ p.add_argument("--ref-duration", type=float, default=10.0, help="Seconds of voice reference to use")
233
+ p.add_argument("--gen-duration", type=float, default=0.0,
234
+ help="Target output duration in seconds (0 = auto from prompt + multiplier). "
235
+ "Set explicitly for long-form prompts (e.g. --gen-duration 30 for music). "
236
+ "Outputs >20.5s automatically engage the end-of-clip silence-prior patch.")
237
+ p.add_argument("--pad-start", type=float, default=0.0,
238
+ help="Prepend N seconds of silent padding, trimmed after decode (use 0 for clean starts)")
239
+ p.add_argument("--speed", type=float, default=1.0)
240
+ p.add_argument("--duration-multiplier", type=float, default=1.0,
241
+ help="Multiply auto-estimated duration by this factor (e.g. 1.1 for 10%% more breathing room)")
242
+
243
+ p.add_argument("--checkpoint", default=os.path.join(MODEL_DIR, "ltx-2.3-audio-only.safetensors"))
244
+ p.add_argument("--full-checkpoint", default=os.path.join(MODEL_DIR, "ltx-2.3-22b-distilled.safetensors"))
245
+ p.add_argument("--gemma-root", default=GEMMA_DIR)
246
+ p.add_argument("--bnb-4bit", dest="bnb_4bit", action="store_true", default=True,
247
+ help="Load Gemma text encoder via the bitsandbytes 4-bit path "
248
+ "(required for the default unsloth/gemma-3-12b-it-bnb-4bit "
249
+ "pre-quantized weights). Default: on.")
250
+ p.add_argument("--no-bnb-4bit", dest="bnb_4bit", action="store_false",
251
+ help="Disable the bitsandbytes path (use only if --gemma-root "
252
+ "points at an unquantized Gemma checkpoint).")
253
+ p.add_argument("--lora", default=None, help="Path to trained IC-LoRA .safetensors (audio-only)")
254
+ p.add_argument("--lora-rank", type=int, default=128, help="LoRA rank (must match training)")
255
+ p.add_argument("--id-guidance-scale", type=float, default=3.0, help="Identity guidance scale (0=disabled)")
256
+ p.add_argument("--seed", type=int, default=42)
257
+
258
+ # Auto-set based on model type but overridable
259
+ p.add_argument("--no-watermark", action="store_true",
260
+ help="Skip Perth audio watermarking on the output (default: watermark on).")
261
+ p.add_argument("--sampler", choices=["euler", "heun"], default="euler",
262
+ help="Denoising loop. 'heun' = jkass_quality 2nd-order predictor-corrector (~2x model calls, cleaner audio).")
263
+ p.add_argument("--cfg-scale", type=float, default=None, help="CFG scale (auto: 1.0 distilled, 7.0 dev)")
264
+ p.add_argument("--stg-scale", type=float, default=None, help="STG scale (auto: 0.0 distilled, 1.0 dev)")
265
+ p.add_argument("--stg-block", type=int, default=29, help="Block index for STG perturbation")
266
+ p.add_argument("--rescale-scale", type=float, default=None,
267
+ help="Latent CFG std-rescale (default auto: cfg-aware schedule that prevents "
268
+ "output clipping at high cfg; pass any float in [0,1] to override).")
269
+ p.add_argument("--modality-scale", type=float, default=None, help="Modality (auto: 1.0 distilled, 3.0 dev)")
270
+ p.add_argument("--cfg-clamp", type=float, default=0.0, help="Clamp guided pred std to N * cond std (0=disabled)")
271
+ p.add_argument("--steps", type=int, default=None, help="Override steps (auto: distilled sigmas / 30 dev)")
272
+ p.add_argument("--fps", type=float, default=None, help="FPS (auto: 24.0 distilled, 25.0 dev)")
273
+ p.add_argument(
274
+ "--negative-prompt",
275
+ default=(
276
+ "worst quality, inconsistent motion, blurry, jittery, distorted, "
277
+ "robotic voice, echo, background noise, off-sync audio, repetitive speech"
278
+ ),
279
+ help="Negative prompt for CFG (dev model)",
280
+ )
281
+
282
+ return p.parse_args()
283
+
284
+
285
+ @torch.inference_mode()
286
+ def main():
287
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
288
+ args = parse_args()
289
+ t0 = time.time()
290
+
291
+ # ---- Imports (deferred to avoid startup cost when checking --help) ----
292
+ from audio_conditioning import AudioConditionByReferenceLatent
293
+
294
+ from ltx_core.batch_split import BatchSplitAdapter
295
+ from ltx_core.components.diffusion_steps import EulerDiffusionStep
296
+ from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
297
+ from ltx_core.components.noisers import GaussianNoiser
298
+ from ltx_core.components.patchifiers import AudioPatchifier
299
+ from ltx_core.components.schedulers import LTX2Scheduler
300
+ from ltx_core.loader.registry import DummyRegistry
301
+ from ltx_core.loader.sd_ops import SDOps
302
+ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
303
+ from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
304
+ from ltx_core.model.model_protocol import ModelConfigurator
305
+ from ltx_core.model.transformer.attention import AttentionFunction
306
+ from ltx_core.model.transformer.model import LTXModel, LTXModelType, X0Model
307
+ from ltx_core.model.transformer.rope import LTXRopeType
308
+ from ltx_core.tools import AudioLatentTools
309
+ from ltx_core.types import Audio, AudioLatentShape, LatentState, VideoPixelShape
310
+ from ltx_pipelines.utils.blocks import AudioConditioner, AudioDecoder, PromptEncoder
311
+ from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES
312
+ from ltx_pipelines.utils.denoisers import GuidedDenoiser, SimpleDenoiser
313
+ from ltx_pipelines.utils.gpu_model import gpu_model
314
+ from ltx_pipelines.utils.media_io import decode_audio_from_file
315
+ from ltx_pipelines.utils.samplers import euler_denoising_loop, heun_denoising_loop
316
+
317
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
318
+ dtype = torch.bfloat16
319
+ patchifier = AudioPatchifier(patch_size=1)
320
+
321
+ # ---- Detect model type and set defaults ----
322
+ model_type = detect_model_type(args.full_checkpoint)
323
+ logging.info(f"Detected model type: {model_type}")
324
+
325
+ is_distilled = model_type == "distilled"
326
+
327
+ if args.cfg_scale is None:
328
+ args.cfg_scale = 1.0 if is_distilled else 7.0
329
+ if args.stg_scale is None:
330
+ args.stg_scale = 0.0 if is_distilled else 1.0
331
+ if args.rescale_scale is None:
332
+ # Auto cfg-aware rescale: imported from inference_server to keep one source of truth.
333
+ from inference_server import auto_rescale_for_cfg
334
+ args.rescale_scale = 0.0 if is_distilled else auto_rescale_for_cfg(args.cfg_scale)
335
+ if args.modality_scale is None:
336
+ args.modality_scale = 1.0 if is_distilled else 3.0
337
+ if args.fps is None:
338
+ args.fps = 24.0 if is_distilled else 25.0
339
+
340
+ logging.info(
341
+ f"Params: cfg={args.cfg_scale}, stg={args.stg_scale}, rescale={args.rescale_scale}, "
342
+ f"modality={args.modality_scale}, fps={args.fps}"
343
+ )
344
+
345
+ # ---- Auto duration ----
346
+ if args.gen_duration <= 0:
347
+ args.gen_duration = estimate_speech_duration(args.prompt, args.speed)
348
+ if args.duration_multiplier != 1.0:
349
+ args.gen_duration = round(args.gen_duration * args.duration_multiplier, 1)
350
+ logging.info(f"Auto duration: {args.gen_duration}s for {len(args.prompt)} chars"
351
+ f"{f' (x{args.duration_multiplier})' if args.duration_multiplier != 1.0 else ''}")
352
+
353
+ # ---- Compute target shape (include pad_start in duration) ----
354
+ padded_duration = args.gen_duration + args.pad_start
355
+ raw_frames = int(round(padded_duration * args.fps)) + 1
356
+ num_frames = ((raw_frames - 1 + 4) // 8) * 8 + 1
357
+ pixel_shape = VideoPixelShape(batch=1, frames=num_frames, height=64, width=64, fps=args.fps)
358
+ tgt_shape = AudioLatentShape.from_video_pixel_shape(pixel_shape)
359
+ logging.info(f"Target shape: {tgt_shape} ({args.gen_duration}s, {num_frames} frames)")
360
+
361
+ # ---- AudioLatentTools for target ----
362
+ audio_tools = AudioLatentTools(patchifier=patchifier, target_shape=tgt_shape)
363
+
364
+ # ---- Create initial state ----
365
+ state = audio_tools.create_initial_state(device, dtype)
366
+ logging.info(
367
+ f"Initial state: latent={state.latent.shape}, positions={state.positions.shape}, "
368
+ f"denoise_mask={state.denoise_mask.shape}"
369
+ )
370
+
371
+ if not args.no_ref and args.voice_sample:
372
+ # ---- Encode voice reference ----
373
+ logging.info(f"Loading voice reference: {args.voice_sample}")
374
+ voice = decode_audio_from_file(args.voice_sample, device, 0.0, args.ref_duration)
375
+ if voice is None:
376
+ raise ValueError(f"Could not load audio from {args.voice_sample}")
377
+
378
+ w = voice.waveform
379
+ if w.dim() == 2:
380
+ if w.shape[0] == 1:
381
+ w = w.repeat(2, 1)
382
+ w = w.unsqueeze(0)
383
+ elif w.dim() == 3 and w.shape[1] == 1:
384
+ w = w.repeat(1, 2, 1)
385
+
386
+ target_samples = int(args.ref_duration * voice.sampling_rate)
387
+ if w.shape[-1] < target_samples:
388
+ w = w.repeat(1, 1, (target_samples // w.shape[-1]) + 1)
389
+ w = w[..., :target_samples]
390
+
391
+ # Peak normalize reference
392
+ peak = w.abs().max()
393
+ if peak > 0:
394
+ target_peak = 10 ** (-4.0 / 20) # -4dB
395
+ w = w * (target_peak / peak)
396
+ logging.info(f"Normalized reference: peak {peak:.4f} -> {target_peak:.4f}")
397
+
398
+ voice = Audio(waveform=w, sampling_rate=voice.sampling_rate)
399
+
400
+ logging.info("Encoding voice through Audio VAE...")
401
+ ac = AudioConditioner(checkpoint_path=args.full_checkpoint, dtype=dtype, device=device)
402
+ ref_latent = ac(lambda enc: vae_encode_audio(voice, enc, None))
403
+ del ac
404
+ torch.cuda.empty_cache()
405
+ logging.info(f"Reference latent: {ref_latent.shape}")
406
+
407
+ # ---- Apply conditioning: append ref tokens to END ----
408
+ conditioning = AudioConditionByReferenceLatent(latent=ref_latent.to(device, dtype), strength=1.0)
409
+ state = conditioning.apply_to(latent_state=state, latent_tools=audio_tools)
410
+ logging.info(
411
+ f"After conditioning: latent={state.latent.shape}, positions={state.positions.shape}, "
412
+ f"attention_mask={'None' if state.attention_mask is None else state.attention_mask.shape}"
413
+ )
414
+ else:
415
+ logging.info("No voice reference — running raw base model")
416
+
417
+ # ---- Apply noise ----
418
+ generator = torch.Generator(device=device).manual_seed(args.seed)
419
+ noiser = GaussianNoiser(generator=generator)
420
+ noised_state = noiser(state, noise_scale=1.0)
421
+ logging.info("Applied Gaussian noise to state")
422
+
423
+ # ---- Encode prompt ----
424
+ use_cfg = args.cfg_scale > 1.0
425
+ logging.info("Encoding prompt...")
426
+ pe = PromptEncoder(checkpoint_path=args.full_checkpoint, gemma_root=args.gemma_root, dtype=dtype, device=device,
427
+ use_bnb_4bit=args.bnb_4bit, warm=True)
428
+ prompts_to_encode = [args.prompt]
429
+ if use_cfg:
430
+ prompts_to_encode.append(args.negative_prompt)
431
+ ctx = pe(prompts_to_encode, streaming_prefetch_count=None)
432
+ a_ctx = ctx[0].audio_encoding
433
+ a_ctx_neg = ctx[1].audio_encoding if use_cfg else None
434
+ del pe
435
+ torch.cuda.empty_cache()
436
+ logging.info(f"Prompt encoded: a_ctx={a_ctx.shape}" + (f", a_ctx_neg={a_ctx_neg.shape}" if a_ctx_neg is not None else ""))
437
+
438
+ # ---- Build audio-only model ----
439
+ logging.info("Building audio-only model...")
440
+ audio_only_sd_ops = SDOps("AO").with_matching(prefix="model.diffusion_model.").with_replacement(
441
+ "model.diffusion_model.", ""
442
+ )
443
+
444
+ class AudioOnlyConfigurator(ModelConfigurator[LTXModel]):
445
+ @classmethod
446
+ def from_config(cls, config):
447
+ t = config.get("transformer", {})
448
+ cp = None
449
+ if not t.get("caption_proj_before_connector", False):
450
+ from ltx_core.model.transformer.text_projection import create_caption_projection
451
+
452
+ with torch.device("meta"):
453
+ cp = create_caption_projection(t, audio=True)
454
+ return LTXModel(
455
+ model_type=LTXModelType.AudioOnly,
456
+ audio_num_attention_heads=t.get("audio_num_attention_heads", 32),
457
+ audio_attention_head_dim=t.get("audio_attention_head_dim", 64),
458
+ audio_in_channels=t.get("audio_in_channels", 128),
459
+ audio_out_channels=t.get("audio_out_channels", 128),
460
+ num_layers=t.get("num_layers", 48),
461
+ audio_cross_attention_dim=t.get("audio_cross_attention_dim", 2048),
462
+ norm_eps=t.get("norm_eps", 1e-6),
463
+ attention_type=AttentionFunction(t.get("attention_type", "default")),
464
+ positional_embedding_theta=10000.0,
465
+ audio_positional_embedding_max_pos=[20.0],
466
+ timestep_scale_multiplier=t.get("timestep_scale_multiplier", 1000),
467
+ use_middle_indices_grid=t.get("use_middle_indices_grid", True),
468
+ rope_type=LTXRopeType(t.get("rope_type", "interleaved")),
469
+ double_precision_rope=t.get("frequencies_precision", False) == "float64",
470
+ apply_gated_attention=t.get("apply_gated_attention", False),
471
+ audio_caption_projection=cp,
472
+ cross_attention_adaln=t.get("cross_attention_adaln", False),
473
+ )
474
+
475
+ builder = Builder(
476
+ model_path=args.checkpoint,
477
+ model_class_configurator=AudioOnlyConfigurator,
478
+ model_sd_ops=audio_only_sd_ops,
479
+ registry=DummyRegistry(),
480
+ )
481
+ velocity_model = builder.build(device=device, dtype=dtype).to(device).eval()
482
+
483
+ # ---- Load LoRA weights (if provided) ----
484
+ if args.lora and os.path.exists(args.lora):
485
+ from peft import LoraConfig, get_peft_model
486
+ from safetensors.torch import load_file as st_load
487
+
488
+ logging.info(f"Loading LoRA: {args.lora}")
489
+ lora_sd = st_load(args.lora)
490
+
491
+ is_peft_format = any("base_model.model." in k for k in lora_sd.keys())
492
+ is_original_idlora = any("diffusion_model." in k for k in lora_sd.keys())
493
+
494
+ lora_config = LoraConfig(
495
+ r=args.lora_rank,
496
+ lora_alpha=args.lora_rank,
497
+ lora_dropout=0.0,
498
+ bias="none",
499
+ target_modules=[
500
+ "audio_attn1.to_k",
501
+ "audio_attn1.to_q",
502
+ "audio_attn1.to_v",
503
+ "audio_attn1.to_out.0",
504
+ "audio_attn2.to_k",
505
+ "audio_attn2.to_q",
506
+ "audio_attn2.to_v",
507
+ "audio_attn2.to_out.0",
508
+ "audio_ff.net.0.proj",
509
+ "audio_ff.net.2",
510
+ ],
511
+ )
512
+ velocity_model = get_peft_model(velocity_model, lora_config)
513
+
514
+ if is_peft_format:
515
+ mapped_sd = {}
516
+ for k, v in lora_sd.items():
517
+ new_key = k
518
+ if ".lora_A.weight" in k and ".lora_A.default.weight" not in k:
519
+ new_key = k.replace(".lora_A.weight", ".lora_A.default.weight")
520
+ if ".lora_B.weight" in k and ".lora_B.default.weight" not in k:
521
+ new_key = k.replace(".lora_B.weight", ".lora_B.default.weight")
522
+ mapped_sd[new_key] = v
523
+ missing, unexpected = velocity_model.load_state_dict(mapped_sd, strict=False)
524
+ loaded = len(mapped_sd) - len(unexpected)
525
+ logging.info(f"Loaded {loaded} LoRA weights (peft format)")
526
+ elif is_original_idlora:
527
+ audio_keys = {
528
+ k: v
529
+ for k, v in lora_sd.items()
530
+ if "audio_attn1" in k or "audio_attn2" in k or "audio_ff" in k
531
+ }
532
+ mapped_sd = {}
533
+ for k, v in audio_keys.items():
534
+ new_key = k.replace("diffusion_model.", "base_model.model.")
535
+ new_key = new_key.replace(".lora_A.weight", ".lora_A.default.weight")
536
+ new_key = new_key.replace(".lora_B.weight", ".lora_B.default.weight")
537
+ mapped_sd[new_key] = v
538
+ missing, unexpected = velocity_model.load_state_dict(mapped_sd, strict=False)
539
+ loaded = len(mapped_sd) - len(unexpected)
540
+ logging.info(f"Loaded {loaded} LoRA weights (original ID-LoRA)")
541
+
542
+ velocity_model = velocity_model.merge_and_unload()
543
+ logging.info("Merged LoRA into model")
544
+
545
+ logging.info(f"Model: {sum(p.numel() for p in velocity_model.parameters()) / 1e9:.1f}B params")
546
+
547
+ # ---- Wrap velocity model in X0Model ----
548
+ x0_model = X0Model(velocity_model)
549
+
550
+ # ---- Build denoiser and sigmas ----
551
+ stepper = EulerDiffusionStep()
552
+
553
+ # ---- Sigma schedule ----
554
+ if is_distilled:
555
+ if args.steps is not None and args.steps > 0:
556
+ sigmas = LTX2Scheduler().execute(steps=args.steps, latent=noised_state.latent).to(device)
557
+ logging.info(f"Distilled with custom {args.steps}-step schedule")
558
+ else:
559
+ sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, dtype=torch.float32, device=device)
560
+ logging.info(f"Distilled {len(DISTILLED_SIGMA_VALUES) - 1}-step schedule")
561
+ else:
562
+ steps = args.steps if args.steps is not None and args.steps > 0 else 30
563
+ sigmas = LTX2Scheduler().execute(steps=steps, latent=noised_state.latent).to(device)
564
+ logging.info(f"Dev {steps}-step schedule")
565
+
566
+ # ---- Denoiser: use GuidedDenoiser if any guidance is active, SimpleDenoiser otherwise ----
567
+ needs_guidance = args.cfg_scale > 1.0 or args.stg_scale > 0.0 or args.modality_scale > 1.0
568
+ if needs_guidance:
569
+ audio_guider = MultiModalGuider(
570
+ params=MultiModalGuiderParams(
571
+ cfg_scale=args.cfg_scale,
572
+ stg_scale=args.stg_scale,
573
+ stg_blocks=[args.stg_block] if args.stg_scale > 0 else [],
574
+ rescale_scale=args.rescale_scale,
575
+ modality_scale=args.modality_scale,
576
+ cfg_clamp_scale=args.cfg_clamp,
577
+ ),
578
+ negative_context=a_ctx_neg,
579
+ )
580
+ denoiser = GuidedDenoiser(
581
+ v_context=None,
582
+ a_context=a_ctx,
583
+ video_guider=None,
584
+ audio_guider=audio_guider,
585
+ )
586
+ logging.info(f"GuidedDenoiser: cfg={args.cfg_scale}, stg={args.stg_scale}, "
587
+ f"rescale={args.rescale_scale}, modality={args.modality_scale}")
588
+ else:
589
+ denoiser = SimpleDenoiser(v_context=None, a_context=a_ctx)
590
+ logging.info("SimpleDenoiser (no guidance)")
591
+
592
+ logging.info(f"Sigmas: {sigmas.tolist()}")
593
+
594
+ # ---- Denoising loop ----
595
+ logging.info(f"Running denoising loop ({len(sigmas) - 1} steps)...")
596
+ with gpu_model(x0_model) as model:
597
+ batched_model = BatchSplitAdapter(model, max_batch_size=1)
598
+
599
+ denoise_fn = heun_denoising_loop if args.sampler == "heun" else euler_denoising_loop
600
+ _, audio_state = denoise_fn(
601
+ sigmas=sigmas,
602
+ video_state=None,
603
+ audio_state=noised_state,
604
+ stepper=stepper,
605
+ transformer=batched_model,
606
+ denoiser=denoiser,
607
+ )
608
+
609
+ del velocity_model, x0_model
610
+ torch.cuda.empty_cache()
611
+
612
+ # ---- Strip ref tokens and unpatchify ----
613
+ logging.info("Stripping conditioning and unpatchifying...")
614
+ audio_state = audio_tools.clear_conditioning(audio_state)
615
+ audio_state = audio_tools.unpatchify(audio_state)
616
+ logging.info(f"Final latent shape: {audio_state.latent.shape}")
617
+
618
+ # ---- End-of-clip silence-prior fix ----
619
+ # Base LTX-2.3 22B was trained on audio clips ≤ ~20 s and learned a strong
620
+ # "clip-end silence" prior at the next patchifier-aligned latent boundary
621
+ # (frame 513 = 8 × 64 + 1). For longer outputs that prior leaks through as
622
+ # a ~30 ms hard silence dip near 20.4 s. Linearly interpolating frames
623
+ # 512–513 between their neighbours (511 and 514) removes the dip cleanly.
624
+ latent_in = audio_state.latent
625
+ if latent_in.shape[2] > 513:
626
+ f0, f1 = 511, 514
627
+ n = f1 - f0
628
+ patched = latent_in.clone()
629
+ for f in (512, 513):
630
+ t = (f - f0) / n
631
+ patched[:, :, f, :] = (1.0 - t) * latent_in[:, :, f0, :] + t * latent_in[:, :, f1, :]
632
+ latent_in = patched
633
+
634
+ # ---- Decode audio ----
635
+ logging.info("Decoding audio...")
636
+ ad = AudioDecoder(checkpoint_path=args.full_checkpoint, dtype=dtype, device=device)
637
+ decoded = ad(latent_in)
638
+ del ad
639
+ torch.cuda.empty_cache()
640
+
641
+ wav = decoded.waveform
642
+ if wav.dim() == 1:
643
+ wav = wav.unsqueeze(0)
644
+ sr = decoded.sampling_rate
645
+
646
+ # Trim leading pad if --pad-start was used
647
+ if args.pad_start > 0:
648
+ trim_samples = int(args.pad_start * sr)
649
+ wav = wav[..., trim_samples:]
650
+ logging.info(f"Trimmed {args.pad_start}s ({trim_samples} samples) of start padding")
651
+
652
+ # Apply Perth (Perceptual Threshold) imperceptible neural watermark — see
653
+ # https://github.com/resemble-ai/perth. Mono waveform required; if stereo,
654
+ # we average to mono for the watermark and broadcast back. Skip on
655
+ # --no-watermark for debugging.
656
+ wav_cpu = wav.float().cpu()
657
+ if not getattr(args, "no_watermark", False):
658
+ try:
659
+ import perth
660
+ import numpy as np
661
+ wm = perth.PerthImplicitWatermarker()
662
+ mono = wav_cpu.mean(dim=0).numpy() if wav_cpu.shape[0] > 1 else wav_cpu[0].numpy()
663
+ mono_wm = wm.apply_watermark(mono, sample_rate=sr)
664
+ mono_wm_t = torch.from_numpy(np.asarray(mono_wm, dtype=np.float32)).unsqueeze(0)
665
+ wav_cpu = mono_wm_t if wav_cpu.shape[0] == 1 else mono_wm_t.repeat(wav_cpu.shape[0], 1)
666
+ except Exception as e:
667
+ logging.warning(f"Perth watermark skipped ({e})")
668
+
669
+ os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
670
+ torchaudio.save(args.output, wav_cpu, sr)
671
+
672
+ elapsed = time.time() - t0
673
+ logging.info(f"Output: {args.output} ({wav.shape[-1] / sr:.1f}s)")
674
+ logging.info(f"Total time: {elapsed:.1f}s")
675
+
676
+
677
+ if __name__ == "__main__":
678
+ main()
dramabox_src/inference_server.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Warm TTS server — loads models once, accepts requests via stdin or function call.
4
+
5
+ The key insight: inference.py spends 11s on Gemma + 8s on model load every call.
6
+ This server loads everything once and keeps it warm.
7
+
8
+ We import and call the same code paths as inference.py but cache the heavy objects.
9
+ """
10
+ import json
11
+ import logging
12
+ import os
13
+ import re
14
+ import sys
15
+ import time
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torchaudio
20
+
21
+ # Setup paths
22
+ APP_DIR = Path(__file__).parent.parent
23
+ sys.path.insert(0, str(APP_DIR / "ltx2"))
24
+ sys.path.insert(0, str(APP_DIR / "src"))
25
+
26
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
27
+
28
+ from audio_conditioning import AudioConditionByReferenceLatent
29
+ from ltx_core.components.noisers import GaussianNoiser
30
+ from ltx_core.components.patchifiers import AudioPatchifier
31
+ from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
32
+ from ltx_core.components.schedulers import LTX2Scheduler
33
+ from ltx_core.components.diffusion_steps import EulerDiffusionStep
34
+ from ltx_core.loader import DummyRegistry
35
+ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
36
+ from ltx_core.loader.sd_ops import SDOps
37
+ from ltx_core.model.transformer.model import LTXModel, LTXModelType, X0Model
38
+ from ltx_core.model.transformer.rope import LTXRopeType
39
+ from ltx_core.model.transformer.text_projection import create_caption_projection
40
+ from ltx_core.model.transformer.attention import AttentionFunction
41
+ from ltx_core.model.model_protocol import ModelConfigurator
42
+ from ltx_core.tools import AudioLatentTools
43
+ from ltx_core.types import Audio, AudioLatentShape, VideoPixelShape
44
+ from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
45
+ from ltx_pipelines.utils.blocks import AudioConditioner, AudioDecoder, PromptEncoder
46
+ from ltx_pipelines.utils.media_io import decode_audio_from_file
47
+ from ltx_pipelines.utils.denoisers import GuidedDenoiser
48
+ from ltx_pipelines.utils.samplers import euler_denoising_loop
49
+ from safetensors import safe_open
50
+
51
+
52
+ DEFAULT_NEG = "worst quality, inconsistent, robotic, distorted, noise, static, muffled, unclear, unnatural, monotone"
53
+
54
+
55
+ def estimate_duration(prompt, multiplier=1.1):
56
+ """Defer to the richer CLI estimator (sentence-aware + non-verbal action
57
+ budget) so warm-server outputs match the lengths of the per-call CLI runs."""
58
+ from inference import estimate_speech_duration
59
+ base = estimate_speech_duration(prompt)
60
+ return max(3.0, round(base * multiplier, 1))
61
+
62
+
63
+ def auto_rescale_for_cfg(cfg: float) -> float:
64
+ """CFG-aware std-rescale schedule that prevents output clipping at high cfg.
65
+
66
+ The CFG formula `pred = cond + (cfg-1)*(cond - uncond)` makes pred.std()
67
+ grow roughly linearly with cfg, which the audio VAE+vocoder render as
68
+ progressively louder waveforms. By cfg≈3 the output starts hard-clipping
69
+ at 0 dBFS — and clipped information is unrecoverable in post.
70
+
71
+ Empirical sweep on the blues prompt with the back-porch-boogie ref
72
+ (rescale_scale needed for ≥1 dB peak headroom):
73
+ cfg=2.5 → 0.2 ; cfg=3 → 0.6 ; cfg=4 → 0.8 ; cfg=5–8 → 0.8 ; cfg=10 → 1.0
74
+
75
+ Piecewise-linear fit through those points; returns 0 below cfg=2 (no CFG
76
+ even applied at cfg=1), plateaus at 0.8 between cfg=4 and cfg=8 to
77
+ preserve the "extra punch" of high-CFG generations, and ramps to 1.0 by
78
+ cfg=10.
79
+ """
80
+ if cfg <= 2.0:
81
+ return 0.0
82
+ if cfg <= 3.0:
83
+ return 0.6 * (cfg - 2.0) # 0 → 0.6
84
+ if cfg <= 4.0:
85
+ return 0.6 + 0.2 * (cfg - 3.0) # 0.6 → 0.8
86
+ if cfg <= 8.0:
87
+ return 0.8 # plateau
88
+ return min(1.0, 0.8 + 0.1 * (cfg - 8.0)) # 0.8 → 1.0 at cfg=10
89
+
90
+
91
+ class TTSServer:
92
+ def __init__(self, checkpoint=None, full_checkpoint=None, gemma_root=None,
93
+ device="cuda", dtype="bf16", compile_model=True, bnb_4bit=True):
94
+ MODELS = APP_DIR / "models"
95
+ self.checkpoint = checkpoint or str(MODELS / "ltx-2.3-22b-dev-audio-only-v13-merged.safetensors")
96
+ self.full_checkpoint = full_checkpoint or os.environ.get(
97
+ "LTX_FULL_CHECKPOINT", "/mnt/persistent0/manmay/models/ltx23/ltx-2.3-22b-dev.safetensors")
98
+ if gemma_root is None and not os.environ.get("GEMMA_DIR"):
99
+ from model_downloader import get_gemma_path
100
+ gemma_root = get_gemma_path()
101
+ self.gemma_root = gemma_root or os.environ["GEMMA_DIR"]
102
+ self.device = torch.device(device)
103
+ self.dtype = torch.float16 if dtype == "fp16" else torch.bfloat16
104
+ self.compile_model = compile_model
105
+ self.bnb_4bit = bnb_4bit
106
+ self.patchifier = AudioPatchifier(patch_size=1)
107
+
108
+ # Cached models
109
+ self._prompt_encoder = None
110
+ self._velocity_model = None
111
+ self._audio_conditioner = None
112
+ self._audio_decoder = None
113
+
114
+ logging.info(f"TTSServer loading on {device}...")
115
+ t0 = time.time()
116
+ self._load_all()
117
+ logging.info(f"All models loaded in {time.time()-t0:.1f}s — ready for requests")
118
+
119
+ def _load_all(self):
120
+ # 1. Prompt encoder (Gemma + embeddings processor kept warm)
121
+ t0 = time.time()
122
+ self._prompt_encoder = PromptEncoder(
123
+ checkpoint_path=self.full_checkpoint,
124
+ gemma_root=self.gemma_root,
125
+ dtype=self.dtype, device=self.device,
126
+ warm=True,
127
+ use_bnb_4bit=self.bnb_4bit,
128
+ audio_only=True,
129
+ )
130
+ logging.info(f" PromptEncoder (warm): {time.time()-t0:.1f}s")
131
+
132
+ # 2. Audio conditioner (VAE encoder kept warm)
133
+ t0 = time.time()
134
+ self._audio_conditioner = AudioConditioner(
135
+ checkpoint_path=self.full_checkpoint,
136
+ dtype=self.dtype, device=self.device,
137
+ warm=True,
138
+ )
139
+ logging.info(f" AudioConditioner (warm): {time.time()-t0:.1f}s")
140
+
141
+ # 3. Transformer
142
+ t0 = time.time()
143
+ with safe_open(self.checkpoint, framework="pt") as f:
144
+ config = json.loads(f.metadata()["config"])
145
+
146
+ t = config.get("transformer", {})
147
+
148
+ class AudioOnlyConfigurator(ModelConfigurator[LTXModel]):
149
+ @classmethod
150
+ def from_config(cls, cfg):
151
+ t = cfg.get("transformer", {})
152
+ cp = None
153
+ if not t.get("caption_proj_before_connector", False):
154
+ with torch.device("meta"):
155
+ cp = create_caption_projection(t, audio=True)
156
+ return LTXModel(
157
+ model_type=LTXModelType.AudioOnly,
158
+ audio_num_attention_heads=t.get("audio_num_attention_heads", 32),
159
+ audio_attention_head_dim=t.get("audio_attention_head_dim", 64),
160
+ audio_in_channels=t.get("audio_in_channels", 128),
161
+ audio_out_channels=t.get("audio_out_channels", 128),
162
+ num_layers=t.get("num_layers", 48),
163
+ audio_cross_attention_dim=t.get("audio_cross_attention_dim", 2048),
164
+ norm_eps=t.get("norm_eps", 1e-6),
165
+ attention_type=AttentionFunction(t.get("attention_type", "default")),
166
+ positional_embedding_theta=10000.0,
167
+ audio_positional_embedding_max_pos=[20.0],
168
+ timestep_scale_multiplier=t.get("timestep_scale_multiplier", 1000),
169
+ use_middle_indices_grid=t.get("use_middle_indices_grid", True),
170
+ rope_type=LTXRopeType(t.get("rope_type", "interleaved")),
171
+ double_precision_rope=t.get("frequencies_precision", False) == "float64",
172
+ apply_gated_attention=t.get("apply_gated_attention", False),
173
+ audio_caption_projection=cp,
174
+ cross_attention_adaln=t.get("cross_attention_adaln", False),
175
+ )
176
+
177
+ audio_sd_ops = SDOps("AO").with_matching(prefix="model.diffusion_model.").with_replacement(
178
+ "model.diffusion_model.", "")
179
+ builder = Builder(
180
+ model_path=self.checkpoint,
181
+ model_class_configurator=AudioOnlyConfigurator,
182
+ model_sd_ops=audio_sd_ops,
183
+ registry=DummyRegistry(),
184
+ )
185
+ self._velocity_model = builder.build(device=self.device, dtype=self.dtype).to(self.device).eval()
186
+ n_params = sum(p.numel() for p in self._velocity_model.parameters()) / 1e9
187
+ vram_gb = sum(p.numel() * p.element_size() for p in self._velocity_model.parameters()) / 1e9
188
+ logging.info(f" Transformer: {time.time()-t0:.1f}s ({n_params:.1f}B params, {vram_gb:.1f}GB VRAM, {self.dtype})")
189
+
190
+ # torch.compile for faster denoising
191
+ if self.compile_model:
192
+ t0 = time.time()
193
+ logging.info(" Compiling transformer with torch.compile (default mode)...")
194
+ self._velocity_model = torch.compile(self._velocity_model, mode="default", dynamic=True)
195
+ logging.info(f" Compiled: {time.time()-t0:.1f}s (first call triggers actual compilation)")
196
+
197
+ # 4. Audio decoder (VAE decoder + vocoder kept warm)
198
+ t0 = time.time()
199
+ self._audio_decoder = AudioDecoder(
200
+ checkpoint_path=self.full_checkpoint,
201
+ dtype=self.dtype, device=self.device,
202
+ warm=True,
203
+ )
204
+ logging.info(f" AudioDecoder (warm): {time.time()-t0:.1f}s")
205
+
206
+ @torch.inference_mode()
207
+ def generate(self, prompt, voice_ref=None, cfg_scale=2.5, stg_scale=1.5,
208
+ duration_multiplier=1.1, seed=42, ref_duration=10.0,
209
+ rescale_scale="auto", gen_duration: float = 0.0):
210
+ """Generate audio. Returns (waveform_path, duration_seconds).
211
+
212
+ rescale_scale: latent-side CFG std-rescale that prevents clipping at
213
+ high cfg. Set to "auto" (default) for the cfg-aware schedule, a
214
+ float in [0, 1] for a fixed override, or 0 to disable.
215
+ gen_duration: explicit target duration in seconds. 0 (default) → auto
216
+ from prompt + duration_multiplier; >0 overrides everything else.
217
+ """
218
+ t_total = time.time()
219
+
220
+ # Duration + target shape — explicit gen_duration wins over the estimator.
221
+ if gen_duration and gen_duration > 0:
222
+ gen_dur = float(gen_duration)
223
+ else:
224
+ gen_dur = estimate_duration(prompt, duration_multiplier)
225
+ fps = 25.0
226
+ n_frames = int(round(gen_dur * fps)) + 1
227
+ n_frames = ((n_frames - 1 + 4) // 8) * 8 + 1
228
+ pixel_shape = VideoPixelShape(batch=1, frames=n_frames, height=64, width=64, fps=fps)
229
+ target_shape = AudioLatentShape.from_video_pixel_shape(pixel_shape)
230
+ audio_tools = AudioLatentTools(patchifier=self.patchifier, target_shape=target_shape)
231
+
232
+ # Initial state
233
+ state = audio_tools.create_initial_state(device=self.device, dtype=self.dtype)
234
+
235
+ # Voice ref conditioning
236
+ if voice_ref and os.path.exists(voice_ref):
237
+ t0 = time.time()
238
+ voice = decode_audio_from_file(voice_ref, self.device, 0.0, ref_duration)
239
+ w = voice.waveform
240
+ if w.dim() == 2:
241
+ if w.shape[0] == 1:
242
+ w = w.repeat(2, 1)
243
+ w = w.unsqueeze(0)
244
+ elif w.dim() == 3 and w.shape[1] == 1:
245
+ w = w.repeat(1, 2, 1)
246
+ target_samples = int(ref_duration * voice.sampling_rate)
247
+ if w.shape[-1] < target_samples:
248
+ w = w.repeat(1, 1, (target_samples // w.shape[-1]) + 1)
249
+ w = w[..., :target_samples]
250
+ peak = w.abs().max()
251
+ if peak > 0:
252
+ w = w * (10 ** (-4.0 / 20) / peak)
253
+ voice = Audio(waveform=w, sampling_rate=voice.sampling_rate)
254
+ ref_latent = self._audio_conditioner(lambda enc: vae_encode_audio(voice, enc, None))
255
+ cond = AudioConditionByReferenceLatent(latent=ref_latent.to(self.device, self.dtype), strength=1.0)
256
+ state = cond.apply_to(state, audio_tools)
257
+ logging.info(f"Voice ref: {time.time()-t0:.2f}s")
258
+
259
+ # Noise
260
+ gen = torch.Generator(device=self.device).manual_seed(seed)
261
+ noiser = GaussianNoiser(generator=gen)
262
+ state = noiser(state, noise_scale=1.0)
263
+
264
+ # Prompt encode
265
+ t0 = time.time()
266
+ prompts = [prompt, DEFAULT_NEG] if cfg_scale > 1.0 else [prompt]
267
+ ctx = self._prompt_encoder(prompts, streaming_prefetch_count=None)
268
+ a_ctx = ctx[0].audio_encoding
269
+ a_ctx_neg = ctx[1].audio_encoding if cfg_scale > 1.0 else None
270
+ logging.info(f"Prompt: {time.time()-t0:.2f}s")
271
+
272
+ # Denoiser
273
+ resc = auto_rescale_for_cfg(cfg_scale) if rescale_scale == "auto" else float(rescale_scale)
274
+ if rescale_scale == "auto":
275
+ logging.info(f"Auto rescale_scale = {resc:.2f} for cfg={cfg_scale}")
276
+ guider = MultiModalGuider(
277
+ params=MultiModalGuiderParams(
278
+ cfg_scale=cfg_scale, stg_scale=stg_scale,
279
+ stg_blocks=[29], rescale_scale=resc, modality_scale=1.0,
280
+ ),
281
+ negative_context=a_ctx_neg,
282
+ )
283
+ denoiser = GuidedDenoiser(
284
+ v_context=None, a_context=a_ctx,
285
+ video_guider=None, audio_guider=guider,
286
+ )
287
+
288
+ # Sigmas
289
+ sigmas = LTX2Scheduler().execute(steps=30, latent=state.latent).to(self.device)
290
+
291
+ # Denoise
292
+ t0 = time.time()
293
+ x0 = X0Model(self._velocity_model)
294
+ _, audio_state = euler_denoising_loop(
295
+ sigmas=sigmas, video_state=None, audio_state=state,
296
+ stepper=EulerDiffusionStep(), transformer=x0, denoiser=denoiser,
297
+ )
298
+ logging.info(f"Denoise (30 steps): {time.time()-t0:.2f}s")
299
+
300
+ # Strip + unpatchify + decode
301
+ audio_state = audio_tools.clear_conditioning(audio_state)
302
+ audio_state = audio_tools.unpatchify(audio_state)
303
+
304
+ # End-of-clip silence-prior fix.
305
+ # The base LTX-2.3 22B DiT was trained on audio clips ≤ ~20 s and
306
+ # learned a strong "clip-end silence" prior that lands on the next
307
+ # patchifier-aligned latent frame after 20 s — index 513 = 8*64+1.
308
+ # When inference produces longer audio, this prior leaks through as a
309
+ # high-norm latent burst at frame 513 (and adjacent 512), which the
310
+ # audio VAE + vocoder render as a ~30 ms hard silence dip near 20.4 s.
311
+ # Linear interpolation across the two affected frames removes the dip
312
+ # cleanly without any retraining. Only runs when the latent is long
313
+ # enough to actually contain the boundary.
314
+ latent = audio_state.latent
315
+ if latent.shape[2] > 513:
316
+ f0, f1 = 511, 514 # neighbours used for interpolation
317
+ n = f1 - f0 # = 3
318
+ patched = latent.clone()
319
+ for f in (512, 513):
320
+ t = (f - f0) / n
321
+ patched[:, :, f, :] = (1.0 - t) * latent[:, :, f0, :] + t * latent[:, :, f1, :]
322
+ latent = patched
323
+
324
+ t0 = time.time()
325
+ decoded = self._audio_decoder(latent)
326
+ logging.info(f"Decode: {time.time()-t0:.2f}s")
327
+
328
+ total = time.time() - t_total
329
+ dur = decoded.waveform.shape[-1] / decoded.sampling_rate
330
+ logging.info(f"Total: {total:.2f}s for {dur:.1f}s audio")
331
+ return decoded.waveform, decoded.sampling_rate
332
+
333
+ def generate_to_file(self, prompt, output, watermark: bool = True, **kwargs):
334
+ waveform, sr = self.generate(prompt, **kwargs)
335
+ wav_cpu = waveform.cpu().float()
336
+ if watermark:
337
+ try:
338
+ import numpy as np, perth
339
+ if not hasattr(self, "_perth"):
340
+ self._perth = perth.PerthImplicitWatermarker()
341
+ mono = wav_cpu.mean(dim=0).numpy() if wav_cpu.shape[0] > 1 else wav_cpu[0].numpy()
342
+ mono_wm = self._perth.apply_watermark(mono, sample_rate=sr)
343
+ mono_wm_t = torch.from_numpy(np.asarray(mono_wm, dtype=np.float32)).unsqueeze(0)
344
+ wav_cpu = mono_wm_t if wav_cpu.shape[0] == 1 else mono_wm_t.repeat(wav_cpu.shape[0], 1)
345
+ except Exception as e:
346
+ logging.warning(f"Perth watermark skipped ({e})")
347
+ torchaudio.save(output, wav_cpu, sr)
348
+ logging.info(f"Saved: {output}")
349
+ return output
350
+
351
+
352
+ if __name__ == "__main__":
353
+ import argparse
354
+ p = argparse.ArgumentParser()
355
+ p.add_argument("--device", default="cuda")
356
+ p.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
357
+ p.add_argument("--no-compile", action="store_true")
358
+ p.add_argument("--no-bnb-4bit", action="store_true",
359
+ help="Disable bitsandbytes 4-bit path (default: on, since the default "
360
+ "unsloth Gemma checkpoint is pre-quantized).")
361
+ args = p.parse_args()
362
+
363
+ server = TTSServer(device=args.device, dtype=args.dtype, compile_model=not args.no_compile,
364
+ bnb_4bit=not args.no_bnb_4bit)
365
+
366
+ # First call - includes any warmup
367
+ logging.info("=== First request ===")
368
+ server.generate_to_file(
369
+ prompt='A woman speaks clearly, "The weather today will be sunny."',
370
+ output="/tmp/warm_test1.wav",
371
+ voice_ref="/mnt/persistent0/manmay/expressive/female_radio_nikole/female_radio_nikole.wav",
372
+ )
373
+
374
+ # Second call - should be much faster (models already warm)
375
+ logging.info("\n=== Second request (warm) ===")
376
+ server.generate_to_file(
377
+ prompt='A man speaks excitedly, "This is amazing, I cannot believe it!"',
378
+ output="/tmp/warm_test2.wav",
379
+ voice_ref="/mnt/persistent0/manmay/expressive/male_arnie/male_arnie.mp3",
380
+ )
dramabox_src/model_downloader.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Download Dramabox models from HuggingFace.
4
+
5
+ Models are cached locally after first download.
6
+ Gemma text encoder is fetched separately from Google's repo.
7
+ """
8
+ import logging
9
+ import os
10
+ from pathlib import Path
11
+
12
+ from huggingface_hub import hf_hub_download, snapshot_download
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ DRAMABOX_REPO = "ResembleAI/Dramabox"
17
+ GEMMA_REPO = "unsloth/gemma-3-12b-it-bnb-4bit"
18
+
19
+ # Default cache directory
20
+ DEFAULT_CACHE = os.environ.get(
21
+ "DRAMABOX_CACHE",
22
+ os.path.join(os.path.expanduser("~"), ".cache", "dramabox"),
23
+ )
24
+
25
+ # Model files in the HF repo (flat structure)
26
+ MODEL_FILES = {
27
+ "transformer": "dramabox-dit-v1.safetensors",
28
+ "audio_components": "dramabox-audio-components.safetensors",
29
+ "silence_latent": "assets/silence_latent_frame.pt",
30
+ }
31
+
32
+
33
+ def get_model_path(name: str, cache_dir: str = None) -> str:
34
+ """Download a model file from HF and return local path.
35
+
36
+ Args:
37
+ name: One of 'transformer', 'audio_components', 'silence_latent'
38
+ cache_dir: Local cache directory (default: ~/.cache/dramabox)
39
+
40
+ Returns:
41
+ Local file path
42
+ """
43
+ cache_dir = cache_dir or DEFAULT_CACHE
44
+
45
+ if name not in MODEL_FILES:
46
+ raise ValueError(f"Unknown model: {name}. Choose from: {list(MODEL_FILES.keys())}")
47
+
48
+ repo_path = MODEL_FILES[name]
49
+ logger.info(f"Fetching {name} from {DRAMABOX_REPO}/{repo_path}...")
50
+
51
+ local_path = hf_hub_download(
52
+ repo_id=DRAMABOX_REPO,
53
+ filename=repo_path,
54
+ cache_dir=cache_dir,
55
+ token=os.environ.get("HF_TOKEN"),
56
+ )
57
+ logger.info(f" -> {local_path}")
58
+ return local_path
59
+
60
+
61
+ def get_gemma_path(cache_dir: str = None) -> str:
62
+ """Download Gemma 3 12B IT (pre-quantized bnb-4bit via unsloth) and return
63
+ the snapshot directory. Using the pre-quantized variant skips runtime
64
+ bitsandbytes quantization and ~halves the Gemma load time.
65
+ """
66
+ cache_dir = cache_dir or DEFAULT_CACHE
67
+ logger.info(f"Fetching Gemma from {GEMMA_REPO}...")
68
+
69
+ local_dir = snapshot_download(
70
+ repo_id=GEMMA_REPO,
71
+ cache_dir=cache_dir,
72
+ token=os.environ.get("HF_TOKEN"),
73
+ )
74
+ logger.info(f" -> {local_dir}")
75
+ return local_dir
76
+
77
+
78
+ def get_all_paths(cache_dir: str = None) -> dict:
79
+ """Download all required models and return paths dict.
80
+
81
+ Returns:
82
+ {
83
+ 'transformer': '/path/to/transformer.safetensors',
84
+ 'audio_components': '/path/to/audio-components.safetensors',
85
+ 'silence_latent': '/path/to/silence_latent_frame.pt',
86
+ 'gemma_root': '/path/to/unsloth/gemma-3-12b-it-bnb-4bit/',
87
+ }
88
+ """
89
+ cache_dir = cache_dir or DEFAULT_CACHE
90
+ paths = {}
91
+
92
+ for name in MODEL_FILES:
93
+ paths[name] = get_model_path(name, cache_dir)
94
+
95
+ paths["gemma_root"] = get_gemma_path(cache_dir)
96
+ return paths
97
+
98
+
99
+ if __name__ == "__main__":
100
+ logging.basicConfig(level=logging.INFO)
101
+ paths = get_all_paths()
102
+ print("\nAll models downloaded:")
103
+ for k, v in paths.items():
104
+ size = os.path.getsize(v) / 1e9 if os.path.isfile(v) else "dir"
105
+ print(f" {k}: {v} ({size:.2f}GB)" if isinstance(size, float) else f" {k}: {v} (directory)")
dramabox_src/preprocess.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Preprocess TTS datasets for LTX-2.3 audio-only LoRA fine-tuning.
4
+
5
+ Takes paired (audio, transcript) data and produces the format expected by
6
+ the LTX trainer:
7
+ .precomputed/
8
+ ├── latents/sample_N.pt # Dummy video latents (minimal)
9
+ ├── conditions/sample_N.pt # Text embeddings from Gemma
10
+ └── audio_latents/sample_N.pt # Audio VAE-encoded latents
11
+
12
+ Supports multiple dataset formats:
13
+ - gemini_synthetic: index.txt with ~-separated fields (id~speaker~lang~sr~samples~dur~phonemes~text)
14
+ - libriheavy: index_ft.txt with ~-separated fields (id~speaker~lang~samples~dur~phonemes~text)
15
+ - manifest: JSON/JSONL with {"audio_filepath": ..., "text": ...}
16
+ - tsv: TSV file with audio_path<TAB>text columns
17
+
18
+ Usage:
19
+ python preprocess_tts_data.py \
20
+ --dataset-type gemini_synthetic \
21
+ --index /mnt/large-datasets/gemini_synthetic_dataset/conversational_dataset_pp/index.txt \
22
+ --audio-dir /mnt/large-datasets/gemini_synthetic_dataset/conversational_dataset_pp/wavs \
23
+ --output-dir /mnt/persistent0/manmay/tts_training_data \
24
+ --max-samples 10000 \
25
+ --max-duration 20.0 \
26
+ --min-duration 3.0
27
+ """
28
+
29
+ import argparse
30
+ import json
31
+ import logging
32
+ import os
33
+ import sys
34
+ from pathlib import Path
35
+
36
+ import torch
37
+ import torchaudio
38
+
39
+ REPO_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
40
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ltx2"))
41
+ # ltx-pipelines on path via ltx2/
42
+
43
+ MODEL_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
44
+ GEMMA_DIR = os.environ.get("GEMMA_DIR", "gemma-3-12b-it-qat-q4_0-unquantized")
45
+
46
+
47
+ def parse_args():
48
+ p = argparse.ArgumentParser(description="Preprocess TTS data for LTX-2.3 fine-tuning")
49
+ p.add_argument("--dataset-type", required=True,
50
+ choices=["gemini_synthetic", "libriheavy", "manifest", "tsv"],
51
+ help="Dataset format type")
52
+ p.add_argument("--index", required=True, help="Path to index/manifest file")
53
+ p.add_argument("--audio-dir", default=None,
54
+ help="Base directory for audio files (if paths in index are relative)")
55
+ p.add_argument("--output-dir", required=True, help="Output directory for preprocessed data")
56
+ p.add_argument("--checkpoint", default=os.path.join(MODEL_DIR, "ltx-2.3-22b-distilled.safetensors"))
57
+ p.add_argument("--gemma-root", default=GEMMA_DIR)
58
+ p.add_argument("--max-samples", type=int, default=0, help="Max samples to process (0=all)")
59
+ p.add_argument("--max-duration", type=float, default=20.0, help="Max audio duration in seconds")
60
+ p.add_argument("--min-duration", type=float, default=2.0, help="Min audio duration in seconds")
61
+ p.add_argument("--batch-size", type=int, default=8, help="Batch size for text encoding")
62
+ p.add_argument("--skip-existing", action="store_true", help="Skip already processed samples")
63
+ p.add_argument("--audio-only-ckpt", default=None,
64
+ help="Audio-only checkpoint for VAE encoding (optional, uses full ckpt if not set)")
65
+ p.add_argument("--shard", type=int, default=0, help="Shard index (for parallel processing)")
66
+ p.add_argument("--num-shards", type=int, default=1, help="Total number of shards")
67
+ p.add_argument("--gpu", type=int, default=None, help="GPU device index to use")
68
+ return p.parse_args()
69
+
70
+
71
+ def parse_gemini_synthetic(index_path: str, audio_dir: str | None) -> list[dict]:
72
+ """Parse gemini_synthetic format: id~speaker~lang~sr~samples~dur~phonemes~text"""
73
+ samples = []
74
+ with open(index_path) as f:
75
+ for line in f:
76
+ parts = line.strip().split("~")
77
+ if len(parts) < 7:
78
+ continue
79
+ file_id = parts[0]
80
+ text = parts[-1] # Last field is always the text
81
+ sr = int(parts[3])
82
+ n_samples = int(parts[4])
83
+ duration = n_samples / sr
84
+
85
+ # Find audio file
86
+ if audio_dir:
87
+ # Try common extensions
88
+ for ext in [".flac", ".wav", ".mp3"]:
89
+ audio_path = os.path.join(audio_dir, file_id + ext)
90
+ if os.path.exists(audio_path):
91
+ break
92
+ else:
93
+ continue
94
+ else:
95
+ audio_path = file_id
96
+
97
+ samples.append({
98
+ "id": file_id,
99
+ "audio_path": audio_path,
100
+ "text": text,
101
+ "duration": duration,
102
+ })
103
+ return samples
104
+
105
+
106
+ def parse_libriheavy(index_path: str, audio_dir: str | None) -> list[dict]:
107
+ """Parse libriheavy format: id~speaker~lang~samples~dur~phonemes~text"""
108
+ samples = []
109
+ with open(index_path) as f:
110
+ for line in f:
111
+ parts = line.strip().split("~")
112
+ if len(parts) < 7:
113
+ continue
114
+ file_id = parts[0]
115
+ text = parts[-1]
116
+ n_samples = int(parts[3])
117
+ duration = int(parts[4]) / 1000.0 # milliseconds to seconds
118
+
119
+ if audio_dir:
120
+ for ext in [".flac", ".wav", ".mp3"]:
121
+ audio_path = os.path.join(audio_dir, file_id + ext)
122
+ if os.path.exists(audio_path):
123
+ break
124
+ else:
125
+ continue
126
+ else:
127
+ audio_path = file_id
128
+
129
+ samples.append({
130
+ "id": file_id,
131
+ "audio_path": audio_path,
132
+ "text": text,
133
+ "duration": duration,
134
+ })
135
+ return samples
136
+
137
+
138
+ def parse_manifest(index_path: str, audio_dir: str | None) -> list[dict]:
139
+ """Parse JSON/JSONL manifest with audio_filepath and text fields."""
140
+ samples = []
141
+ with open(index_path) as f:
142
+ for line in f:
143
+ entry = json.loads(line.strip())
144
+ audio_path = entry.get("audio_filepath", entry.get("audio_path", ""))
145
+ text = entry.get("text", entry.get("transcript", ""))
146
+ duration = entry.get("duration", 0.0)
147
+
148
+ if audio_dir and not os.path.isabs(audio_path):
149
+ audio_path = os.path.join(audio_dir, audio_path)
150
+
151
+ if os.path.exists(audio_path) and text:
152
+ samples.append({
153
+ "id": Path(audio_path).stem,
154
+ "audio_path": audio_path,
155
+ "text": text,
156
+ "duration": duration,
157
+ })
158
+ return samples
159
+
160
+
161
+ def parse_tsv(index_path: str, audio_dir: str | None) -> list[dict]:
162
+ """Parse TSV file with audio_path<TAB>text."""
163
+ samples = []
164
+ with open(index_path) as f:
165
+ for line in f:
166
+ parts = line.strip().split("\t")
167
+ if len(parts) < 2:
168
+ continue
169
+ audio_path, text = parts[0], parts[1]
170
+ if audio_dir and not os.path.isabs(audio_path):
171
+ audio_path = os.path.join(audio_dir, audio_path)
172
+ if os.path.exists(audio_path):
173
+ samples.append({
174
+ "id": Path(audio_path).stem,
175
+ "audio_path": audio_path,
176
+ "text": text,
177
+ "duration": 0.0,
178
+ })
179
+ return samples
180
+
181
+
182
+ PARSERS = {
183
+ "gemini_synthetic": parse_gemini_synthetic,
184
+ "libriheavy": parse_libriheavy,
185
+ "manifest": parse_manifest,
186
+ "tsv": parse_tsv,
187
+ }
188
+
189
+
190
+ @torch.inference_mode()
191
+ def main():
192
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
193
+ args = parse_args()
194
+
195
+ from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
196
+ from ltx_core.types import Audio
197
+ from ltx_pipelines.utils.blocks import AudioConditioner
198
+ from ltx_pipelines.utils.media_io import decode_audio_from_file
199
+ from ltx_trainer.model_loader import load_text_encoder, load_embeddings_processor
200
+
201
+ if args.gpu is not None:
202
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
203
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
204
+ dtype = torch.bfloat16
205
+
206
+ # Create output directories
207
+ out = Path(args.output_dir)
208
+ (out / "latents").mkdir(parents=True, exist_ok=True)
209
+ (out / "conditions").mkdir(parents=True, exist_ok=True)
210
+ (out / "audio_latents").mkdir(parents=True, exist_ok=True)
211
+
212
+ # Parse dataset
213
+ logging.info(f"Parsing {args.dataset_type} dataset from {args.index}...")
214
+ samples = PARSERS[args.dataset_type](args.index, args.audio_dir)
215
+ logging.info(f"Found {len(samples)} samples")
216
+
217
+ # Filter by duration
218
+ before = len(samples)
219
+ samples = [s for s in samples if args.min_duration <= s["duration"] <= args.max_duration]
220
+ logging.info(f"After duration filter [{args.min_duration}s, {args.max_duration}s]: {len(samples)} (dropped {before - len(samples)})")
221
+
222
+ if args.max_samples > 0:
223
+ samples = samples[:args.max_samples]
224
+ logging.info(f"Limiting to {len(samples)} samples")
225
+
226
+ # Assign global indices before sharding
227
+ for i, s in enumerate(samples):
228
+ s["global_idx"] = i
229
+
230
+ # Shard the data for parallel processing
231
+ if args.num_shards > 1:
232
+ total = len(samples)
233
+ samples = samples[args.shard::args.num_shards]
234
+ logging.info(f"Shard {args.shard}/{args.num_shards}: {len(samples)} samples (of {total} total)")
235
+
236
+ # ── Step 1: Encode text with Gemma (Blocks 1+2 only) ──
237
+ # The trainer runs Block 3 (embeddings processor/connectors) during training,
238
+ # so we only precompute Blocks 1+2 here (Gemma LLM + feature extractor).
239
+ logging.info("Loading text encoder (Gemma + feature extractor)...")
240
+ text_encoder = load_text_encoder(args.gemma_root, device=device, dtype=dtype)
241
+
242
+ # Load feature extractor on CPU first to save GPU memory, then move to device
243
+ logging.info("Loading feature extractor (on CPU first to save GPU memory)...")
244
+ emb_proc = load_embeddings_processor(args.checkpoint, device="cpu", dtype=dtype)
245
+ text_encoder.feature_extractor = emb_proc.feature_extractor.to(device)
246
+ del emb_proc
247
+ torch.cuda.empty_cache()
248
+
249
+ logging.info("Encoding text prompts (Blocks 1+2: Gemma + feature extractor)...")
250
+ for i, sample in enumerate(samples):
251
+ gidx = sample["global_idx"]
252
+ cond_path = out / "conditions" / f"sample_{gidx:06d}.pt"
253
+ if args.skip_existing and cond_path.exists():
254
+ continue
255
+
256
+ text = sample["text"]
257
+ # Run Blocks 1+2: Gemma LLM → feature extractor
258
+ hidden_states, attention_mask = text_encoder.encode(text)
259
+ video_feats, audio_feats = text_encoder.feature_extractor(
260
+ hidden_states, attention_mask, "left"
261
+ )
262
+
263
+ torch.save({
264
+ "video_prompt_embeds": video_feats.squeeze(0).cpu(),
265
+ "audio_prompt_embeds": audio_feats.squeeze(0).cpu() if audio_feats is not None else video_feats.squeeze(0).cpu(),
266
+ "prompt_attention_mask": attention_mask.squeeze(0).bool().cpu(),
267
+ }, cond_path)
268
+
269
+ if i % 100 == 0:
270
+ logging.info(f" Text encoding: {i}/{len(samples)}")
271
+
272
+ del text_encoder
273
+ torch.cuda.empty_cache()
274
+
275
+ # ── Step 2: Encode audio with Audio VAE ──
276
+ ckpt_for_vae = args.audio_only_ckpt or args.checkpoint
277
+ logging.info(f"Loading audio VAE from {ckpt_for_vae}...")
278
+
279
+ ac = AudioConditioner(checkpoint_path=ckpt_for_vae, dtype=dtype, device=device)
280
+
281
+ logging.info("Encoding audio samples...")
282
+ for idx, sample in enumerate(samples):
283
+ gidx = sample["global_idx"]
284
+ audio_path = out / "audio_latents" / f"sample_{gidx:06d}.pt"
285
+ if args.skip_existing and audio_path.exists():
286
+ continue
287
+
288
+ try:
289
+ # Load audio
290
+ voice = decode_audio_from_file(sample["audio_path"], device, 0.0, args.max_duration)
291
+ if voice is None:
292
+ logging.warning(f" Skipping {sample['id']}: no audio")
293
+ continue
294
+
295
+ w = voice.waveform
296
+ if w.dim() == 2:
297
+ if w.shape[0] == 1:
298
+ w = w.repeat(2, 1)
299
+ w = w.unsqueeze(0)
300
+ elif w.dim() == 3 and w.shape[1] == 1:
301
+ w = w.repeat(1, 2, 1)
302
+ voice = Audio(waveform=w, sampling_rate=voice.sampling_rate)
303
+
304
+ # Encode through Audio VAE
305
+ audio_latent = ac(lambda enc: vae_encode_audio(voice, enc, None))
306
+
307
+ # Save audio latent
308
+ torch.save({
309
+ "latents": audio_latent.squeeze(0).cpu(), # [C=8, T, F=16]
310
+ "sample_rate": 16000,
311
+ }, audio_path)
312
+
313
+ except Exception as e:
314
+ logging.warning(f" Skipping {sample['id']}: {e}")
315
+ continue
316
+
317
+ if idx % 100 == 0:
318
+ logging.info(f" Audio encoding: {idx}/{len(samples)}")
319
+
320
+ del ac
321
+ torch.cuda.empty_cache()
322
+
323
+ # ── Step 3: Create dummy video latents ──
324
+ logging.info("Creating dummy video latents...")
325
+ # Minimal video: 1 frame, 64x64 = 2x2 in latent space
326
+ dummy_video = {
327
+ "latents": torch.zeros(128, 1, 2, 2),
328
+ "num_frames": 1,
329
+ "height": 2,
330
+ "width": 2,
331
+ "fps": 24.0,
332
+ }
333
+ for idx, sample in enumerate(samples):
334
+ gidx = sample["global_idx"]
335
+ latent_path = out / "latents" / f"sample_{gidx:06d}.pt"
336
+ if args.skip_existing and latent_path.exists():
337
+ continue
338
+ torch.save(dummy_video, latent_path)
339
+
340
+ # ── Summary ──
341
+ n_audio = len(list((out / "audio_latents").glob("*.pt")))
342
+ n_cond = len(list((out / "conditions").glob("*.pt")))
343
+ n_lat = len(list((out / "latents").glob("*.pt")))
344
+ logging.info(f"\nDone! Output: {args.output_dir}")
345
+ logging.info(f" audio_latents: {n_audio} files")
346
+ logging.info(f" conditions: {n_cond} files")
347
+ logging.info(f" latents: {n_lat} files")
348
+
349
+
350
+ if __name__ == "__main__":
351
+ main()
dramabox_src/train.py ADDED
@@ -0,0 +1,882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Audio-Only IC-LoRA Training for Voice Cloning on LTX-2.3.
4
+
5
+ Uses the IC-LoRA pattern: reference audio tokens are APPENDED to the end of
6
+ the target sequence using AudioConditionByReferenceLatent. Loss is computed
7
+ only on target tokens; reference tokens remain clean (denoise_mask=0).
8
+
9
+ This follows the official video-to-video IC-LoRA strategy closely, but adapted
10
+ for the audio-only modality path.
11
+
12
+ Usage (single GPU):
13
+ CUDA_VISIBLE_DEVICES=0 python train_audio_iclora.py --data-dir ... --speaker-index ...
14
+
15
+ Usage (multi-GPU with accelerate):
16
+ CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --num_processes=4 train_audio_iclora.py ...
17
+ """
18
+
19
+ import argparse
20
+ import logging
21
+ import math
22
+ import os
23
+ import random
24
+ import shutil
25
+ import sys
26
+ import time
27
+ from collections import defaultdict
28
+ from pathlib import Path
29
+
30
+ import torch
31
+ import torch.nn.functional as F
32
+ from torch.utils.data import DataLoader, Dataset
33
+
34
+ REPO_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
35
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ltx2"))
36
+ # ltx-pipelines already on path via ltx2/
37
+
38
+ MODEL_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
39
+
40
+ # Import audio conditioning item from our module
41
+ sys.path.insert(0, MODEL_DIR)
42
+ from audio_conditioning import AudioConditionByReferenceLatent
43
+
44
+
45
+ # ─── Timestep Sampling ───
46
+
47
+ class DistilledTimestepSampler:
48
+ """Sample timesteps from the distilled sigma schedule.
49
+
50
+ The distilled model was trained to denoise at these specific sigma values.
51
+ We sample uniformly from the intervals between consecutive sigmas,
52
+ matching the distribution the model actually operates on.
53
+ """
54
+
55
+ # Distilled 8-step sigma values (boundaries of denoising intervals)
56
+ SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
57
+
58
+ def __init__(self, jitter: float = 0.02):
59
+ self.jitter = jitter
60
+
61
+ def sample(self, batch_size: int, seq_length: int = None, device: torch.device = None) -> torch.Tensor:
62
+ n_intervals = len(self.SIGMAS) - 1
63
+ interval_idx = torch.randint(0, n_intervals, (batch_size,), device=device)
64
+ t = torch.rand(batch_size, device=device)
65
+ sigma_high = torch.tensor([self.SIGMAS[i] for i in interval_idx], device=device)
66
+ sigma_low = torch.tensor([self.SIGMAS[i + 1] for i in interval_idx], device=device)
67
+ sigma = sigma_low + t * (sigma_high - sigma_low)
68
+ return sigma.clamp(0.01, 0.99)
69
+
70
+
71
+ class ShiftedLogitNormalTimestepSampler:
72
+ """Shifted logit-normal distribution, shift depends on sequence length."""
73
+
74
+ def __init__(self, std: float = 1.0, eps: float = 1e-3, uniform_prob: float = 0.1):
75
+ self.std = std
76
+ self.eps = eps
77
+ self.uniform_prob = uniform_prob
78
+ self.normal_999_percentile = 3.0902 * std
79
+ self.normal_005_percentile = -2.5758 * std
80
+
81
+ def sample(self, batch_size: int, seq_length: int, device: torch.device = None) -> torch.Tensor:
82
+ mu = self._get_shift(seq_length)
83
+ normal = torch.randn(batch_size, device=device) * self.std + mu
84
+ logitnormal = torch.sigmoid(normal)
85
+
86
+ p999 = torch.sigmoid(torch.tensor(mu + self.normal_999_percentile, device=device))
87
+ p005 = torch.sigmoid(torch.tensor(mu + self.normal_005_percentile, device=device))
88
+ stretched = (logitnormal - p005) / (p999 - p005)
89
+ stretched = torch.where(stretched >= self.eps, stretched, 2 * self.eps - stretched)
90
+ stretched = stretched.clamp(0, 1)
91
+
92
+ uniform = (1 - self.eps) * torch.rand(batch_size, device=device) + self.eps
93
+ prob = torch.rand(batch_size, device=device)
94
+ return torch.where(prob > self.uniform_prob, stretched, uniform)
95
+
96
+ @staticmethod
97
+ def _get_shift(seq_length, min_tok=1024, max_tok=4096, min_s=0.95, max_s=2.05):
98
+ m = (max_s - min_s) / (max_tok - min_tok)
99
+ return m * seq_length + (min_s - m * min_tok)
100
+
101
+
102
+ # ─── Dataset ───
103
+
104
+ def build_speaker_map(index_paths, data_dirs):
105
+ """Map speaker → [(data_dir, sample_idx)] from index file(s).
106
+
107
+ The sample index comes from field 0 of the `~`-delimited row when it
108
+ parses as int (allows subset indexes that keep original sample numbers),
109
+ otherwise we fall back to the row's line number (legacy behaviour for
110
+ string-keyed indexes like tts_training_data_podcast).
111
+ """
112
+ speaker_to_samples = defaultdict(list)
113
+ for index_path, data_dir in zip(index_paths, data_dirs):
114
+ with open(index_path) as f:
115
+ for line_num, line in enumerate(f):
116
+ parts = line.strip().split("~")
117
+ if len(parts) < 7:
118
+ continue
119
+ try:
120
+ idx = int(parts[0])
121
+ except ValueError:
122
+ idx = line_num
123
+ speaker_id = parts[1]
124
+ speaker_to_samples[speaker_id].append((data_dir, idx))
125
+ return {k: v for k, v in speaker_to_samples.items() if len(v) >= 2}
126
+
127
+
128
+ class IDLoRADataset(Dataset):
129
+ # Silence-latent reference loaded once, used to detect and strip any
130
+ # leading silence frames baked into the preprocessed audio_latents. The
131
+ # training loop ALREADY prepends 0-25 random silence frames, so we don't
132
+ # want accidental silence in the source data compounding on top.
133
+ _silence_ref = None
134
+
135
+ @classmethod
136
+ def _load_silence_ref(cls):
137
+ if cls._silence_ref is None:
138
+ p = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
139
+ "assets", "silence_latent_frame.pt")
140
+ if os.path.exists(p):
141
+ cls._silence_ref = torch.load(p, weights_only=True).float().squeeze() # [C, F]
142
+ return cls._silence_ref
143
+
144
+ def __init__(self, speaker_map):
145
+ self.samples = []
146
+ self.speaker_map = {}
147
+ for speaker, entries in speaker_map.items():
148
+ valid = []
149
+ for data_dir, idx in entries:
150
+ audio_path = Path(data_dir) / "audio_latents" / f"sample_{idx:06d}.pt"
151
+ cond_path = Path(data_dir) / "conditions" / f"sample_{idx:06d}.pt"
152
+ if audio_path.exists() and cond_path.exists():
153
+ valid.append((data_dir, idx))
154
+ if len(valid) >= 2:
155
+ self.speaker_map[speaker] = valid
156
+ for speaker, entries in self.speaker_map.items():
157
+ for entry in entries:
158
+ self.samples.append((entry, speaker))
159
+ IDLoRADataset._load_silence_ref()
160
+
161
+ def __len__(self):
162
+ return len(self.samples)
163
+
164
+ def _load_sample(self, data_dir, idx):
165
+ base = Path(data_dir)
166
+ audio = torch.load(base / "audio_latents" / f"sample_{idx:06d}.pt", weights_only=False)
167
+ # Prefer prefix-stripped text embeddings if they exist (re-encoded with
168
+ # just the quoted dialogue, dropping the "A woman says, " / "A man
169
+ # speaks with X accent, " scene-description prefix).
170
+ stripped = base / "conditions_stripped" / f"sample_{idx:06d}.pt"
171
+ cond_path = stripped if stripped.exists() else base / "conditions" / f"sample_{idx:06d}.pt"
172
+ cond = torch.load(cond_path, weights_only=False)
173
+ if isinstance(audio, dict):
174
+ audio = audio.get("audio_latent", audio.get("latent", list(audio.values())[0]))
175
+ if audio.dim() == 2:
176
+ audio = audio.unsqueeze(0)
177
+ audio_feats = cond.get("audio_prompt_embeds", cond.get("prompt_embeds"))
178
+ attn_mask = cond.get("prompt_attention_mask")
179
+ # The audio_connector has num_learnable_registers=128 and asserts the
180
+ # input sequence length is divisible by 128. Our new preprocessing
181
+ # saved trimmed conditions (dropping left-padding to save disk), which
182
+ # produces short/irregular sequence lengths. Left-pad back to the next
183
+ # multiple of 128 with zeros (matching the tokenizer's left-padding
184
+ # convention) so this assertion holds.
185
+ REG = 128
186
+ L = audio_feats.shape[0]
187
+ target_L = ((L + REG - 1) // REG) * REG
188
+ if target_L != L:
189
+ pad_len = target_L - L
190
+ pad_emb = torch.zeros(pad_len, audio_feats.shape[1],
191
+ dtype=audio_feats.dtype)
192
+ pad_mask = torch.zeros(pad_len, dtype=attn_mask.dtype)
193
+ audio_feats = torch.cat([pad_emb, audio_feats], dim=0)
194
+ attn_mask = torch.cat([pad_mask, attn_mask], dim=0)
195
+ return audio, audio_feats, attn_mask
196
+
197
+ def __getitem__(self, idx):
198
+ (data_dir, tgt_idx), speaker = self.samples[idx]
199
+ tgt_latent, audio_feats, attn_mask = self._load_sample(data_dir, tgt_idx)
200
+
201
+ # Drop the reference entirely for non-voice-cloning categories:
202
+ # - SFX samples (speaker starts with "sfx_"): descriptive sound events,
203
+ # no speaker identity to clone.
204
+ # - Song/music samples (suno dataset): prompts describe the music style,
205
+ # reference audio doesn't transfer anything useful.
206
+ # Return a zero-length ref so the model trains target-only for these.
207
+ drop_ref = speaker.startswith("sfx_") or "preprocessed_ltx_suno" in str(data_dir)
208
+ if drop_ref:
209
+ C, F_dim = tgt_latent.shape[0], tgt_latent.shape[2]
210
+ ref_latent = torch.zeros(C, 0, F_dim, dtype=tgt_latent.dtype)
211
+ else:
212
+ entries = self.speaker_map[speaker]
213
+ ref_entry = random.choice([e for e in entries if e[1] != tgt_idx])
214
+ ref_latent, _, _ = self._load_sample(*ref_entry)
215
+
216
+ return {
217
+ "tgt_latent": tgt_latent,
218
+ "ref_latent": ref_latent,
219
+ "audio_features": audio_feats,
220
+ "attention_mask": attn_mask,
221
+ }
222
+
223
+
224
+ # ─── Model building ───
225
+
226
+ def build_audio_only_model(checkpoint_path, device, dtype):
227
+ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
228
+ from ltx_core.loader.registry import DummyRegistry
229
+ from ltx_core.loader.sd_ops import SDOps
230
+ from ltx_core.model.transformer.model import LTXModel, LTXModelType
231
+ from ltx_core.model.model_protocol import ModelConfigurator
232
+ from ltx_core.model.transformer.attention import AttentionFunction
233
+ from ltx_core.model.transformer.rope import LTXRopeType
234
+
235
+ sd_ops = SDOps("AO").with_matching(prefix="model.diffusion_model.").with_replacement("model.diffusion_model.", "")
236
+
237
+ class Cfg(ModelConfigurator[LTXModel]):
238
+ @classmethod
239
+ def from_config(cls, config):
240
+ t = config.get("transformer", {})
241
+ cp = None
242
+ if not t.get("caption_proj_before_connector", False):
243
+ from ltx_core.model.transformer.text_projection import create_caption_projection
244
+ with torch.device("meta"):
245
+ cp = create_caption_projection(t, audio=True)
246
+ return LTXModel(
247
+ model_type=LTXModelType.AudioOnly,
248
+ audio_num_attention_heads=t.get("audio_num_attention_heads", 32),
249
+ audio_attention_head_dim=t.get("audio_attention_head_dim", 64),
250
+ audio_in_channels=t.get("audio_in_channels", 128),
251
+ audio_out_channels=t.get("audio_out_channels", 128),
252
+ num_layers=t.get("num_layers", 48),
253
+ audio_cross_attention_dim=t.get("audio_cross_attention_dim", 2048),
254
+ norm_eps=t.get("norm_eps", 1e-6),
255
+ attention_type=AttentionFunction(t.get("attention_type", "default")),
256
+ positional_embedding_theta=t.get("positional_embedding_theta", 10000.0),
257
+ audio_positional_embedding_max_pos=t.get("audio_positional_embedding_max_pos", [20]),
258
+ timestep_scale_multiplier=t.get("timestep_scale_multiplier", 1000),
259
+ use_middle_indices_grid=t.get("use_middle_indices_grid", True),
260
+ rope_type=LTXRopeType(t.get("rope_type", "interleaved")),
261
+ double_precision_rope=t.get("frequencies_precision", False) == "float64",
262
+ apply_gated_attention=t.get("apply_gated_attention", False),
263
+ audio_caption_projection=cp,
264
+ cross_attention_adaln=t.get("cross_attention_adaln", False),
265
+ )
266
+
267
+ builder = Builder(model_path=checkpoint_path, model_class_configurator=Cfg,
268
+ model_sd_ops=sd_ops, registry=DummyRegistry())
269
+ return builder.build(device=device, dtype=dtype)
270
+
271
+
272
+ def load_audio_connector(checkpoint_path, device, dtype):
273
+ # ltx-trainer already on path via ltx2/
274
+ from ltx_trainer.model_loader import load_embeddings_processor
275
+ emb_proc = load_embeddings_processor(checkpoint_path, device=device, dtype=dtype)
276
+ connector = emb_proc.audio_connector
277
+ del emb_proc
278
+ return connector
279
+
280
+
281
+ def apply_lora(model, rank, alpha, dropout=0.0):
282
+ from peft import LoraConfig, get_peft_model
283
+ config = LoraConfig(
284
+ r=rank, lora_alpha=alpha, lora_dropout=dropout, bias="none",
285
+ target_modules=[
286
+ # Self-attention over audio tokens (voice-transfer pathway via ref).
287
+ "audio_attn1.to_k", "audio_attn1.to_q", "audio_attn1.to_v", "audio_attn1.to_out.0",
288
+ # Cross-attention (audio ↔ text context) NOT adapted — keep base
289
+ # model's prompt→audio behaviour intact and rely on dataset balance
290
+ # to drive expressiveness. (v15c tried this with adaLN unfreeze,
291
+ # that proved too destructive; v16 tries it adaLN-frozen.)
292
+ # FFN — non-linear capacity for style/phonetic adaptation.
293
+ "audio_ff.net.0.proj", "audio_ff.net.2",
294
+ ],
295
+ )
296
+ model = get_peft_model(model, config)
297
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
298
+ total = sum(p.numel() for p in model.parameters())
299
+ logging.info(f"LoRA: {trainable:,} trainable / {total:,} total ({100*trainable/total:.1f}%)")
300
+ return model
301
+
302
+
303
+ @torch.no_grad()
304
+ def prepare_audio_context(audio_connector, audio_features, attention_mask, device, dtype):
305
+ from ltx_core.text_encoders.gemma.embeddings_processor import convert_to_additive_mask
306
+ audio_features = audio_features.to(device=device, dtype=dtype)
307
+ attention_mask = attention_mask.to(device=device)
308
+ if audio_features.shape[0] > 1:
309
+ results = []
310
+ for i in range(audio_features.shape[0]):
311
+ feat_i = audio_features[i:i+1]
312
+ mask_i = attention_mask[i:i+1]
313
+ additive = convert_to_additive_mask(mask_i, feat_i.dtype)
314
+ enc_i, _ = audio_connector(feat_i, additive)
315
+ results.append(enc_i)
316
+ return torch.cat(results, dim=0)
317
+ additive_mask = convert_to_additive_mask(attention_mask, audio_features.dtype)
318
+ audio_encoded, _ = audio_connector(audio_features, additive_mask)
319
+ return audio_encoded
320
+
321
+
322
+ # ─── Validation ───
323
+
324
+ def _unwrap_model_safe(model):
325
+ """Strip DDP / peft wrappers without going through accelerate.unwrap_model,
326
+ which imports deepspeed — broken in our env (torch API drift)."""
327
+ while hasattr(model, "module"):
328
+ model = model.module
329
+ return model
330
+
331
+
332
+ def run_validation(lora_path, val_config_path, output_dir, step, lora_rank=128):
333
+ """Call validate.py in a subprocess. It loads TTSServer (the same stack
334
+ the warm server / Gradio app uses), attaches our LoRA, then iterates every
335
+ entry in val_config with the same inference settings the user tests with.
336
+ Single subprocess amortises the model-load cost across all val entries.
337
+
338
+ Forces validation onto VAL_GPU (default "0") because training already
339
+ occupies the rest. Override via TRAIN_VAL_GPU env var.
340
+ """
341
+ import subprocess
342
+ val_dir = os.path.join(output_dir, "validation", f"step_{step:05d}")
343
+ os.makedirs(val_dir, exist_ok=True)
344
+ script = os.path.join(os.path.dirname(__file__), "validate.py")
345
+ cmd = [
346
+ sys.executable, script,
347
+ "--val-config", val_config_path,
348
+ "--output-dir", val_dir,
349
+ "--lora", lora_path,
350
+ "--lora-rank", str(lora_rank),
351
+ # Use raw estimator output (no +10% buffer) so we can hear
352
+ # whether the model needs more/less duration at current quality.
353
+ "--duration-multiplier", "1.0",
354
+ ]
355
+ log_path = os.path.join(val_dir, "validate.log")
356
+ env = os.environ.copy()
357
+ # Validation needs its OWN GPU (training fills the others).
358
+ env["CUDA_VISIBLE_DEVICES"] = os.environ.get("TRAIN_VAL_GPU", "0")
359
+ try:
360
+ with open(log_path, "w") as logf:
361
+ result = subprocess.run(
362
+ cmd, stdout=logf, stderr=subprocess.STDOUT, timeout=1800, env=env,
363
+ )
364
+ if result.returncode == 0:
365
+ logging.info(f" Validation step {step}: OK → {val_dir}")
366
+ else:
367
+ logging.warning(f" Validation step {step} FAILED (see {log_path})")
368
+ except subprocess.TimeoutExpired:
369
+ logging.warning(f" Validation step {step} TIMEOUT (>30min)")
370
+
371
+
372
+ # ─── Args ───
373
+
374
+ def parse_args():
375
+ # First pass: pull out --config so its values can become argparse defaults.
376
+ cfg_parser = argparse.ArgumentParser(add_help=False)
377
+ cfg_parser.add_argument("--config", default=None,
378
+ help="YAML file with default values for any of the flags below. "
379
+ "Explicit CLI flags still override the YAML.")
380
+ cfg_args, remaining = cfg_parser.parse_known_args()
381
+ yaml_defaults: dict = {}
382
+ if cfg_args.config:
383
+ import yaml as _yaml
384
+ with open(cfg_args.config) as f:
385
+ yaml_defaults = _yaml.safe_load(f) or {}
386
+ # YAML keys are dashes-or-underscores → normalize to argparse dest (underscore).
387
+ yaml_defaults = {k.replace("-", "_"): v for k, v in yaml_defaults.items()}
388
+
389
+ def _yaml(name, fallback):
390
+ return yaml_defaults.get(name, fallback)
391
+
392
+ p = argparse.ArgumentParser(
393
+ parents=[cfg_parser],
394
+ description="Audio-Only IC-LoRA Training for Voice Cloning",
395
+ )
396
+ p.add_argument("--data-dir", required="data_dir" not in yaml_defaults,
397
+ nargs="+", default=_yaml("data_dir", None))
398
+ p.add_argument("--speaker-index", required="speaker_index" not in yaml_defaults,
399
+ nargs="+", default=_yaml("speaker_index", None))
400
+ p.add_argument("--output-dir", default=_yaml("output_dir", os.path.join(MODEL_DIR, "tts_iclora_v1")))
401
+ p.add_argument("--checkpoint", default=_yaml("checkpoint", os.path.join(MODEL_DIR, "dramabox-dit-v1.safetensors")))
402
+ p.add_argument("--full-checkpoint", default=_yaml("full_checkpoint", os.path.join(MODEL_DIR, "dramabox-audio-components.safetensors")))
403
+ p.add_argument("--base-model", choices=["distilled", "dev"], default=_yaml("base_model", "dev"),
404
+ help="Base model type: distilled uses DistilledTimestepSampler, dev uses ShiftedLogitNormal")
405
+ p.add_argument("--lora-rank", type=int, default=_yaml("lora_rank", 128))
406
+ p.add_argument("--lora-alpha", type=int, default=_yaml("lora_alpha", 128))
407
+ p.add_argument("--lora-dropout", type=float, default=_yaml("lora_dropout", 0.0),
408
+ help="Dropout applied to LoRA A/B matrices during training. "
409
+ "Recommended ~0.1 for small datasets to regularize.")
410
+ p.add_argument("--resume-lora", default=_yaml("resume_lora", None))
411
+ p.add_argument("--resume-step-offset", type=int, default=_yaml("resume_step_offset", None),
412
+ help="Step to add when naming saved checkpoints. If None, inferred "
413
+ "from --resume-lora filename (e.g. lora_step_10000.safetensors → 10000). "
414
+ "Set to 0 to start numbering at 0 regardless.")
415
+ p.add_argument("--ref-ratio", type=float, default=_yaml("ref_ratio", 0.3),
416
+ help="Fraction of target length to use as reference (default 0.3)")
417
+ p.add_argument("--max-ref-tokens", type=int, default=_yaml("max_ref_tokens", 200),
418
+ help="Maximum reference tokens after patchification (default 200)")
419
+ p.add_argument("--text-dropout", type=float, default=_yaml("text_dropout", 0.0),
420
+ help="Probability of dropping text conditioning (forces reliance on voice ref)")
421
+ p.add_argument("--steps", type=int, default=_yaml("steps", 30000))
422
+ p.add_argument("--lr", type=float, default=_yaml("lr", 3e-5))
423
+ p.add_argument("--lr-scheduler", choices=["cosine", "linear", "constant"], default=_yaml("lr_scheduler", "cosine"))
424
+ p.add_argument("--batch-size", type=int, default=_yaml("batch_size", 1))
425
+ p.add_argument("--grad-accum", type=int, default=_yaml("grad_accum", 4))
426
+ p.add_argument("--max-grad-norm", type=float, default=_yaml("max_grad_norm", 1.0))
427
+ p.add_argument("--save-every", type=int, default=_yaml("save_every", 1000))
428
+ p.add_argument("--log-every", type=int, default=_yaml("log_every", 50))
429
+ p.add_argument("--seed", type=int, default=_yaml("seed", 42))
430
+ p.add_argument("--warmup-steps", type=int, default=_yaml("warmup_steps", 100))
431
+ p.add_argument("--val-config", default=_yaml("val_config", None))
432
+ return p.parse_args(remaining)
433
+
434
+
435
+ # ─── Main ───
436
+
437
+ def main():
438
+ from accelerate import Accelerator
439
+ from accelerate.utils import set_seed
440
+
441
+ args = parse_args()
442
+
443
+ accelerator = Accelerator(
444
+ gradient_accumulation_steps=args.grad_accum,
445
+ mixed_precision="bf16",
446
+ )
447
+
448
+ is_main = accelerator.is_main_process
449
+ if is_main:
450
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
451
+ else:
452
+ logging.basicConfig(level=logging.WARNING)
453
+
454
+ set_seed(args.seed)
455
+ device = accelerator.device
456
+ dtype = torch.bfloat16
457
+
458
+ os.makedirs(args.output_dir, exist_ok=True)
459
+
460
+ # Save training args
461
+ if is_main:
462
+ import yaml
463
+ args_dict = vars(args).copy()
464
+ args_dict["_meta"] = {
465
+ "world_size": accelerator.num_processes,
466
+ "dtype": str(dtype),
467
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
468
+ "script": "train_audio_iclora.py",
469
+ "pattern": "IC-LoRA (ref appended to end)",
470
+ }
471
+ with open(os.path.join(args.output_dir, "training_args.yaml"), "w") as f:
472
+ yaml.dump(args_dict, f, default_flow_style=False, sort_keys=False)
473
+
474
+ from ltx_core.components.patchifiers import AudioPatchifier
475
+ from ltx_core.model.transformer.modality import Modality
476
+ from ltx_core.guidance.perturbations import BatchedPerturbationConfig
477
+ from ltx_core.tools import AudioLatentTools
478
+ from ltx_core.types import AudioLatentShape, LatentState
479
+ from ltx_pipelines.utils.helpers import modality_from_latent_state, timesteps_from_mask
480
+
481
+ # Build speaker map
482
+ if is_main:
483
+ logging.info("Building speaker map...")
484
+ speaker_map = build_speaker_map(args.speaker_index, args.data_dir)
485
+ if is_main:
486
+ logging.info(f"Speaker map: {len(speaker_map)} speakers, "
487
+ f"{sum(len(v) for v in speaker_map.values())} samples")
488
+
489
+ # Load model
490
+ if is_main:
491
+ logging.info("Loading audio-only model...")
492
+ model = build_audio_only_model(args.checkpoint, device, dtype)
493
+
494
+ if is_main:
495
+ logging.info("Loading audio connector...")
496
+ audio_connector = load_audio_connector(args.full_checkpoint, device, dtype)
497
+ audio_connector.eval()
498
+ for p in audio_connector.parameters():
499
+ p.requires_grad = False
500
+
501
+ if is_main:
502
+ logging.info(f"Applying LoRA (rank={args.lora_rank}, alpha={args.lora_alpha})...")
503
+ model = apply_lora(model, args.lora_rank, args.lora_alpha, args.lora_dropout)
504
+
505
+ # Resume from checkpoint
506
+ if args.resume_lora:
507
+ from safetensors.torch import load_file as st_load
508
+ if is_main:
509
+ logging.info(f"Resuming from: {args.resume_lora}")
510
+ lora_sd = st_load(args.resume_lora)
511
+ mapped = {}
512
+ for k, v in lora_sd.items():
513
+ nk = k.replace(".lora_A.weight", ".lora_A.default.weight").replace(
514
+ ".lora_B.weight", ".lora_B.default.weight")
515
+ mapped[nk] = v
516
+ model.load_state_dict(mapped, strict=False)
517
+
518
+ # Determine step offset for save filenames. Without this, resuming a run
519
+ # restarts step numbering at 0 and would overwrite earlier phase-1
520
+ # checkpoints with the same save_every cadence.
521
+ if args.resume_step_offset is None:
522
+ resume_offset = 0
523
+ if args.resume_lora:
524
+ import re as _re
525
+ m = _re.search(r"lora_step_(\d+)", os.path.basename(args.resume_lora))
526
+ if m:
527
+ resume_offset = int(m.group(1))
528
+ args.resume_step_offset = resume_offset
529
+ if is_main and args.resume_step_offset:
530
+ logging.info(f"Save-step offset: +{args.resume_step_offset}")
531
+
532
+ model.train()
533
+ model.base_model.model.set_gradient_checkpointing(True)
534
+
535
+ # Dataset & DataLoader
536
+ dataset = IDLoRADataset(speaker_map)
537
+ if is_main:
538
+ logging.info(f"Dataset: {len(dataset)} samples, {len(dataset.speaker_map)} speakers")
539
+
540
+ def collate_fn(batch):
541
+ """Pad variable-length audio to max in batch, track real lengths for loss masking."""
542
+ max_tgt_T = max(b["tgt_latent"].shape[1] for b in batch) # [C, T, F]
543
+ max_ref_T = max(b["ref_latent"].shape[1] for b in batch)
544
+ C = batch[0]["tgt_latent"].shape[0]
545
+ F_dim = batch[0]["tgt_latent"].shape[2]
546
+
547
+ tgt_list, ref_list, feat_list, mask_list = [], [], [], []
548
+ tgt_lengths, ref_lengths = [], []
549
+
550
+ for b in batch:
551
+ tgt = b["tgt_latent"]
552
+ ref = b["ref_latent"]
553
+ tgt_lengths.append(tgt.shape[1])
554
+ ref_lengths.append(ref.shape[1])
555
+
556
+ if tgt.shape[1] < max_tgt_T:
557
+ pad = torch.zeros(C, max_tgt_T - tgt.shape[1], F_dim, dtype=tgt.dtype)
558
+ tgt = torch.cat([tgt, pad], dim=1)
559
+ tgt_list.append(tgt)
560
+
561
+ if ref.shape[1] < max_ref_T:
562
+ pad = torch.zeros(C, max_ref_T - ref.shape[1], F_dim, dtype=ref.dtype)
563
+ ref = torch.cat([ref, pad], dim=1)
564
+ ref_list.append(ref)
565
+
566
+ feat_list.append(b["audio_features"])
567
+ mask_list.append(b["attention_mask"])
568
+
569
+ return {
570
+ "tgt_latent": torch.stack(tgt_list),
571
+ "ref_latent": torch.stack(ref_list),
572
+ "audio_features": torch.stack(feat_list),
573
+ "attention_mask": torch.stack(mask_list),
574
+ "tgt_lengths": torch.tensor(tgt_lengths),
575
+ "ref_lengths": torch.tensor(ref_lengths),
576
+ }
577
+
578
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2,
579
+ pin_memory=True, drop_last=True, collate_fn=collate_fn)
580
+
581
+ # Optimizer & Scheduler
582
+ optimizer = torch.optim.AdamW(
583
+ [p for p in model.parameters() if p.requires_grad],
584
+ lr=args.lr, betas=(0.9, 0.999), weight_decay=0.01,
585
+ )
586
+
587
+ from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR, ConstantLR
588
+ warmup = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=args.warmup_steps)
589
+ remaining = args.steps - args.warmup_steps
590
+ if args.lr_scheduler == "cosine":
591
+ # Warmup -> constant hold (20% of remaining) -> cosine decay
592
+ hold_steps = max(remaining // 5, 0)
593
+ decay_steps = max(remaining - hold_steps, 1)
594
+ hold_sched = ConstantLR(optimizer, factor=1.0, total_iters=hold_steps)
595
+ decay_sched = CosineAnnealingLR(optimizer, T_max=decay_steps, eta_min=1e-6)
596
+ scheduler = SequentialLR(
597
+ optimizer,
598
+ [warmup, hold_sched, decay_sched],
599
+ milestones=[args.warmup_steps, args.warmup_steps + hold_steps],
600
+ )
601
+ elif args.lr_scheduler == "linear":
602
+ main_sched = LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=max(remaining, 1))
603
+ scheduler = SequentialLR(optimizer, [warmup, main_sched], milestones=[args.warmup_steps])
604
+ else:
605
+ main_sched = ConstantLR(optimizer, factor=1.0, total_iters=max(remaining, 1))
606
+ scheduler = SequentialLR(optimizer, [warmup, main_sched], milestones=[args.warmup_steps])
607
+
608
+ # Prepare with Accelerate — but NOT the scheduler. AcceleratedScheduler
609
+ # calls the underlying scheduler.step() `num_processes` times per sync,
610
+ # which silently scales down our warmup/cosine spans by that factor.
611
+ # We call scheduler.step() ourselves, gated on sync_gradients → exactly
612
+ # one advance per optimizer step, as the yaml spec intends.
613
+ model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
614
+
615
+ patchifier = AudioPatchifier(patch_size=1)
616
+
617
+ # Select timestep sampler based on base model type
618
+ if args.base_model == "distilled":
619
+ timestep_sampler = DistilledTimestepSampler()
620
+ if is_main:
621
+ logging.info("Using DistilledTimestepSampler (matching distilled model sigmas)")
622
+ else:
623
+ timestep_sampler = ShiftedLogitNormalTimestepSampler()
624
+ if is_main:
625
+ logging.info("Using ShiftedLogitNormalTimestepSampler (dev model)")
626
+
627
+ # Training loop
628
+ if is_main:
629
+ logging.info(f"Training: {args.steps} steps, lr={args.lr}, scheduler={args.lr_scheduler}, "
630
+ f"batch={args.batch_size}, grad_accum={args.grad_accum}, "
631
+ f"world_size={accelerator.num_processes}, "
632
+ f"ref_ratio={args.ref_ratio}, max_ref_tokens={args.max_ref_tokens}")
633
+ logging.info("IC-LoRA pattern: ref tokens APPENDED to target, loss on target only")
634
+
635
+ data_iter = iter(dataloader)
636
+ step = 0
637
+ accum_loss = 0.0
638
+ best_loss = float("inf")
639
+ best_step = 0
640
+ t0 = time.time()
641
+
642
+ total_micro_steps = args.steps * args.grad_accum
643
+
644
+ for micro_step in range(total_micro_steps):
645
+ try:
646
+ batch = next(data_iter)
647
+ except StopIteration:
648
+ data_iter = iter(dataloader)
649
+ batch = next(data_iter)
650
+
651
+ is_opt_step = (micro_step + 1) % args.grad_accum == 0
652
+ if is_opt_step:
653
+ step += 1
654
+
655
+ with accelerator.accumulate(model):
656
+ tgt_latent = batch["tgt_latent"].to(dtype=dtype) # [B, C, max_tgt_T, F]
657
+ ref_latent = batch["ref_latent"].to(dtype=dtype) # [B, C, max_ref_T, F]
658
+ tgt_lengths = batch["tgt_lengths"].to(device=device) # [B]
659
+ B = tgt_latent.shape[0]
660
+
661
+ # ── Random silence padding (0-1s) ── ltx_audio_tts baseline.
662
+ # User observed reference-audio leak at end of generations when this
663
+ # was reduced to 5 (v14) or 10 frames (v16/v17) — the model seemed
664
+ # to use the extra target budget to regurgitate ref content. Full
665
+ # 25 frames (0-1s avg 500ms) was apparently load-bearing for
666
+ # regularising the boundary and reducing hallucinations.
667
+ # Uses the real silence latent (not zeros) so the VAE decodes it as
668
+ # true silence instead of static noise.
669
+ max_pad_frames = 25 # ~1s at 25 latent frames/sec
670
+ pad_frames = random.randint(0, max_pad_frames)
671
+ if pad_frames > 0:
672
+ C, F_dim = tgt_latent.shape[1], tgt_latent.shape[3]
673
+ if not hasattr(args, '_silence_frame') or args._silence_frame is None:
674
+ _sf_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "assets", "silence_latent_frame.pt")
675
+ if os.path.exists(_sf_path):
676
+ args._silence_frame = torch.load(_sf_path, weights_only=True) # [C, 1, F]
677
+ if is_main:
678
+ logging.info(f"Loaded silence latent from {_sf_path}")
679
+ else:
680
+ args._silence_frame = False # fallback to zeros
681
+ if is_main:
682
+ logging.warning(f"silence_latent_frame.pt not found, using zeros")
683
+ if args._silence_frame is not False:
684
+ sf = args._silence_frame.to(dtype=dtype, device=device) # [C, 1, F]
685
+ silence_pad = sf.unsqueeze(0).expand(B, -1, pad_frames, -1) # [B, C, pad, F]
686
+ else:
687
+ silence_pad = torch.zeros(B, C, pad_frames, F_dim, dtype=dtype, device=device)
688
+ tgt_latent = torch.cat([silence_pad, tgt_latent], dim=2)
689
+
690
+ # Cap reference to max_ref_tokens (in latent frames, before patchification)
691
+ # After patchification, ref_T tokens = ref frames (patch_size=1)
692
+ ref_T_frames = min(ref_latent.shape[2], args.max_ref_tokens)
693
+ ref_latent = ref_latent[:, :, :ref_T_frames, :]
694
+
695
+ tgt_T_frames = tgt_latent.shape[2] # max (padded) target frames
696
+
697
+ # ── Step 1: Create target AudioLatentShape and AudioLatentTools ──
698
+ tgt_shape = AudioLatentShape(
699
+ batch=B,
700
+ channels=tgt_latent.shape[1], # 8
701
+ frames=tgt_T_frames,
702
+ mel_bins=tgt_latent.shape[3], # 16
703
+ )
704
+
705
+ audio_tools = AudioLatentTools(
706
+ patchifier=patchifier,
707
+ target_shape=tgt_shape,
708
+ )
709
+
710
+ # ── Step 2: Create initial state from target latent ──
711
+ # create_initial_state patchifies: [B, C, T, F] -> [B, T, C*F]
712
+ # Also creates denoise_mask=1 (all target tokens will be denoised)
713
+ # and computes temporal positions
714
+ state = audio_tools.create_initial_state(
715
+ device=device,
716
+ dtype=dtype,
717
+ initial_latent=tgt_latent,
718
+ )
719
+ # state.latent: [B, tgt_T, 128], state.denoise_mask: [B, tgt_T, 1]
720
+ # state.positions: [B, 1, tgt_T, 2]
721
+
722
+ tgt_T = audio_tools.target_shape.token_count() # = tgt_T_frames
723
+
724
+ # ── Step 3: Apply flow-matching noise to target BEFORE appending ref ──
725
+ # Sample sigma
726
+ total_tokens = tgt_T + ref_T_frames
727
+ sigma = timestep_sampler.sample(B, total_tokens, device=device)
728
+ sigma_exp = sigma.view(-1, 1, 1) # [B, 1, 1]
729
+
730
+ noise = torch.randn_like(state.latent) # [B, tgt_T, 128]
731
+ noisy_tgt = (1 - sigma_exp) * state.latent + sigma_exp * noise
732
+
733
+ # Replace the latent in state with the noisy version
734
+ # (clean_latent stays clean for post_process_latent pattern)
735
+ state = LatentState(
736
+ latent=noisy_tgt,
737
+ denoise_mask=state.denoise_mask,
738
+ positions=state.positions,
739
+ clean_latent=state.clean_latent,
740
+ attention_mask=state.attention_mask,
741
+ )
742
+
743
+ # ── Step 4: Append reference tokens using AudioConditionByReferenceLatent ──
744
+ # This appends ref tokens to the END with denoise_mask=0 (frozen/clean)
745
+ # Skip entirely when ref_T=0 (SFX / song samples): the model trains
746
+ # target-only for those categories since there's no voice to clone.
747
+ if ref_T_frames > 0:
748
+ ref_conditioning = AudioConditionByReferenceLatent(
749
+ latent=ref_latent,
750
+ strength=1.0, # 1.0 = ref fully clean (denoise_mask=0)
751
+ )
752
+ state = ref_conditioning.apply_to(
753
+ latent_state=state,
754
+ latent_tools=audio_tools,
755
+ )
756
+ # state.latent: [B, tgt_T + ref_T, 128]
757
+ # state.denoise_mask: [B, tgt_T + ref_T, 1]
758
+ # target tokens: 1.0 (denoise), ref tokens: 0.0 (frozen)
759
+ # state.positions: [B, 1, tgt_T + ref_T, 2]
760
+
761
+ # ── Step 5: Build loss mask for target tokens (excluding padding) ──
762
+ # loss_mask: 1 for real target tokens, 0 for padding and ref tokens
763
+ loss_mask = torch.zeros(B, tgt_T, device=device)
764
+ for b_idx in range(B):
765
+ real_len = min(tgt_lengths[b_idx].item(), tgt_T)
766
+ loss_mask[b_idx, :real_len] = 1.0
767
+
768
+ # ── Step 6: Prepare text context ──
769
+ # Text conditioning dropout: randomly zero out text context to force
770
+ # the model to rely on the voice reference for identity/style.
771
+ with torch.no_grad():
772
+ audio_context = prepare_audio_context(
773
+ audio_connector, batch["audio_features"],
774
+ batch["attention_mask"], device, dtype)
775
+ if args.text_dropout > 0 and random.random() < args.text_dropout:
776
+ audio_context = torch.zeros_like(audio_context)
777
+
778
+ # ── Step 7: Build Modality using modality_from_latent_state ──
779
+ # timesteps = sigma * denoise_mask (ref gets 0, target gets sigma)
780
+ audio_mod = modality_from_latent_state(
781
+ state=state,
782
+ context=audio_context,
783
+ sigma=sigma,
784
+ enabled=True,
785
+ )
786
+
787
+ # ── Step 8: Forward pass ──
788
+ perturbations = BatchedPerturbationConfig.empty(B)
789
+ with torch.autocast(device_type="cuda", dtype=dtype):
790
+ _, velocity_pred = model(video=None, audio=audio_mod, perturbations=perturbations)
791
+
792
+ # ── Step 9: Compute loss (IC-LoRA pattern) ──
793
+ # Target is at the FRONT (indices 0..tgt_T), ref at the END
794
+ # velocity target = noise - clean
795
+ tgt_patchified = audio_tools.patchifier.patchify(tgt_latent) # [B, tgt_T, 128]
796
+ target_velocity = noise - tgt_patchified
797
+
798
+ # Extract target portion of prediction
799
+ pred_tgt = velocity_pred[:, :tgt_T] # [B, tgt_T, 128]
800
+
801
+ # MSE loss with mask: only on real target tokens (not padding or ref)
802
+ per_token_mse = (pred_tgt - target_velocity).pow(2).mean(dim=-1) # [B, tgt_T]
803
+ loss = per_token_mse.mul(loss_mask).div(loss_mask.mean().clamp(min=1e-6)).mean()
804
+
805
+ accelerator.backward(loss)
806
+
807
+ if accelerator.sync_gradients and args.max_grad_norm > 0:
808
+ accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
809
+
810
+ optimizer.step()
811
+ optimizer.zero_grad()
812
+ # Only advance the LR scheduler once per OPTIMIZER step (not per
813
+ # micro-step). Mirrors AcceleratedOptimizer.step() which is
814
+ # internally gated on sync_gradients.
815
+ if accelerator.sync_gradients:
816
+ scheduler.step()
817
+
818
+ accum_loss += loss.item()
819
+
820
+ # Logging & saving on optimization steps only
821
+ if is_opt_step and step % args.log_every == 0 and is_main:
822
+ avg_loss = accum_loss / (args.log_every * args.grad_accum)
823
+ lr = optimizer.param_groups[0]["lr"]
824
+ elapsed = time.time() - t0
825
+ sps = step / elapsed if elapsed > 0 else 0
826
+ eta = (args.steps - step) / sps if sps > 0 else 0
827
+ logging.info(
828
+ f"Step {step}/{args.steps} | loss={avg_loss:.4f} | lr={lr:.2e} | "
829
+ f"tgt_T={tgt_T} ref_T={ref_T_frames} total={tgt_T + ref_T_frames} | "
830
+ f"{sps:.1f} steps/s | ETA {eta/60:.0f}min"
831
+ )
832
+
833
+ # Save best whenever loss improves — no warmup gate, so we can
834
+ # observe best checkpoints during warmup too.
835
+ if avg_loss < best_loss:
836
+ best_loss = avg_loss
837
+ old_best = os.path.join(args.output_dir, f"best_step_{best_step:05d}.safetensors")
838
+ best_step = step + args.resume_step_offset
839
+ new_best = os.path.join(args.output_dir, f"best_step_{best_step:05d}.safetensors")
840
+ unwrapped = _unwrap_model_safe(model)
841
+ unwrapped.save_pretrained(args.output_dir)
842
+ adapter = os.path.join(args.output_dir, "adapter_model.safetensors")
843
+ if os.path.exists(adapter):
844
+ shutil.copy(adapter, new_best)
845
+ if old_best != new_best and os.path.exists(old_best):
846
+ os.remove(old_best)
847
+ logging.info(f"New best: loss={best_loss:.4f} at step {best_step}")
848
+
849
+ accum_loss = 0.0
850
+
851
+ if is_opt_step and step % args.save_every == 0 and is_main:
852
+ global_step = step + args.resume_step_offset
853
+ save_path = os.path.join(args.output_dir, f"lora_step_{global_step:05d}.safetensors")
854
+ logging.info(f"Saving: {save_path}")
855
+ unwrapped = _unwrap_model_safe(model)
856
+ unwrapped.save_pretrained(args.output_dir)
857
+ adapter = os.path.join(args.output_dir, "adapter_model.safetensors")
858
+ if os.path.exists(adapter):
859
+ shutil.copy(adapter, save_path)
860
+
861
+ if args.val_config:
862
+ logging.info(f"Running validation at step {global_step}...")
863
+ model.eval()
864
+ run_validation(save_path, args.val_config, args.output_dir, global_step,
865
+ lora_rank=args.lora_rank)
866
+ model.train()
867
+
868
+ # Final save
869
+ if is_main:
870
+ unwrapped = _unwrap_model_safe(model)
871
+ unwrapped.save_pretrained(args.output_dir)
872
+ adapter = os.path.join(args.output_dir, "adapter_model.safetensors")
873
+ global_step = step + args.resume_step_offset
874
+ save_path = os.path.join(args.output_dir, f"lora_step_{global_step:05d}.safetensors")
875
+ if os.path.exists(adapter):
876
+ shutil.copy(adapter, save_path)
877
+ logging.info(f"Training complete! {step} steps in {time.time()-t0:.0f}s")
878
+ logging.info(f"Best loss: {best_loss:.4f} at step {best_step}")
879
+
880
+
881
+ if __name__ == "__main__":
882
+ main()
dramabox_src/validate.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Warm validation runner — loads base dev + LoRA + all aux models ONCE,
3
+ then iterates every speaker in val_config generating each output.
4
+
5
+ Matches the same generation path as inference.py but keeps Gemma / audio VAE
6
+ / velocity model / audio decoder resident across entries. Inference
7
+ settings default to the Gradio warm-server values (cfg=2.5, stg=1.5,
8
+ modality=1.0, rescale=0, 30 steps, fps=25) — use --inference-params to
9
+ override.
10
+ """
11
+ import argparse
12
+ import logging
13
+ import os
14
+ import sys
15
+ import time
16
+ import traceback
17
+
18
+ import torch
19
+ import torchaudio
20
+
21
+ REPO_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
22
+ MODEL_DIR = REPO_DIR
23
+ sys.path.insert(0, os.path.join(REPO_DIR, "ltx2"))
24
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
25
+
26
+ DEV_FULL_CKPT = os.environ.get(
27
+ "LTX_FULL_CHECKPOINT",
28
+ os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ltx-2.3-22b-dev.safetensors"),
29
+ )
30
+ GEMMA_ROOT = os.environ.get(
31
+ "GEMMA_ROOT",
32
+ os.path.expanduser("~/.cache/dramabox/gemma-3-12b-it-bnb-4bit"),
33
+ )
34
+
35
+
36
+ def parse_args():
37
+ p = argparse.ArgumentParser()
38
+ p.add_argument("--val-config", required=True)
39
+ p.add_argument("--output-dir", required=True)
40
+ p.add_argument("--lora", default=None)
41
+ p.add_argument("--lora-rank", type=int, default=128)
42
+ p.add_argument("--full-checkpoint", default=DEV_FULL_CKPT)
43
+ p.add_argument("--gemma-root", default=GEMMA_ROOT)
44
+ p.add_argument("--cfg-scale", type=float, default=2.5)
45
+ p.add_argument("--stg-scale", type=float, default=1.5)
46
+ p.add_argument("--rescale-scale", type=float, default=0.0)
47
+ p.add_argument("--modality-scale", type=float, default=1.0)
48
+ p.add_argument("--steps", type=int, default=30)
49
+ p.add_argument("--fps", type=float, default=25.0)
50
+ p.add_argument("--stg-block", type=int, default=29)
51
+ p.add_argument("--cfg-clamp", type=float, default=0.0)
52
+ p.add_argument("--seed", type=int, default=42)
53
+ p.add_argument("--duration-multiplier", type=float, default=1.1)
54
+ # Match Gradio / inference_server.py DEFAULT_NEG exactly
55
+ p.add_argument("--negative-prompt", default=(
56
+ "worst quality, inconsistent, robotic, distorted, noise, static, "
57
+ "muffled, unclear, unnatural, monotone"
58
+ ))
59
+ return p.parse_args()
60
+
61
+
62
+ def estimate_speech_duration(prompt: str, speed: float = 1.0) -> float:
63
+ import re
64
+ quoted = re.findall(r'"([^"]*)"', prompt) or re.findall(r"'([^']*)'", prompt)
65
+ text = " ".join(quoted) if quoted else prompt
66
+ duration = len(text) * 0.065 / max(speed, 0.1) + 1.5
67
+ return max(3.0, round(duration, 1))
68
+
69
+
70
+ class WarmValidator:
71
+ def __init__(self, full_checkpoint, gemma_root, lora_path=None, lora_rank=128,
72
+ device="cuda", dtype=torch.bfloat16):
73
+ from audio_conditioning import AudioConditionByReferenceLatent # noqa: F401 (imported by inference.py)
74
+ from ltx_core.components.patchifiers import AudioPatchifier
75
+ from ltx_pipelines.utils.blocks import PromptEncoder, AudioConditioner, AudioDecoder
76
+
77
+ self.device = torch.device(device)
78
+ self.dtype = dtype
79
+ self.full_checkpoint = full_checkpoint
80
+ self.gemma_root = gemma_root
81
+ self.patchifier = AudioPatchifier(patch_size=1)
82
+
83
+ logging.info("Loading PromptEncoder (Gemma + embeddings_processor)...")
84
+ t0 = time.time()
85
+ self.prompt_encoder = PromptEncoder(
86
+ checkpoint_path=full_checkpoint, gemma_root=gemma_root,
87
+ dtype=dtype, device=self.device, warm=True, audio_only=True,
88
+ )
89
+ logging.info(f" PromptEncoder ready in {time.time()-t0:.1f}s")
90
+
91
+ logging.info("Loading AudioConditioner (audio VAE encoder)...")
92
+ t0 = time.time()
93
+ self.audio_conditioner = AudioConditioner(
94
+ checkpoint_path=full_checkpoint, dtype=dtype, device=self.device, warm=True,
95
+ )
96
+ logging.info(f" AudioConditioner ready in {time.time()-t0:.1f}s")
97
+
98
+ logging.info("Loading AudioDecoder...")
99
+ t0 = time.time()
100
+ self.audio_decoder = AudioDecoder(
101
+ checkpoint_path=full_checkpoint, dtype=dtype, device=self.device, warm=True,
102
+ )
103
+ logging.info(f" AudioDecoder ready in {time.time()-t0:.1f}s")
104
+
105
+ logging.info("Building velocity model (audio-only from base dev)...")
106
+ t0 = time.time()
107
+ self.velocity_model = self._build_velocity_model(full_checkpoint, lora_path, lora_rank)
108
+ logging.info(f" Velocity model ready in {time.time()-t0:.1f}s "
109
+ f"({sum(p.numel() for p in self.velocity_model.parameters()) / 1e9:.1f}B params)")
110
+
111
+ def _build_velocity_model(self, checkpoint_path, lora_path, lora_rank):
112
+ from ltx_core.loader.registry import DummyRegistry
113
+ from ltx_core.loader.sd_ops import SDOps
114
+ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
115
+ from ltx_core.model.model_protocol import ModelConfigurator
116
+ from ltx_core.model.transformer.attention import AttentionFunction
117
+ from ltx_core.model.transformer.model import LTXModel, LTXModelType
118
+ from ltx_core.model.transformer.rope import LTXRopeType
119
+
120
+ sd_ops = (
121
+ SDOps("AO")
122
+ .with_matching(prefix="model.diffusion_model.")
123
+ .with_replacement("model.diffusion_model.", "")
124
+ )
125
+
126
+ class Cfg(ModelConfigurator[LTXModel]):
127
+ @classmethod
128
+ def from_config(cls, config):
129
+ t = config.get("transformer", {})
130
+ cp = None
131
+ if not t.get("caption_proj_before_connector", False):
132
+ from ltx_core.model.transformer.text_projection import create_caption_projection
133
+ with torch.device("meta"):
134
+ cp = create_caption_projection(t, audio=True)
135
+ return LTXModel(
136
+ model_type=LTXModelType.AudioOnly,
137
+ audio_num_attention_heads=t.get("audio_num_attention_heads", 32),
138
+ audio_attention_head_dim=t.get("audio_attention_head_dim", 64),
139
+ audio_in_channels=t.get("audio_in_channels", 128),
140
+ audio_out_channels=t.get("audio_out_channels", 128),
141
+ num_layers=t.get("num_layers", 48),
142
+ audio_cross_attention_dim=t.get("audio_cross_attention_dim", 2048),
143
+ norm_eps=t.get("norm_eps", 1e-6),
144
+ attention_type=AttentionFunction(t.get("attention_type", "default")),
145
+ positional_embedding_theta=10000.0,
146
+ audio_positional_embedding_max_pos=[20.0],
147
+ timestep_scale_multiplier=t.get("timestep_scale_multiplier", 1000),
148
+ use_middle_indices_grid=t.get("use_middle_indices_grid", True),
149
+ rope_type=LTXRopeType(t.get("rope_type", "interleaved")),
150
+ double_precision_rope=t.get("frequencies_precision", False) == "float64",
151
+ apply_gated_attention=t.get("apply_gated_attention", False),
152
+ audio_caption_projection=cp,
153
+ cross_attention_adaln=t.get("cross_attention_adaln", False),
154
+ )
155
+
156
+ builder = Builder(
157
+ model_path=checkpoint_path, model_class_configurator=Cfg,
158
+ model_sd_ops=sd_ops, registry=DummyRegistry(),
159
+ )
160
+ velocity = builder.build(device=self.device, dtype=self.dtype).to(self.device).eval()
161
+
162
+ if lora_path and os.path.exists(lora_path):
163
+ from peft import LoraConfig, get_peft_model
164
+ from safetensors.torch import load_file as st_load
165
+ logging.info(f"Attaching LoRA: {lora_path}")
166
+ lora_sd = st_load(lora_path)
167
+ is_peft = any("base_model.model." in k for k in lora_sd.keys())
168
+ is_iclora = any("diffusion_model." in k for k in lora_sd.keys())
169
+ cfg = LoraConfig(
170
+ r=lora_rank, lora_alpha=lora_rank, lora_dropout=0.0, bias="none",
171
+ target_modules=[
172
+ "audio_attn1.to_k", "audio_attn1.to_q",
173
+ "audio_attn1.to_v", "audio_attn1.to_out.0",
174
+ "audio_attn2.to_k", "audio_attn2.to_q",
175
+ "audio_attn2.to_v", "audio_attn2.to_out.0",
176
+ "audio_ff.net.0.proj", "audio_ff.net.2",
177
+ ],
178
+ )
179
+ velocity = get_peft_model(velocity, cfg)
180
+
181
+ if is_peft:
182
+ mapped = {}
183
+ for k, v in lora_sd.items():
184
+ nk = k
185
+ if ".lora_A.weight" in k and ".lora_A.default.weight" not in k:
186
+ nk = k.replace(".lora_A.weight", ".lora_A.default.weight")
187
+ if ".lora_B.weight" in k and ".lora_B.default.weight" not in k:
188
+ nk = k.replace(".lora_B.weight", ".lora_B.default.weight")
189
+ mapped[nk] = v
190
+ _, unexpected = velocity.load_state_dict(mapped, strict=False)
191
+ logging.info(f" Loaded {len(mapped) - len(unexpected)} LoRA weights (peft)")
192
+ elif is_iclora:
193
+ audio_keys = {k: v for k, v in lora_sd.items()
194
+ if "audio_attn1" in k or "audio_attn2" in k or "audio_ff" in k}
195
+ mapped = {}
196
+ for k, v in audio_keys.items():
197
+ nk = k.replace("diffusion_model.", "base_model.model.")
198
+ nk = nk.replace(".lora_A.weight", ".lora_A.default.weight")
199
+ nk = nk.replace(".lora_B.weight", ".lora_B.default.weight")
200
+ mapped[nk] = v
201
+ _, unexpected = velocity.load_state_dict(mapped, strict=False)
202
+ logging.info(f" Loaded {len(mapped) - len(unexpected)} LoRA weights (iclora)")
203
+
204
+ velocity = velocity.merge_and_unload()
205
+ logging.info(" Merged LoRA into base weights")
206
+
207
+ return velocity
208
+
209
+ @torch.inference_mode()
210
+ def generate(self, prompt, output_path, voice_ref=None, args=None):
211
+ from audio_conditioning import AudioConditionByReferenceLatent
212
+ from ltx_core.batch_split import BatchSplitAdapter
213
+ from ltx_core.components.diffusion_steps import EulerDiffusionStep
214
+ from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
215
+ from ltx_core.components.noisers import GaussianNoiser
216
+ from ltx_core.components.schedulers import LTX2Scheduler
217
+ from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
218
+ from ltx_core.model.transformer.model import X0Model
219
+ from ltx_core.tools import AudioLatentTools
220
+ from ltx_core.types import Audio, AudioLatentShape, VideoPixelShape
221
+ from ltx_pipelines.utils.denoisers import GuidedDenoiser, SimpleDenoiser
222
+ from ltx_pipelines.utils.gpu_model import gpu_model
223
+ from ltx_pipelines.utils.media_io import decode_audio_from_file
224
+ from ltx_pipelines.utils.samplers import euler_denoising_loop
225
+
226
+ t_total = time.time()
227
+
228
+ # ---- Duration + shape ----
229
+ gen_dur = estimate_speech_duration(prompt) * args.duration_multiplier
230
+ raw_frames = int(round(gen_dur * args.fps)) + 1
231
+ num_frames = ((raw_frames - 1 + 4) // 8) * 8 + 1
232
+ pixel_shape = VideoPixelShape(batch=1, frames=num_frames, height=64, width=64, fps=args.fps)
233
+ tgt_shape = AudioLatentShape.from_video_pixel_shape(pixel_shape)
234
+ audio_tools = AudioLatentTools(patchifier=self.patchifier, target_shape=tgt_shape)
235
+
236
+ state = audio_tools.create_initial_state(self.device, self.dtype)
237
+
238
+ # ---- Voice reference ----
239
+ if voice_ref and os.path.exists(voice_ref):
240
+ voice = decode_audio_from_file(voice_ref, self.device, 0.0, 10.0)
241
+ if voice is not None:
242
+ w = voice.waveform
243
+ if w.dim() == 2:
244
+ if w.shape[0] == 1:
245
+ w = w.repeat(2, 1)
246
+ w = w.unsqueeze(0)
247
+ elif w.dim() == 3 and w.shape[1] == 1:
248
+ w = w.repeat(1, 2, 1)
249
+ target_samples = int(10.0 * voice.sampling_rate)
250
+ if w.shape[-1] < target_samples:
251
+ w = w.repeat(1, 1, (target_samples // w.shape[-1]) + 1)
252
+ w = w[..., :target_samples]
253
+ peak = w.abs().max()
254
+ if peak > 0:
255
+ w = w * (10 ** (-4.0 / 20) / peak)
256
+ voice = Audio(waveform=w, sampling_rate=voice.sampling_rate)
257
+ ref_latent = self.audio_conditioner(lambda enc: vae_encode_audio(voice, enc, None))
258
+ cond = AudioConditionByReferenceLatent(
259
+ latent=ref_latent.to(self.device, self.dtype), strength=1.0,
260
+ )
261
+ state = cond.apply_to(latent_state=state, latent_tools=audio_tools)
262
+
263
+ # ---- Noise ----
264
+ gen = torch.Generator(device=self.device).manual_seed(args.seed)
265
+ noiser = GaussianNoiser(generator=gen)
266
+ state = noiser(state, noise_scale=1.0)
267
+
268
+ # ---- Prompt encode ----
269
+ use_cfg = args.cfg_scale > 1.0
270
+ prompts = [prompt, args.negative_prompt] if use_cfg else [prompt]
271
+ ctx = self.prompt_encoder(prompts, streaming_prefetch_count=None)
272
+ a_ctx = ctx[0].audio_encoding
273
+ a_ctx_neg = ctx[1].audio_encoding if use_cfg else None
274
+
275
+ # ---- Denoiser ----
276
+ needs_guidance = args.cfg_scale > 1.0 or args.stg_scale > 0.0 or args.modality_scale > 1.0
277
+ if needs_guidance:
278
+ guider = MultiModalGuider(
279
+ params=MultiModalGuiderParams(
280
+ cfg_scale=args.cfg_scale, stg_scale=args.stg_scale,
281
+ stg_blocks=[args.stg_block] if args.stg_scale > 0 else [],
282
+ rescale_scale=args.rescale_scale,
283
+ modality_scale=args.modality_scale,
284
+ cfg_clamp_scale=args.cfg_clamp,
285
+ ),
286
+ negative_context=a_ctx_neg,
287
+ )
288
+ denoiser = GuidedDenoiser(
289
+ v_context=None, a_context=a_ctx,
290
+ video_guider=None, audio_guider=guider,
291
+ )
292
+ else:
293
+ denoiser = SimpleDenoiser(v_context=None, a_context=a_ctx)
294
+
295
+ sigmas = LTX2Scheduler().execute(steps=args.steps, latent=state.latent).to(self.device)
296
+
297
+ # ---- Denoise ----
298
+ # NOTE: don't wrap in gpu_model() — that context manager moves the
299
+ # model back off GPU on exit, which breaks subsequent iterations of
300
+ # our warm validator. We keep the velocity model resident.
301
+ x0 = X0Model(self.velocity_model)
302
+ batched = BatchSplitAdapter(x0, max_batch_size=1)
303
+ _, audio_state = euler_denoising_loop(
304
+ sigmas=sigmas, video_state=None, audio_state=state,
305
+ stepper=EulerDiffusionStep(), transformer=batched, denoiser=denoiser,
306
+ )
307
+
308
+ audio_state = audio_tools.clear_conditioning(audio_state)
309
+ audio_state = audio_tools.unpatchify(audio_state)
310
+ decoded = self.audio_decoder(audio_state.latent)
311
+
312
+ wav = decoded.waveform
313
+ if wav.dim() == 1:
314
+ wav = wav.unsqueeze(0)
315
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
316
+ torchaudio.save(output_path, wav.float().cpu(), decoded.sampling_rate)
317
+ logging.info(f" -> {output_path} ({wav.shape[-1]/decoded.sampling_rate:.1f}s, "
318
+ f"{time.time()-t_total:.1f}s)")
319
+
320
+
321
+ def main():
322
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
323
+ args = parse_args()
324
+ import yaml
325
+ with open(args.val_config) as f:
326
+ val_cfg = yaml.safe_load(f)
327
+ os.makedirs(args.output_dir, exist_ok=True)
328
+
329
+ # Build validator once (models warm for all entries).
330
+ validator = WarmValidator(
331
+ full_checkpoint=args.full_checkpoint,
332
+ gemma_root=args.gemma_root,
333
+ lora_path=args.lora,
334
+ lora_rank=args.lora_rank,
335
+ device="cuda" if torch.cuda.is_available() else "cpu",
336
+ dtype=torch.bfloat16,
337
+ )
338
+
339
+ n_ok = n_fail = 0
340
+ t0 = time.time()
341
+ for entry in val_cfg.get("speakers", []):
342
+ name = entry["name"]
343
+ out_path = os.path.join(args.output_dir, f"{name}.wav")
344
+ try:
345
+ validator.generate(
346
+ prompt=entry["prompt"],
347
+ output_path=out_path,
348
+ voice_ref=entry.get("reference"),
349
+ args=args,
350
+ )
351
+ n_ok += 1
352
+ logging.info(f" [{name}] OK")
353
+ except Exception as e:
354
+ n_fail += 1
355
+ logging.warning(f" [{name}] FAILED: {e}")
356
+ traceback.print_exc()
357
+
358
+ logging.info(f"Validation done: ok={n_ok} fail={n_fail} in {(time.time()-t0)/60:.1f}min "
359
+ f"at {args.output_dir}")
360
+
361
+
362
+ if __name__ == "__main__":
363
+ main()