Spaces:
Running on Zero
Running on Zero
github-actions[bot] commited on
Commit ·
2c3df98
1
Parent(s): 4230483
deploy: switch to dramabox requirements @ b5b35d7
Browse files- dramabox_src/audio_conditioning.py +115 -0
- dramabox_src/audio_conditioning.py.training_helpers +115 -0
- dramabox_src/inference.py +678 -0
- dramabox_src/inference_server.py +380 -0
- dramabox_src/model_downloader.py +105 -0
- dramabox_src/preprocess.py +351 -0
- dramabox_src/train.py +882 -0
- dramabox_src/validate.py +363 -0
dramabox_src/audio_conditioning.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio reference conditioning item for IC-LoRA voice cloning."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.components.patchifiers import AudioPatchifier
|
| 6 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 7 |
+
from ltx_core.tools import AudioLatentTools
|
| 8 |
+
from ltx_core.types import AudioLatentShape, LatentState
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AudioConditionByReferenceLatent(ConditioningItem):
|
| 12 |
+
"""Conditions audio generation on a reference audio latent for voice cloning.
|
| 13 |
+
|
| 14 |
+
Mirrors VideoConditionByReferenceLatent but for audio:
|
| 15 |
+
- Patchifies reference latent [B, C, T, F] -> [B, ref_T, 128]
|
| 16 |
+
- Computes 1D temporal positions via AudioPatchifier
|
| 17 |
+
- Sets denoise_mask = 1.0 - strength (strength=1.0 -> mask=0 -> frozen)
|
| 18 |
+
- Builds ASYMMETRIC attention mask: target->ref=1 (attend), ref->target=0 (read-only)
|
| 19 |
+
- APPENDS ref tokens to END of latent sequence (IC-LoRA pattern)
|
| 20 |
+
- Uses OVERLAPPING positions (same coordinate space) so RoPE doesn't
|
| 21 |
+
decay target->ref attention. The asymmetric mask provides the structural
|
| 22 |
+
signal that ref tokens are conditioning, not reconstruction targets.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
latent: Reference audio latent [B, C, T, F] (pre-VAE-encoded).
|
| 26 |
+
strength: Conditioning strength. 1.0 = full (ref kept clean),
|
| 27 |
+
0.0 = none (ref fully denoised). Default 1.0.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, latent: torch.Tensor, strength: float = 1.0):
|
| 31 |
+
self.latent = latent
|
| 32 |
+
self.strength = strength
|
| 33 |
+
|
| 34 |
+
def apply_to(
|
| 35 |
+
self,
|
| 36 |
+
latent_state: LatentState,
|
| 37 |
+
latent_tools: AudioLatentTools,
|
| 38 |
+
) -> LatentState:
|
| 39 |
+
"""Append reference audio tokens with positions and attention mask."""
|
| 40 |
+
tokens = latent_tools.patchifier.patchify(self.latent)
|
| 41 |
+
|
| 42 |
+
# Compute positions for the reference audio — small offset (0.5s) from
|
| 43 |
+
# target start to avoid exact t=0 overlap (which causes ref content to
|
| 44 |
+
# bleed into target start), while keeping RoPE decay minimal.
|
| 45 |
+
# 0.5s / max_pos(20s) = 0.025 fractional — negligible RoPE decay.
|
| 46 |
+
ref_shape = AudioLatentShape(
|
| 47 |
+
batch=self.latent.shape[0],
|
| 48 |
+
channels=self.latent.shape[1],
|
| 49 |
+
frames=self.latent.shape[2],
|
| 50 |
+
mel_bins=self.latent.shape[3],
|
| 51 |
+
)
|
| 52 |
+
positions = latent_tools.patchifier.get_patch_grid_bounds(
|
| 53 |
+
output_shape=ref_shape,
|
| 54 |
+
device=self.latent.device,
|
| 55 |
+
)
|
| 56 |
+
# Small offset to prevent t=0 position collision between target and ref
|
| 57 |
+
positions = positions + 0.5
|
| 58 |
+
|
| 59 |
+
# Denoise mask: 0 for frozen (strength=1.0), 1 for fully denoised (strength=0.0)
|
| 60 |
+
denoise_mask = torch.full(
|
| 61 |
+
size=(*tokens.shape[:2], 1),
|
| 62 |
+
fill_value=1.0 - self.strength,
|
| 63 |
+
device=self.latent.device,
|
| 64 |
+
dtype=torch.float32,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Build ASYMMETRIC attention mask manually.
|
| 68 |
+
# Structure:
|
| 69 |
+
# target (N) ref (M)
|
| 70 |
+
# ┌────────────┬──────────┐
|
| 71 |
+
# target │ 1.0 │ 1.0 │ target attends to everything
|
| 72 |
+
# (N) │ │ │
|
| 73 |
+
# ├────────────┼──────────┤
|
| 74 |
+
# ref │ 0.0 │ 1.0 │ ref only attends to itself
|
| 75 |
+
# (M) │ │ │
|
| 76 |
+
# └────────────┴──────────┘
|
| 77 |
+
#
|
| 78 |
+
# This makes reference tokens "read-only conditioning":
|
| 79 |
+
# - Target tokens freely attend to ref (voice cloning signal)
|
| 80 |
+
# - Ref tokens don't attend to noisy target (stays clean/stable)
|
| 81 |
+
batch_size = tokens.shape[0]
|
| 82 |
+
num_target = latent_state.latent.shape[1]
|
| 83 |
+
num_ref = tokens.shape[1]
|
| 84 |
+
total = num_target + num_ref
|
| 85 |
+
|
| 86 |
+
# Use float32 for the [0,1] mask — _prepare_self_attention_mask converts
|
| 87 |
+
# to log-space bias in the model's compute dtype before it reaches attention.
|
| 88 |
+
mask = torch.zeros(
|
| 89 |
+
(batch_size, total, total),
|
| 90 |
+
device=self.latent.device,
|
| 91 |
+
dtype=torch.float32,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Incorporate existing mask if present, otherwise full attention for target
|
| 95 |
+
if latent_state.attention_mask is not None:
|
| 96 |
+
mask[:, :num_target, :num_target] = latent_state.attention_mask
|
| 97 |
+
else:
|
| 98 |
+
mask[:, :num_target, :num_target] = 1.0
|
| 99 |
+
|
| 100 |
+
# Target -> ref: FULL attention (target can read reference voice)
|
| 101 |
+
mask[:, :num_target, num_target:] = 1.0
|
| 102 |
+
|
| 103 |
+
# Ref -> target: BLOCKED (ref is read-only, doesn't see noisy target)
|
| 104 |
+
# mask[:, num_target:, :num_target] remains 0.0
|
| 105 |
+
|
| 106 |
+
# Ref -> ref: full self-attention within reference
|
| 107 |
+
mask[:, num_target:, num_target:] = 1.0
|
| 108 |
+
|
| 109 |
+
return LatentState(
|
| 110 |
+
latent=torch.cat([latent_state.latent, tokens], dim=1),
|
| 111 |
+
denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
|
| 112 |
+
positions=torch.cat([latent_state.positions, positions], dim=2),
|
| 113 |
+
clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
|
| 114 |
+
attention_mask=mask,
|
| 115 |
+
)
|
dramabox_src/audio_conditioning.py.training_helpers
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio reference conditioning item for IC-LoRA voice cloning."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.components.patchifiers import AudioPatchifier
|
| 6 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 7 |
+
from ltx_core.tools import AudioLatentTools
|
| 8 |
+
from ltx_core.types import AudioLatentShape, LatentState
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AudioConditionByReferenceLatent(ConditioningItem):
|
| 12 |
+
"""Conditions audio generation on a reference audio latent for voice cloning.
|
| 13 |
+
|
| 14 |
+
Mirrors VideoConditionByReferenceLatent but for audio:
|
| 15 |
+
- Patchifies reference latent [B, C, T, F] -> [B, ref_T, 128]
|
| 16 |
+
- Computes 1D temporal positions via AudioPatchifier
|
| 17 |
+
- Sets denoise_mask = 1.0 - strength (strength=1.0 -> mask=0 -> frozen)
|
| 18 |
+
- Builds ASYMMETRIC attention mask: target->ref=1 (attend), ref->target=0 (read-only)
|
| 19 |
+
- APPENDS ref tokens to END of latent sequence (IC-LoRA pattern)
|
| 20 |
+
- Uses OVERLAPPING positions (same coordinate space) so RoPE doesn't
|
| 21 |
+
decay target->ref attention. The asymmetric mask provides the structural
|
| 22 |
+
signal that ref tokens are conditioning, not reconstruction targets.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
latent: Reference audio latent [B, C, T, F] (pre-VAE-encoded).
|
| 26 |
+
strength: Conditioning strength. 1.0 = full (ref kept clean),
|
| 27 |
+
0.0 = none (ref fully denoised). Default 1.0.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, latent: torch.Tensor, strength: float = 1.0):
|
| 31 |
+
self.latent = latent
|
| 32 |
+
self.strength = strength
|
| 33 |
+
|
| 34 |
+
def apply_to(
|
| 35 |
+
self,
|
| 36 |
+
latent_state: LatentState,
|
| 37 |
+
latent_tools: AudioLatentTools,
|
| 38 |
+
) -> LatentState:
|
| 39 |
+
"""Append reference audio tokens with positions and attention mask."""
|
| 40 |
+
tokens = latent_tools.patchifier.patchify(self.latent)
|
| 41 |
+
|
| 42 |
+
# Compute positions for the reference audio — small offset (0.5s) from
|
| 43 |
+
# target start to avoid exact t=0 overlap (which causes ref content to
|
| 44 |
+
# bleed into target start), while keeping RoPE decay minimal.
|
| 45 |
+
# 0.5s / max_pos(20s) = 0.025 fractional — negligible RoPE decay.
|
| 46 |
+
ref_shape = AudioLatentShape(
|
| 47 |
+
batch=self.latent.shape[0],
|
| 48 |
+
channels=self.latent.shape[1],
|
| 49 |
+
frames=self.latent.shape[2],
|
| 50 |
+
mel_bins=self.latent.shape[3],
|
| 51 |
+
)
|
| 52 |
+
positions = latent_tools.patchifier.get_patch_grid_bounds(
|
| 53 |
+
output_shape=ref_shape,
|
| 54 |
+
device=self.latent.device,
|
| 55 |
+
)
|
| 56 |
+
# Small offset to prevent t=0 position collision between target and ref
|
| 57 |
+
positions = positions + 0.5
|
| 58 |
+
|
| 59 |
+
# Denoise mask: 0 for frozen (strength=1.0), 1 for fully denoised (strength=0.0)
|
| 60 |
+
denoise_mask = torch.full(
|
| 61 |
+
size=(*tokens.shape[:2], 1),
|
| 62 |
+
fill_value=1.0 - self.strength,
|
| 63 |
+
device=self.latent.device,
|
| 64 |
+
dtype=torch.float32,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Build ASYMMETRIC attention mask manually.
|
| 68 |
+
# Structure:
|
| 69 |
+
# target (N) ref (M)
|
| 70 |
+
# ┌────────────┬──────────┐
|
| 71 |
+
# target │ 1.0 │ 1.0 │ target attends to everything
|
| 72 |
+
# (N) │ │ │
|
| 73 |
+
# ├────────────┼──────────┤
|
| 74 |
+
# ref │ 0.0 │ 1.0 │ ref only attends to itself
|
| 75 |
+
# (M) │ │ │
|
| 76 |
+
# └────────────┴──────────┘
|
| 77 |
+
#
|
| 78 |
+
# This makes reference tokens "read-only conditioning":
|
| 79 |
+
# - Target tokens freely attend to ref (voice cloning signal)
|
| 80 |
+
# - Ref tokens don't attend to noisy target (stays clean/stable)
|
| 81 |
+
batch_size = tokens.shape[0]
|
| 82 |
+
num_target = latent_state.latent.shape[1]
|
| 83 |
+
num_ref = tokens.shape[1]
|
| 84 |
+
total = num_target + num_ref
|
| 85 |
+
|
| 86 |
+
# Use float32 for the [0,1] mask — _prepare_self_attention_mask converts
|
| 87 |
+
# to log-space bias in the model's compute dtype before it reaches attention.
|
| 88 |
+
mask = torch.zeros(
|
| 89 |
+
(batch_size, total, total),
|
| 90 |
+
device=self.latent.device,
|
| 91 |
+
dtype=torch.float32,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Incorporate existing mask if present, otherwise full attention for target
|
| 95 |
+
if latent_state.attention_mask is not None:
|
| 96 |
+
mask[:, :num_target, :num_target] = latent_state.attention_mask
|
| 97 |
+
else:
|
| 98 |
+
mask[:, :num_target, :num_target] = 1.0
|
| 99 |
+
|
| 100 |
+
# Target -> ref: FULL attention (target can read reference voice)
|
| 101 |
+
mask[:, :num_target, num_target:] = 1.0
|
| 102 |
+
|
| 103 |
+
# Ref -> target: BLOCKED (ref is read-only, doesn't see noisy target)
|
| 104 |
+
# mask[:, num_target:, :num_target] remains 0.0
|
| 105 |
+
|
| 106 |
+
# Ref -> ref: full self-attention within reference
|
| 107 |
+
mask[:, num_target:, num_target:] = 1.0
|
| 108 |
+
|
| 109 |
+
return LatentState(
|
| 110 |
+
latent=torch.cat([latent_state.latent, tokens], dim=1),
|
| 111 |
+
denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
|
| 112 |
+
positions=torch.cat([latent_state.positions, positions], dim=2),
|
| 113 |
+
clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
|
| 114 |
+
attention_mask=mask,
|
| 115 |
+
)
|
dramabox_src/inference.py
ADDED
|
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
LTX-2.3 TTS with IC-LoRA voice cloning.
|
| 4 |
+
|
| 5 |
+
Uses AudioConditionByReferenceLatent to append reference audio tokens to the
|
| 6 |
+
end of the target sequence. Auto-detects distilled vs dev checkpoint and
|
| 7 |
+
selects the appropriate denoiser (SimpleDenoiser / GuidedDenoiser) and sigma
|
| 8 |
+
schedule. Leverages the official euler_denoising_loop, AudioLatentTools,
|
| 9 |
+
GaussianNoiser, and X0Model wrapper throughout.
|
| 10 |
+
|
| 11 |
+
Usage (distilled):
|
| 12 |
+
python tts_iclora.py \
|
| 13 |
+
--voice-sample reference.wav \
|
| 14 |
+
--prompt "A woman speaks clearly: The weather today will be sunny." \
|
| 15 |
+
--output tts_output.wav
|
| 16 |
+
|
| 17 |
+
Usage (dev):
|
| 18 |
+
python tts_iclora.py \
|
| 19 |
+
--voice-sample reference.wav \
|
| 20 |
+
--prompt "A woman speaks clearly: The weather today will be sunny." \
|
| 21 |
+
--checkpoint ltx-2.3-22b-dev-audio-only.safetensors \
|
| 22 |
+
--full-checkpoint ltx-2.3-22b-dev.safetensors \
|
| 23 |
+
--output tts_output.wav
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import argparse
|
| 27 |
+
import json
|
| 28 |
+
import logging
|
| 29 |
+
import os
|
| 30 |
+
import re
|
| 31 |
+
import struct
|
| 32 |
+
import sys
|
| 33 |
+
import time
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
import torchaudio
|
| 38 |
+
|
| 39 |
+
REPO_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 40 |
+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ltx2"))
|
| 41 |
+
# ltx-pipelines already on path via ltx2/
|
| 42 |
+
|
| 43 |
+
# Also add the local directory so audio_conditioning.py is importable
|
| 44 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 45 |
+
|
| 46 |
+
MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "models")
|
| 47 |
+
GEMMA_DIR = os.environ.get("GEMMA_DIR", "gemma-3-12b-it-qat-q4_0-unquantized")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
# Helpers
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def detect_model_type(checkpoint_path: str) -> str:
|
| 56 |
+
"""Detect if checkpoint is distilled or dev by checking filename and metadata."""
|
| 57 |
+
path_lower = checkpoint_path.lower()
|
| 58 |
+
if "distilled" in path_lower:
|
| 59 |
+
return "distilled"
|
| 60 |
+
if "dev" in path_lower:
|
| 61 |
+
return "dev"
|
| 62 |
+
# Fallback: try to read safetensors metadata
|
| 63 |
+
try:
|
| 64 |
+
with open(checkpoint_path, "rb") as f:
|
| 65 |
+
header_size = struct.unpack("<Q", f.read(8))[0]
|
| 66 |
+
header = json.loads(f.read(header_size).decode())
|
| 67 |
+
metadata = header.get("__metadata__", {})
|
| 68 |
+
version = metadata.get("model_version", "")
|
| 69 |
+
if "distilled" in version.lower():
|
| 70 |
+
return "distilled"
|
| 71 |
+
except Exception:
|
| 72 |
+
pass
|
| 73 |
+
# Default to distilled (most common for audio-only)
|
| 74 |
+
return "distilled"
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
_LAUGH_VERBS = {
|
| 78 |
+
# base seconds per occurrence; gets scaled by the modifier found nearby.
|
| 79 |
+
# Verb regex covers inflections: laugh/laughs/laughed/laughing.
|
| 80 |
+
r"\blaugh(?:s|ed|ing)?\b": 1.5,
|
| 81 |
+
r"\bcackl(?:e|es|ed|ing)\b": 1.5,
|
| 82 |
+
r"\bchuckl(?:e|es|ed|ing)\b": 1.0,
|
| 83 |
+
r"\bgiggl(?:e|es|ed|ing)\b": 1.0,
|
| 84 |
+
r"\bsnicker(?:s|ed|ing)?\b": 0.8,
|
| 85 |
+
r"\bcru?el laugh\b": 1.5,
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _contextual_laugh_duration(text: str) -> float:
|
| 90 |
+
"""Context-aware laugh budget.
|
| 91 |
+
|
| 92 |
+
For each laugh verb in the prompt, look at the adjective/adverb that
|
| 93 |
+
modifies it and scale the base duration:
|
| 94 |
+
- short modifiers (briefly, softly, once) -> 0.4x base
|
| 95 |
+
- long modifiers (maniacally, heartily, ...) -> 1.2x base
|
| 96 |
+
- default (no mod / neutral) -> 1.0x base
|
| 97 |
+
Also reward phonetic repetition inside quotes -- 'Hahahahahaha' buys more
|
| 98 |
+
time than 'Haha' -- at ~0.2s per extra repeated syllable.
|
| 99 |
+
"""
|
| 100 |
+
# "softly" / "quietly" describe volume not length, so keep at default 1.0x.
|
| 101 |
+
short_mod = re.compile(
|
| 102 |
+
r"^\s*(?:[a-z]+ly )?(?:briefly|shortly|once|quickly)",
|
| 103 |
+
re.IGNORECASE)
|
| 104 |
+
long_mod = re.compile(
|
| 105 |
+
r"^\s*(?:[a-z]+ly )?(?:maniacally|heartily|uproariously|uncontrollably|"
|
| 106 |
+
r"hysterically|darkly|wickedly|evilly|loudly|long)"
|
| 107 |
+
r"|^\s*between phrases", re.IGNORECASE)
|
| 108 |
+
|
| 109 |
+
total = 0.0
|
| 110 |
+
for pat, base_dur in _LAUGH_VERBS.items():
|
| 111 |
+
for m in re.finditer(pat, text, re.IGNORECASE):
|
| 112 |
+
ctx = text[m.end(): m.end() + 40]
|
| 113 |
+
if short_mod.match(ctx):
|
| 114 |
+
total += base_dur * 0.4
|
| 115 |
+
elif long_mod.match(ctx):
|
| 116 |
+
total += base_dur * 1.2
|
| 117 |
+
else:
|
| 118 |
+
total += base_dur
|
| 119 |
+
|
| 120 |
+
# Phonetic laugh repetition inside quotes:
|
| 121 |
+
# 'Haha' = 2 syllables (base, no bonus)
|
| 122 |
+
# 'Hahahaha' = 4 syllables (+0.4s)
|
| 123 |
+
# 'Hehehehahahahahahahaha' ~ 10 syllables (+1.6s)
|
| 124 |
+
for q in re.findall(r'"([^"]+)"', text) + re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text):
|
| 125 |
+
for run in re.findall(r"(?:h[ae]){3,}|(?:h[ae][ \-]?){3,}", q, re.IGNORECASE):
|
| 126 |
+
syls = len(re.findall(r"h[ae]", run, re.IGNORECASE))
|
| 127 |
+
total += 0.2 * max(syls - 2, 0)
|
| 128 |
+
return total
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _estimate_nonverbal_duration(text: str) -> float:
|
| 132 |
+
"""Estimate extra duration for non-verbal sounds and actions in the prompt.
|
| 133 |
+
|
| 134 |
+
Laugh-verb handling lives in ``_contextual_laugh_duration`` so cackle /
|
| 135 |
+
chuckle / laugh budgets scale with the adjective ("maniacally" vs
|
| 136 |
+
"briefly") and with the repetition length of 'Ha'/'He' tokens inside
|
| 137 |
+
quotes.
|
| 138 |
+
"""
|
| 139 |
+
PATTERNS = {
|
| 140 |
+
# Breathing / sighs
|
| 141 |
+
r'\bsighs?\b': 0.8, r'\bshaky breath\b': 1.0, r'\bbreathing deeply\b': 1.0,
|
| 142 |
+
r'\bgasps?\b': 0.5, r'\bburps?\b': 0.5, r'\byawns?\b': 1.0,
|
| 143 |
+
r'\bpants?\b': 0.8, r'\bwheezes?\b': 0.8, r'\bcoughs?\b': 0.8,
|
| 144 |
+
r'\bsniffles?\b': 0.5, r'\bsnorts?\b': 0.3, r'\bgroans?\b': 0.8,
|
| 145 |
+
# Pauses (trimmed; earlier values over-budgeted silence)
|
| 146 |
+
r'\blong pause\b': 1.0, r'\bpauses? briefly\b': 0.3,
|
| 147 |
+
r'\bpauses?\b': 0.5, r'\bsilence\b': 1.0,
|
| 148 |
+
r'\blets? the .{1,20} hang\b': 1.0, r'\blets? .{1,20} sink in\b': 1.0,
|
| 149 |
+
# Physical actions that produce sound
|
| 150 |
+
r'\bslams?\b': 0.5, r'\bclaps?\b': 0.3,
|
| 151 |
+
r'\bdraws? (?:his|her|a) sword\b': 0.5,
|
| 152 |
+
r'\btakes? a (?:drag|swig|sip|drink)\b': 0.5,
|
| 153 |
+
r'\bwhistles?\b': 1.0, r'\bhums?\b': 0.8,
|
| 154 |
+
# Vocal actions (not in quotes but take time)
|
| 155 |
+
r'\bmutters?\b': 1.5, r'\bmumbles?\b': 1.0, r'\bwhispers?\b': 0.0,
|
| 156 |
+
r'\bclears? (?:his|her) throat\b': 0.5, r'\bgulps?\b': 0.5,
|
| 157 |
+
r'\bswallows?\b': 0.5,
|
| 158 |
+
# (laugh / chuckle / cackle / giggle / snicker handled by
|
| 159 |
+
# _contextual_laugh_duration below -- modifier-aware, not flat.)
|
| 160 |
+
# Emotional transitions
|
| 161 |
+
r'\bvoice (?:breaks?|cracks?|trembles?|drops?|rises?)\b': 0.5,
|
| 162 |
+
r'\bsteadies? (?:him|her)self\b': 1.0,
|
| 163 |
+
r'\bcatches? (?:his|her) breath\b': 1.0,
|
| 164 |
+
r'\bcomposes? (?:him|her)self\b': 0.8,
|
| 165 |
+
# Scene transitions that imply time
|
| 166 |
+
r'\bdemeanor shifts?\b': 0.5, r'\bsettles? in\b': 0.5,
|
| 167 |
+
r'\bleans? in\b': 0.3, r'\bwipes? (?:his|her) eyes\b': 0.5,
|
| 168 |
+
}
|
| 169 |
+
extra = 0.0
|
| 170 |
+
for pattern, dur in PATTERNS.items():
|
| 171 |
+
extra += dur * len(re.findall(pattern, text, re.IGNORECASE))
|
| 172 |
+
extra += _contextual_laugh_duration(text)
|
| 173 |
+
return extra
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def estimate_speech_duration(text: str, speed: float = 1.0) -> float:
|
| 177 |
+
"""Estimate speech duration from spoken content + non-verbal actions.
|
| 178 |
+
|
| 179 |
+
Extracts spoken text by priority:
|
| 180 |
+
1. Quoted text ('...' or "...") -- official prompt guide format
|
| 181 |
+
2. Text after colon -- simple "Speaker: dialogue" format
|
| 182 |
+
3. Full text -- fallback
|
| 183 |
+
|
| 184 |
+
Also scans the full prompt for non-verbal cues (laughs, pauses, sighs,
|
| 185 |
+
gasps, etc.) and adds estimated duration for each.
|
| 186 |
+
"""
|
| 187 |
+
# Try double quotes first (clean, no contraction issues)
|
| 188 |
+
quotes = re.findall(r'"([^"]+)"', text)
|
| 189 |
+
if not quotes:
|
| 190 |
+
# Single quotes: allow apostrophes in contractions (don't, can't, it's)
|
| 191 |
+
# Match ' to ' but apostrophes NOT followed by space/punctuation are kept inside
|
| 192 |
+
quotes = re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text)
|
| 193 |
+
# Filter out short fragments (scene directions like "He pauses")
|
| 194 |
+
quotes = [q for q in quotes if len(q.split()) > 3]
|
| 195 |
+
if quotes:
|
| 196 |
+
spoken = " ".join(quotes)
|
| 197 |
+
elif ":" in text:
|
| 198 |
+
spoken = text.split(":", 1)[1].strip()
|
| 199 |
+
else:
|
| 200 |
+
spoken = text
|
| 201 |
+
|
| 202 |
+
CHARS_PER_SEC = 14.0
|
| 203 |
+
text_len = len(spoken)
|
| 204 |
+
|
| 205 |
+
if text_len < 40:
|
| 206 |
+
chars_per_sec = CHARS_PER_SEC * 0.6
|
| 207 |
+
elif text_len < 80:
|
| 208 |
+
chars_per_sec = CHARS_PER_SEC * 0.8
|
| 209 |
+
else:
|
| 210 |
+
chars_per_sec = CHARS_PER_SEC
|
| 211 |
+
|
| 212 |
+
chars_per_sec *= speed
|
| 213 |
+
duration = text_len / chars_per_sec
|
| 214 |
+
|
| 215 |
+
sentence_count = spoken.count(".") + spoken.count("!") + spoken.count("?")
|
| 216 |
+
duration += sentence_count * 0.3
|
| 217 |
+
|
| 218 |
+
# Add time for non-verbal sounds/actions in the full prompt
|
| 219 |
+
duration += _estimate_nonverbal_duration(text)
|
| 220 |
+
|
| 221 |
+
return max(3.0, round(duration + 2.0, 1))
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def parse_args():
|
| 225 |
+
p = argparse.ArgumentParser(description="LTX-2.3 TTS with IC-LoRA voice cloning")
|
| 226 |
+
|
| 227 |
+
p.add_argument("--voice-sample", default=None, help="Voice reference WAV")
|
| 228 |
+
p.add_argument("--no-ref", action="store_true", help="Skip voice reference conditioning (raw base model)")
|
| 229 |
+
p.add_argument("--prompt", required=True, help="Text/scene description to synthesize")
|
| 230 |
+
p.add_argument("--output", default="tts_output.wav")
|
| 231 |
+
|
| 232 |
+
p.add_argument("--ref-duration", type=float, default=10.0, help="Seconds of voice reference to use")
|
| 233 |
+
p.add_argument("--gen-duration", type=float, default=0.0,
|
| 234 |
+
help="Target output duration in seconds (0 = auto from prompt + multiplier). "
|
| 235 |
+
"Set explicitly for long-form prompts (e.g. --gen-duration 30 for music). "
|
| 236 |
+
"Outputs >20.5s automatically engage the end-of-clip silence-prior patch.")
|
| 237 |
+
p.add_argument("--pad-start", type=float, default=0.0,
|
| 238 |
+
help="Prepend N seconds of silent padding, trimmed after decode (use 0 for clean starts)")
|
| 239 |
+
p.add_argument("--speed", type=float, default=1.0)
|
| 240 |
+
p.add_argument("--duration-multiplier", type=float, default=1.0,
|
| 241 |
+
help="Multiply auto-estimated duration by this factor (e.g. 1.1 for 10%% more breathing room)")
|
| 242 |
+
|
| 243 |
+
p.add_argument("--checkpoint", default=os.path.join(MODEL_DIR, "ltx-2.3-audio-only.safetensors"))
|
| 244 |
+
p.add_argument("--full-checkpoint", default=os.path.join(MODEL_DIR, "ltx-2.3-22b-distilled.safetensors"))
|
| 245 |
+
p.add_argument("--gemma-root", default=GEMMA_DIR)
|
| 246 |
+
p.add_argument("--bnb-4bit", dest="bnb_4bit", action="store_true", default=True,
|
| 247 |
+
help="Load Gemma text encoder via the bitsandbytes 4-bit path "
|
| 248 |
+
"(required for the default unsloth/gemma-3-12b-it-bnb-4bit "
|
| 249 |
+
"pre-quantized weights). Default: on.")
|
| 250 |
+
p.add_argument("--no-bnb-4bit", dest="bnb_4bit", action="store_false",
|
| 251 |
+
help="Disable the bitsandbytes path (use only if --gemma-root "
|
| 252 |
+
"points at an unquantized Gemma checkpoint).")
|
| 253 |
+
p.add_argument("--lora", default=None, help="Path to trained IC-LoRA .safetensors (audio-only)")
|
| 254 |
+
p.add_argument("--lora-rank", type=int, default=128, help="LoRA rank (must match training)")
|
| 255 |
+
p.add_argument("--id-guidance-scale", type=float, default=3.0, help="Identity guidance scale (0=disabled)")
|
| 256 |
+
p.add_argument("--seed", type=int, default=42)
|
| 257 |
+
|
| 258 |
+
# Auto-set based on model type but overridable
|
| 259 |
+
p.add_argument("--no-watermark", action="store_true",
|
| 260 |
+
help="Skip Perth audio watermarking on the output (default: watermark on).")
|
| 261 |
+
p.add_argument("--sampler", choices=["euler", "heun"], default="euler",
|
| 262 |
+
help="Denoising loop. 'heun' = jkass_quality 2nd-order predictor-corrector (~2x model calls, cleaner audio).")
|
| 263 |
+
p.add_argument("--cfg-scale", type=float, default=None, help="CFG scale (auto: 1.0 distilled, 7.0 dev)")
|
| 264 |
+
p.add_argument("--stg-scale", type=float, default=None, help="STG scale (auto: 0.0 distilled, 1.0 dev)")
|
| 265 |
+
p.add_argument("--stg-block", type=int, default=29, help="Block index for STG perturbation")
|
| 266 |
+
p.add_argument("--rescale-scale", type=float, default=None,
|
| 267 |
+
help="Latent CFG std-rescale (default auto: cfg-aware schedule that prevents "
|
| 268 |
+
"output clipping at high cfg; pass any float in [0,1] to override).")
|
| 269 |
+
p.add_argument("--modality-scale", type=float, default=None, help="Modality (auto: 1.0 distilled, 3.0 dev)")
|
| 270 |
+
p.add_argument("--cfg-clamp", type=float, default=0.0, help="Clamp guided pred std to N * cond std (0=disabled)")
|
| 271 |
+
p.add_argument("--steps", type=int, default=None, help="Override steps (auto: distilled sigmas / 30 dev)")
|
| 272 |
+
p.add_argument("--fps", type=float, default=None, help="FPS (auto: 24.0 distilled, 25.0 dev)")
|
| 273 |
+
p.add_argument(
|
| 274 |
+
"--negative-prompt",
|
| 275 |
+
default=(
|
| 276 |
+
"worst quality, inconsistent motion, blurry, jittery, distorted, "
|
| 277 |
+
"robotic voice, echo, background noise, off-sync audio, repetitive speech"
|
| 278 |
+
),
|
| 279 |
+
help="Negative prompt for CFG (dev model)",
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
return p.parse_args()
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
@torch.inference_mode()
|
| 286 |
+
def main():
|
| 287 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 288 |
+
args = parse_args()
|
| 289 |
+
t0 = time.time()
|
| 290 |
+
|
| 291 |
+
# ---- Imports (deferred to avoid startup cost when checking --help) ----
|
| 292 |
+
from audio_conditioning import AudioConditionByReferenceLatent
|
| 293 |
+
|
| 294 |
+
from ltx_core.batch_split import BatchSplitAdapter
|
| 295 |
+
from ltx_core.components.diffusion_steps import EulerDiffusionStep
|
| 296 |
+
from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
|
| 297 |
+
from ltx_core.components.noisers import GaussianNoiser
|
| 298 |
+
from ltx_core.components.patchifiers import AudioPatchifier
|
| 299 |
+
from ltx_core.components.schedulers import LTX2Scheduler
|
| 300 |
+
from ltx_core.loader.registry import DummyRegistry
|
| 301 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 302 |
+
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
|
| 303 |
+
from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
|
| 304 |
+
from ltx_core.model.model_protocol import ModelConfigurator
|
| 305 |
+
from ltx_core.model.transformer.attention import AttentionFunction
|
| 306 |
+
from ltx_core.model.transformer.model import LTXModel, LTXModelType, X0Model
|
| 307 |
+
from ltx_core.model.transformer.rope import LTXRopeType
|
| 308 |
+
from ltx_core.tools import AudioLatentTools
|
| 309 |
+
from ltx_core.types import Audio, AudioLatentShape, LatentState, VideoPixelShape
|
| 310 |
+
from ltx_pipelines.utils.blocks import AudioConditioner, AudioDecoder, PromptEncoder
|
| 311 |
+
from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES
|
| 312 |
+
from ltx_pipelines.utils.denoisers import GuidedDenoiser, SimpleDenoiser
|
| 313 |
+
from ltx_pipelines.utils.gpu_model import gpu_model
|
| 314 |
+
from ltx_pipelines.utils.media_io import decode_audio_from_file
|
| 315 |
+
from ltx_pipelines.utils.samplers import euler_denoising_loop, heun_denoising_loop
|
| 316 |
+
|
| 317 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 318 |
+
dtype = torch.bfloat16
|
| 319 |
+
patchifier = AudioPatchifier(patch_size=1)
|
| 320 |
+
|
| 321 |
+
# ---- Detect model type and set defaults ----
|
| 322 |
+
model_type = detect_model_type(args.full_checkpoint)
|
| 323 |
+
logging.info(f"Detected model type: {model_type}")
|
| 324 |
+
|
| 325 |
+
is_distilled = model_type == "distilled"
|
| 326 |
+
|
| 327 |
+
if args.cfg_scale is None:
|
| 328 |
+
args.cfg_scale = 1.0 if is_distilled else 7.0
|
| 329 |
+
if args.stg_scale is None:
|
| 330 |
+
args.stg_scale = 0.0 if is_distilled else 1.0
|
| 331 |
+
if args.rescale_scale is None:
|
| 332 |
+
# Auto cfg-aware rescale: imported from inference_server to keep one source of truth.
|
| 333 |
+
from inference_server import auto_rescale_for_cfg
|
| 334 |
+
args.rescale_scale = 0.0 if is_distilled else auto_rescale_for_cfg(args.cfg_scale)
|
| 335 |
+
if args.modality_scale is None:
|
| 336 |
+
args.modality_scale = 1.0 if is_distilled else 3.0
|
| 337 |
+
if args.fps is None:
|
| 338 |
+
args.fps = 24.0 if is_distilled else 25.0
|
| 339 |
+
|
| 340 |
+
logging.info(
|
| 341 |
+
f"Params: cfg={args.cfg_scale}, stg={args.stg_scale}, rescale={args.rescale_scale}, "
|
| 342 |
+
f"modality={args.modality_scale}, fps={args.fps}"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# ---- Auto duration ----
|
| 346 |
+
if args.gen_duration <= 0:
|
| 347 |
+
args.gen_duration = estimate_speech_duration(args.prompt, args.speed)
|
| 348 |
+
if args.duration_multiplier != 1.0:
|
| 349 |
+
args.gen_duration = round(args.gen_duration * args.duration_multiplier, 1)
|
| 350 |
+
logging.info(f"Auto duration: {args.gen_duration}s for {len(args.prompt)} chars"
|
| 351 |
+
f"{f' (x{args.duration_multiplier})' if args.duration_multiplier != 1.0 else ''}")
|
| 352 |
+
|
| 353 |
+
# ---- Compute target shape (include pad_start in duration) ----
|
| 354 |
+
padded_duration = args.gen_duration + args.pad_start
|
| 355 |
+
raw_frames = int(round(padded_duration * args.fps)) + 1
|
| 356 |
+
num_frames = ((raw_frames - 1 + 4) // 8) * 8 + 1
|
| 357 |
+
pixel_shape = VideoPixelShape(batch=1, frames=num_frames, height=64, width=64, fps=args.fps)
|
| 358 |
+
tgt_shape = AudioLatentShape.from_video_pixel_shape(pixel_shape)
|
| 359 |
+
logging.info(f"Target shape: {tgt_shape} ({args.gen_duration}s, {num_frames} frames)")
|
| 360 |
+
|
| 361 |
+
# ---- AudioLatentTools for target ----
|
| 362 |
+
audio_tools = AudioLatentTools(patchifier=patchifier, target_shape=tgt_shape)
|
| 363 |
+
|
| 364 |
+
# ---- Create initial state ----
|
| 365 |
+
state = audio_tools.create_initial_state(device, dtype)
|
| 366 |
+
logging.info(
|
| 367 |
+
f"Initial state: latent={state.latent.shape}, positions={state.positions.shape}, "
|
| 368 |
+
f"denoise_mask={state.denoise_mask.shape}"
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
if not args.no_ref and args.voice_sample:
|
| 372 |
+
# ---- Encode voice reference ----
|
| 373 |
+
logging.info(f"Loading voice reference: {args.voice_sample}")
|
| 374 |
+
voice = decode_audio_from_file(args.voice_sample, device, 0.0, args.ref_duration)
|
| 375 |
+
if voice is None:
|
| 376 |
+
raise ValueError(f"Could not load audio from {args.voice_sample}")
|
| 377 |
+
|
| 378 |
+
w = voice.waveform
|
| 379 |
+
if w.dim() == 2:
|
| 380 |
+
if w.shape[0] == 1:
|
| 381 |
+
w = w.repeat(2, 1)
|
| 382 |
+
w = w.unsqueeze(0)
|
| 383 |
+
elif w.dim() == 3 and w.shape[1] == 1:
|
| 384 |
+
w = w.repeat(1, 2, 1)
|
| 385 |
+
|
| 386 |
+
target_samples = int(args.ref_duration * voice.sampling_rate)
|
| 387 |
+
if w.shape[-1] < target_samples:
|
| 388 |
+
w = w.repeat(1, 1, (target_samples // w.shape[-1]) + 1)
|
| 389 |
+
w = w[..., :target_samples]
|
| 390 |
+
|
| 391 |
+
# Peak normalize reference
|
| 392 |
+
peak = w.abs().max()
|
| 393 |
+
if peak > 0:
|
| 394 |
+
target_peak = 10 ** (-4.0 / 20) # -4dB
|
| 395 |
+
w = w * (target_peak / peak)
|
| 396 |
+
logging.info(f"Normalized reference: peak {peak:.4f} -> {target_peak:.4f}")
|
| 397 |
+
|
| 398 |
+
voice = Audio(waveform=w, sampling_rate=voice.sampling_rate)
|
| 399 |
+
|
| 400 |
+
logging.info("Encoding voice through Audio VAE...")
|
| 401 |
+
ac = AudioConditioner(checkpoint_path=args.full_checkpoint, dtype=dtype, device=device)
|
| 402 |
+
ref_latent = ac(lambda enc: vae_encode_audio(voice, enc, None))
|
| 403 |
+
del ac
|
| 404 |
+
torch.cuda.empty_cache()
|
| 405 |
+
logging.info(f"Reference latent: {ref_latent.shape}")
|
| 406 |
+
|
| 407 |
+
# ---- Apply conditioning: append ref tokens to END ----
|
| 408 |
+
conditioning = AudioConditionByReferenceLatent(latent=ref_latent.to(device, dtype), strength=1.0)
|
| 409 |
+
state = conditioning.apply_to(latent_state=state, latent_tools=audio_tools)
|
| 410 |
+
logging.info(
|
| 411 |
+
f"After conditioning: latent={state.latent.shape}, positions={state.positions.shape}, "
|
| 412 |
+
f"attention_mask={'None' if state.attention_mask is None else state.attention_mask.shape}"
|
| 413 |
+
)
|
| 414 |
+
else:
|
| 415 |
+
logging.info("No voice reference — running raw base model")
|
| 416 |
+
|
| 417 |
+
# ---- Apply noise ----
|
| 418 |
+
generator = torch.Generator(device=device).manual_seed(args.seed)
|
| 419 |
+
noiser = GaussianNoiser(generator=generator)
|
| 420 |
+
noised_state = noiser(state, noise_scale=1.0)
|
| 421 |
+
logging.info("Applied Gaussian noise to state")
|
| 422 |
+
|
| 423 |
+
# ---- Encode prompt ----
|
| 424 |
+
use_cfg = args.cfg_scale > 1.0
|
| 425 |
+
logging.info("Encoding prompt...")
|
| 426 |
+
pe = PromptEncoder(checkpoint_path=args.full_checkpoint, gemma_root=args.gemma_root, dtype=dtype, device=device,
|
| 427 |
+
use_bnb_4bit=args.bnb_4bit, warm=True)
|
| 428 |
+
prompts_to_encode = [args.prompt]
|
| 429 |
+
if use_cfg:
|
| 430 |
+
prompts_to_encode.append(args.negative_prompt)
|
| 431 |
+
ctx = pe(prompts_to_encode, streaming_prefetch_count=None)
|
| 432 |
+
a_ctx = ctx[0].audio_encoding
|
| 433 |
+
a_ctx_neg = ctx[1].audio_encoding if use_cfg else None
|
| 434 |
+
del pe
|
| 435 |
+
torch.cuda.empty_cache()
|
| 436 |
+
logging.info(f"Prompt encoded: a_ctx={a_ctx.shape}" + (f", a_ctx_neg={a_ctx_neg.shape}" if a_ctx_neg is not None else ""))
|
| 437 |
+
|
| 438 |
+
# ---- Build audio-only model ----
|
| 439 |
+
logging.info("Building audio-only model...")
|
| 440 |
+
audio_only_sd_ops = SDOps("AO").with_matching(prefix="model.diffusion_model.").with_replacement(
|
| 441 |
+
"model.diffusion_model.", ""
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
class AudioOnlyConfigurator(ModelConfigurator[LTXModel]):
|
| 445 |
+
@classmethod
|
| 446 |
+
def from_config(cls, config):
|
| 447 |
+
t = config.get("transformer", {})
|
| 448 |
+
cp = None
|
| 449 |
+
if not t.get("caption_proj_before_connector", False):
|
| 450 |
+
from ltx_core.model.transformer.text_projection import create_caption_projection
|
| 451 |
+
|
| 452 |
+
with torch.device("meta"):
|
| 453 |
+
cp = create_caption_projection(t, audio=True)
|
| 454 |
+
return LTXModel(
|
| 455 |
+
model_type=LTXModelType.AudioOnly,
|
| 456 |
+
audio_num_attention_heads=t.get("audio_num_attention_heads", 32),
|
| 457 |
+
audio_attention_head_dim=t.get("audio_attention_head_dim", 64),
|
| 458 |
+
audio_in_channels=t.get("audio_in_channels", 128),
|
| 459 |
+
audio_out_channels=t.get("audio_out_channels", 128),
|
| 460 |
+
num_layers=t.get("num_layers", 48),
|
| 461 |
+
audio_cross_attention_dim=t.get("audio_cross_attention_dim", 2048),
|
| 462 |
+
norm_eps=t.get("norm_eps", 1e-6),
|
| 463 |
+
attention_type=AttentionFunction(t.get("attention_type", "default")),
|
| 464 |
+
positional_embedding_theta=10000.0,
|
| 465 |
+
audio_positional_embedding_max_pos=[20.0],
|
| 466 |
+
timestep_scale_multiplier=t.get("timestep_scale_multiplier", 1000),
|
| 467 |
+
use_middle_indices_grid=t.get("use_middle_indices_grid", True),
|
| 468 |
+
rope_type=LTXRopeType(t.get("rope_type", "interleaved")),
|
| 469 |
+
double_precision_rope=t.get("frequencies_precision", False) == "float64",
|
| 470 |
+
apply_gated_attention=t.get("apply_gated_attention", False),
|
| 471 |
+
audio_caption_projection=cp,
|
| 472 |
+
cross_attention_adaln=t.get("cross_attention_adaln", False),
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
builder = Builder(
|
| 476 |
+
model_path=args.checkpoint,
|
| 477 |
+
model_class_configurator=AudioOnlyConfigurator,
|
| 478 |
+
model_sd_ops=audio_only_sd_ops,
|
| 479 |
+
registry=DummyRegistry(),
|
| 480 |
+
)
|
| 481 |
+
velocity_model = builder.build(device=device, dtype=dtype).to(device).eval()
|
| 482 |
+
|
| 483 |
+
# ---- Load LoRA weights (if provided) ----
|
| 484 |
+
if args.lora and os.path.exists(args.lora):
|
| 485 |
+
from peft import LoraConfig, get_peft_model
|
| 486 |
+
from safetensors.torch import load_file as st_load
|
| 487 |
+
|
| 488 |
+
logging.info(f"Loading LoRA: {args.lora}")
|
| 489 |
+
lora_sd = st_load(args.lora)
|
| 490 |
+
|
| 491 |
+
is_peft_format = any("base_model.model." in k for k in lora_sd.keys())
|
| 492 |
+
is_original_idlora = any("diffusion_model." in k for k in lora_sd.keys())
|
| 493 |
+
|
| 494 |
+
lora_config = LoraConfig(
|
| 495 |
+
r=args.lora_rank,
|
| 496 |
+
lora_alpha=args.lora_rank,
|
| 497 |
+
lora_dropout=0.0,
|
| 498 |
+
bias="none",
|
| 499 |
+
target_modules=[
|
| 500 |
+
"audio_attn1.to_k",
|
| 501 |
+
"audio_attn1.to_q",
|
| 502 |
+
"audio_attn1.to_v",
|
| 503 |
+
"audio_attn1.to_out.0",
|
| 504 |
+
"audio_attn2.to_k",
|
| 505 |
+
"audio_attn2.to_q",
|
| 506 |
+
"audio_attn2.to_v",
|
| 507 |
+
"audio_attn2.to_out.0",
|
| 508 |
+
"audio_ff.net.0.proj",
|
| 509 |
+
"audio_ff.net.2",
|
| 510 |
+
],
|
| 511 |
+
)
|
| 512 |
+
velocity_model = get_peft_model(velocity_model, lora_config)
|
| 513 |
+
|
| 514 |
+
if is_peft_format:
|
| 515 |
+
mapped_sd = {}
|
| 516 |
+
for k, v in lora_sd.items():
|
| 517 |
+
new_key = k
|
| 518 |
+
if ".lora_A.weight" in k and ".lora_A.default.weight" not in k:
|
| 519 |
+
new_key = k.replace(".lora_A.weight", ".lora_A.default.weight")
|
| 520 |
+
if ".lora_B.weight" in k and ".lora_B.default.weight" not in k:
|
| 521 |
+
new_key = k.replace(".lora_B.weight", ".lora_B.default.weight")
|
| 522 |
+
mapped_sd[new_key] = v
|
| 523 |
+
missing, unexpected = velocity_model.load_state_dict(mapped_sd, strict=False)
|
| 524 |
+
loaded = len(mapped_sd) - len(unexpected)
|
| 525 |
+
logging.info(f"Loaded {loaded} LoRA weights (peft format)")
|
| 526 |
+
elif is_original_idlora:
|
| 527 |
+
audio_keys = {
|
| 528 |
+
k: v
|
| 529 |
+
for k, v in lora_sd.items()
|
| 530 |
+
if "audio_attn1" in k or "audio_attn2" in k or "audio_ff" in k
|
| 531 |
+
}
|
| 532 |
+
mapped_sd = {}
|
| 533 |
+
for k, v in audio_keys.items():
|
| 534 |
+
new_key = k.replace("diffusion_model.", "base_model.model.")
|
| 535 |
+
new_key = new_key.replace(".lora_A.weight", ".lora_A.default.weight")
|
| 536 |
+
new_key = new_key.replace(".lora_B.weight", ".lora_B.default.weight")
|
| 537 |
+
mapped_sd[new_key] = v
|
| 538 |
+
missing, unexpected = velocity_model.load_state_dict(mapped_sd, strict=False)
|
| 539 |
+
loaded = len(mapped_sd) - len(unexpected)
|
| 540 |
+
logging.info(f"Loaded {loaded} LoRA weights (original ID-LoRA)")
|
| 541 |
+
|
| 542 |
+
velocity_model = velocity_model.merge_and_unload()
|
| 543 |
+
logging.info("Merged LoRA into model")
|
| 544 |
+
|
| 545 |
+
logging.info(f"Model: {sum(p.numel() for p in velocity_model.parameters()) / 1e9:.1f}B params")
|
| 546 |
+
|
| 547 |
+
# ---- Wrap velocity model in X0Model ----
|
| 548 |
+
x0_model = X0Model(velocity_model)
|
| 549 |
+
|
| 550 |
+
# ---- Build denoiser and sigmas ----
|
| 551 |
+
stepper = EulerDiffusionStep()
|
| 552 |
+
|
| 553 |
+
# ---- Sigma schedule ----
|
| 554 |
+
if is_distilled:
|
| 555 |
+
if args.steps is not None and args.steps > 0:
|
| 556 |
+
sigmas = LTX2Scheduler().execute(steps=args.steps, latent=noised_state.latent).to(device)
|
| 557 |
+
logging.info(f"Distilled with custom {args.steps}-step schedule")
|
| 558 |
+
else:
|
| 559 |
+
sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, dtype=torch.float32, device=device)
|
| 560 |
+
logging.info(f"Distilled {len(DISTILLED_SIGMA_VALUES) - 1}-step schedule")
|
| 561 |
+
else:
|
| 562 |
+
steps = args.steps if args.steps is not None and args.steps > 0 else 30
|
| 563 |
+
sigmas = LTX2Scheduler().execute(steps=steps, latent=noised_state.latent).to(device)
|
| 564 |
+
logging.info(f"Dev {steps}-step schedule")
|
| 565 |
+
|
| 566 |
+
# ---- Denoiser: use GuidedDenoiser if any guidance is active, SimpleDenoiser otherwise ----
|
| 567 |
+
needs_guidance = args.cfg_scale > 1.0 or args.stg_scale > 0.0 or args.modality_scale > 1.0
|
| 568 |
+
if needs_guidance:
|
| 569 |
+
audio_guider = MultiModalGuider(
|
| 570 |
+
params=MultiModalGuiderParams(
|
| 571 |
+
cfg_scale=args.cfg_scale,
|
| 572 |
+
stg_scale=args.stg_scale,
|
| 573 |
+
stg_blocks=[args.stg_block] if args.stg_scale > 0 else [],
|
| 574 |
+
rescale_scale=args.rescale_scale,
|
| 575 |
+
modality_scale=args.modality_scale,
|
| 576 |
+
cfg_clamp_scale=args.cfg_clamp,
|
| 577 |
+
),
|
| 578 |
+
negative_context=a_ctx_neg,
|
| 579 |
+
)
|
| 580 |
+
denoiser = GuidedDenoiser(
|
| 581 |
+
v_context=None,
|
| 582 |
+
a_context=a_ctx,
|
| 583 |
+
video_guider=None,
|
| 584 |
+
audio_guider=audio_guider,
|
| 585 |
+
)
|
| 586 |
+
logging.info(f"GuidedDenoiser: cfg={args.cfg_scale}, stg={args.stg_scale}, "
|
| 587 |
+
f"rescale={args.rescale_scale}, modality={args.modality_scale}")
|
| 588 |
+
else:
|
| 589 |
+
denoiser = SimpleDenoiser(v_context=None, a_context=a_ctx)
|
| 590 |
+
logging.info("SimpleDenoiser (no guidance)")
|
| 591 |
+
|
| 592 |
+
logging.info(f"Sigmas: {sigmas.tolist()}")
|
| 593 |
+
|
| 594 |
+
# ---- Denoising loop ----
|
| 595 |
+
logging.info(f"Running denoising loop ({len(sigmas) - 1} steps)...")
|
| 596 |
+
with gpu_model(x0_model) as model:
|
| 597 |
+
batched_model = BatchSplitAdapter(model, max_batch_size=1)
|
| 598 |
+
|
| 599 |
+
denoise_fn = heun_denoising_loop if args.sampler == "heun" else euler_denoising_loop
|
| 600 |
+
_, audio_state = denoise_fn(
|
| 601 |
+
sigmas=sigmas,
|
| 602 |
+
video_state=None,
|
| 603 |
+
audio_state=noised_state,
|
| 604 |
+
stepper=stepper,
|
| 605 |
+
transformer=batched_model,
|
| 606 |
+
denoiser=denoiser,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
del velocity_model, x0_model
|
| 610 |
+
torch.cuda.empty_cache()
|
| 611 |
+
|
| 612 |
+
# ---- Strip ref tokens and unpatchify ----
|
| 613 |
+
logging.info("Stripping conditioning and unpatchifying...")
|
| 614 |
+
audio_state = audio_tools.clear_conditioning(audio_state)
|
| 615 |
+
audio_state = audio_tools.unpatchify(audio_state)
|
| 616 |
+
logging.info(f"Final latent shape: {audio_state.latent.shape}")
|
| 617 |
+
|
| 618 |
+
# ---- End-of-clip silence-prior fix ----
|
| 619 |
+
# Base LTX-2.3 22B was trained on audio clips ≤ ~20 s and learned a strong
|
| 620 |
+
# "clip-end silence" prior at the next patchifier-aligned latent boundary
|
| 621 |
+
# (frame 513 = 8 × 64 + 1). For longer outputs that prior leaks through as
|
| 622 |
+
# a ~30 ms hard silence dip near 20.4 s. Linearly interpolating frames
|
| 623 |
+
# 512–513 between their neighbours (511 and 514) removes the dip cleanly.
|
| 624 |
+
latent_in = audio_state.latent
|
| 625 |
+
if latent_in.shape[2] > 513:
|
| 626 |
+
f0, f1 = 511, 514
|
| 627 |
+
n = f1 - f0
|
| 628 |
+
patched = latent_in.clone()
|
| 629 |
+
for f in (512, 513):
|
| 630 |
+
t = (f - f0) / n
|
| 631 |
+
patched[:, :, f, :] = (1.0 - t) * latent_in[:, :, f0, :] + t * latent_in[:, :, f1, :]
|
| 632 |
+
latent_in = patched
|
| 633 |
+
|
| 634 |
+
# ---- Decode audio ----
|
| 635 |
+
logging.info("Decoding audio...")
|
| 636 |
+
ad = AudioDecoder(checkpoint_path=args.full_checkpoint, dtype=dtype, device=device)
|
| 637 |
+
decoded = ad(latent_in)
|
| 638 |
+
del ad
|
| 639 |
+
torch.cuda.empty_cache()
|
| 640 |
+
|
| 641 |
+
wav = decoded.waveform
|
| 642 |
+
if wav.dim() == 1:
|
| 643 |
+
wav = wav.unsqueeze(0)
|
| 644 |
+
sr = decoded.sampling_rate
|
| 645 |
+
|
| 646 |
+
# Trim leading pad if --pad-start was used
|
| 647 |
+
if args.pad_start > 0:
|
| 648 |
+
trim_samples = int(args.pad_start * sr)
|
| 649 |
+
wav = wav[..., trim_samples:]
|
| 650 |
+
logging.info(f"Trimmed {args.pad_start}s ({trim_samples} samples) of start padding")
|
| 651 |
+
|
| 652 |
+
# Apply Perth (Perceptual Threshold) imperceptible neural watermark — see
|
| 653 |
+
# https://github.com/resemble-ai/perth. Mono waveform required; if stereo,
|
| 654 |
+
# we average to mono for the watermark and broadcast back. Skip on
|
| 655 |
+
# --no-watermark for debugging.
|
| 656 |
+
wav_cpu = wav.float().cpu()
|
| 657 |
+
if not getattr(args, "no_watermark", False):
|
| 658 |
+
try:
|
| 659 |
+
import perth
|
| 660 |
+
import numpy as np
|
| 661 |
+
wm = perth.PerthImplicitWatermarker()
|
| 662 |
+
mono = wav_cpu.mean(dim=0).numpy() if wav_cpu.shape[0] > 1 else wav_cpu[0].numpy()
|
| 663 |
+
mono_wm = wm.apply_watermark(mono, sample_rate=sr)
|
| 664 |
+
mono_wm_t = torch.from_numpy(np.asarray(mono_wm, dtype=np.float32)).unsqueeze(0)
|
| 665 |
+
wav_cpu = mono_wm_t if wav_cpu.shape[0] == 1 else mono_wm_t.repeat(wav_cpu.shape[0], 1)
|
| 666 |
+
except Exception as e:
|
| 667 |
+
logging.warning(f"Perth watermark skipped ({e})")
|
| 668 |
+
|
| 669 |
+
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
|
| 670 |
+
torchaudio.save(args.output, wav_cpu, sr)
|
| 671 |
+
|
| 672 |
+
elapsed = time.time() - t0
|
| 673 |
+
logging.info(f"Output: {args.output} ({wav.shape[-1] / sr:.1f}s)")
|
| 674 |
+
logging.info(f"Total time: {elapsed:.1f}s")
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
if __name__ == "__main__":
|
| 678 |
+
main()
|
dramabox_src/inference_server.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Warm TTS server — loads models once, accepts requests via stdin or function call.
|
| 4 |
+
|
| 5 |
+
The key insight: inference.py spends 11s on Gemma + 8s on model load every call.
|
| 6 |
+
This server loads everything once and keeps it warm.
|
| 7 |
+
|
| 8 |
+
We import and call the same code paths as inference.py but cache the heavy objects.
|
| 9 |
+
"""
|
| 10 |
+
import json
|
| 11 |
+
import logging
|
| 12 |
+
import os
|
| 13 |
+
import re
|
| 14 |
+
import sys
|
| 15 |
+
import time
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torchaudio
|
| 20 |
+
|
| 21 |
+
# Setup paths
|
| 22 |
+
APP_DIR = Path(__file__).parent.parent
|
| 23 |
+
sys.path.insert(0, str(APP_DIR / "ltx2"))
|
| 24 |
+
sys.path.insert(0, str(APP_DIR / "src"))
|
| 25 |
+
|
| 26 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 27 |
+
|
| 28 |
+
from audio_conditioning import AudioConditionByReferenceLatent
|
| 29 |
+
from ltx_core.components.noisers import GaussianNoiser
|
| 30 |
+
from ltx_core.components.patchifiers import AudioPatchifier
|
| 31 |
+
from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
|
| 32 |
+
from ltx_core.components.schedulers import LTX2Scheduler
|
| 33 |
+
from ltx_core.components.diffusion_steps import EulerDiffusionStep
|
| 34 |
+
from ltx_core.loader import DummyRegistry
|
| 35 |
+
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
|
| 36 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 37 |
+
from ltx_core.model.transformer.model import LTXModel, LTXModelType, X0Model
|
| 38 |
+
from ltx_core.model.transformer.rope import LTXRopeType
|
| 39 |
+
from ltx_core.model.transformer.text_projection import create_caption_projection
|
| 40 |
+
from ltx_core.model.transformer.attention import AttentionFunction
|
| 41 |
+
from ltx_core.model.model_protocol import ModelConfigurator
|
| 42 |
+
from ltx_core.tools import AudioLatentTools
|
| 43 |
+
from ltx_core.types import Audio, AudioLatentShape, VideoPixelShape
|
| 44 |
+
from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
|
| 45 |
+
from ltx_pipelines.utils.blocks import AudioConditioner, AudioDecoder, PromptEncoder
|
| 46 |
+
from ltx_pipelines.utils.media_io import decode_audio_from_file
|
| 47 |
+
from ltx_pipelines.utils.denoisers import GuidedDenoiser
|
| 48 |
+
from ltx_pipelines.utils.samplers import euler_denoising_loop
|
| 49 |
+
from safetensors import safe_open
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
DEFAULT_NEG = "worst quality, inconsistent, robotic, distorted, noise, static, muffled, unclear, unnatural, monotone"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def estimate_duration(prompt, multiplier=1.1):
|
| 56 |
+
"""Defer to the richer CLI estimator (sentence-aware + non-verbal action
|
| 57 |
+
budget) so warm-server outputs match the lengths of the per-call CLI runs."""
|
| 58 |
+
from inference import estimate_speech_duration
|
| 59 |
+
base = estimate_speech_duration(prompt)
|
| 60 |
+
return max(3.0, round(base * multiplier, 1))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def auto_rescale_for_cfg(cfg: float) -> float:
|
| 64 |
+
"""CFG-aware std-rescale schedule that prevents output clipping at high cfg.
|
| 65 |
+
|
| 66 |
+
The CFG formula `pred = cond + (cfg-1)*(cond - uncond)` makes pred.std()
|
| 67 |
+
grow roughly linearly with cfg, which the audio VAE+vocoder render as
|
| 68 |
+
progressively louder waveforms. By cfg≈3 the output starts hard-clipping
|
| 69 |
+
at 0 dBFS — and clipped information is unrecoverable in post.
|
| 70 |
+
|
| 71 |
+
Empirical sweep on the blues prompt with the back-porch-boogie ref
|
| 72 |
+
(rescale_scale needed for ≥1 dB peak headroom):
|
| 73 |
+
cfg=2.5 → 0.2 ; cfg=3 → 0.6 ; cfg=4 → 0.8 ; cfg=5–8 → 0.8 ; cfg=10 → 1.0
|
| 74 |
+
|
| 75 |
+
Piecewise-linear fit through those points; returns 0 below cfg=2 (no CFG
|
| 76 |
+
even applied at cfg=1), plateaus at 0.8 between cfg=4 and cfg=8 to
|
| 77 |
+
preserve the "extra punch" of high-CFG generations, and ramps to 1.0 by
|
| 78 |
+
cfg=10.
|
| 79 |
+
"""
|
| 80 |
+
if cfg <= 2.0:
|
| 81 |
+
return 0.0
|
| 82 |
+
if cfg <= 3.0:
|
| 83 |
+
return 0.6 * (cfg - 2.0) # 0 → 0.6
|
| 84 |
+
if cfg <= 4.0:
|
| 85 |
+
return 0.6 + 0.2 * (cfg - 3.0) # 0.6 → 0.8
|
| 86 |
+
if cfg <= 8.0:
|
| 87 |
+
return 0.8 # plateau
|
| 88 |
+
return min(1.0, 0.8 + 0.1 * (cfg - 8.0)) # 0.8 → 1.0 at cfg=10
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class TTSServer:
|
| 92 |
+
def __init__(self, checkpoint=None, full_checkpoint=None, gemma_root=None,
|
| 93 |
+
device="cuda", dtype="bf16", compile_model=True, bnb_4bit=True):
|
| 94 |
+
MODELS = APP_DIR / "models"
|
| 95 |
+
self.checkpoint = checkpoint or str(MODELS / "ltx-2.3-22b-dev-audio-only-v13-merged.safetensors")
|
| 96 |
+
self.full_checkpoint = full_checkpoint or os.environ.get(
|
| 97 |
+
"LTX_FULL_CHECKPOINT", "/mnt/persistent0/manmay/models/ltx23/ltx-2.3-22b-dev.safetensors")
|
| 98 |
+
if gemma_root is None and not os.environ.get("GEMMA_DIR"):
|
| 99 |
+
from model_downloader import get_gemma_path
|
| 100 |
+
gemma_root = get_gemma_path()
|
| 101 |
+
self.gemma_root = gemma_root or os.environ["GEMMA_DIR"]
|
| 102 |
+
self.device = torch.device(device)
|
| 103 |
+
self.dtype = torch.float16 if dtype == "fp16" else torch.bfloat16
|
| 104 |
+
self.compile_model = compile_model
|
| 105 |
+
self.bnb_4bit = bnb_4bit
|
| 106 |
+
self.patchifier = AudioPatchifier(patch_size=1)
|
| 107 |
+
|
| 108 |
+
# Cached models
|
| 109 |
+
self._prompt_encoder = None
|
| 110 |
+
self._velocity_model = None
|
| 111 |
+
self._audio_conditioner = None
|
| 112 |
+
self._audio_decoder = None
|
| 113 |
+
|
| 114 |
+
logging.info(f"TTSServer loading on {device}...")
|
| 115 |
+
t0 = time.time()
|
| 116 |
+
self._load_all()
|
| 117 |
+
logging.info(f"All models loaded in {time.time()-t0:.1f}s — ready for requests")
|
| 118 |
+
|
| 119 |
+
def _load_all(self):
|
| 120 |
+
# 1. Prompt encoder (Gemma + embeddings processor kept warm)
|
| 121 |
+
t0 = time.time()
|
| 122 |
+
self._prompt_encoder = PromptEncoder(
|
| 123 |
+
checkpoint_path=self.full_checkpoint,
|
| 124 |
+
gemma_root=self.gemma_root,
|
| 125 |
+
dtype=self.dtype, device=self.device,
|
| 126 |
+
warm=True,
|
| 127 |
+
use_bnb_4bit=self.bnb_4bit,
|
| 128 |
+
audio_only=True,
|
| 129 |
+
)
|
| 130 |
+
logging.info(f" PromptEncoder (warm): {time.time()-t0:.1f}s")
|
| 131 |
+
|
| 132 |
+
# 2. Audio conditioner (VAE encoder kept warm)
|
| 133 |
+
t0 = time.time()
|
| 134 |
+
self._audio_conditioner = AudioConditioner(
|
| 135 |
+
checkpoint_path=self.full_checkpoint,
|
| 136 |
+
dtype=self.dtype, device=self.device,
|
| 137 |
+
warm=True,
|
| 138 |
+
)
|
| 139 |
+
logging.info(f" AudioConditioner (warm): {time.time()-t0:.1f}s")
|
| 140 |
+
|
| 141 |
+
# 3. Transformer
|
| 142 |
+
t0 = time.time()
|
| 143 |
+
with safe_open(self.checkpoint, framework="pt") as f:
|
| 144 |
+
config = json.loads(f.metadata()["config"])
|
| 145 |
+
|
| 146 |
+
t = config.get("transformer", {})
|
| 147 |
+
|
| 148 |
+
class AudioOnlyConfigurator(ModelConfigurator[LTXModel]):
|
| 149 |
+
@classmethod
|
| 150 |
+
def from_config(cls, cfg):
|
| 151 |
+
t = cfg.get("transformer", {})
|
| 152 |
+
cp = None
|
| 153 |
+
if not t.get("caption_proj_before_connector", False):
|
| 154 |
+
with torch.device("meta"):
|
| 155 |
+
cp = create_caption_projection(t, audio=True)
|
| 156 |
+
return LTXModel(
|
| 157 |
+
model_type=LTXModelType.AudioOnly,
|
| 158 |
+
audio_num_attention_heads=t.get("audio_num_attention_heads", 32),
|
| 159 |
+
audio_attention_head_dim=t.get("audio_attention_head_dim", 64),
|
| 160 |
+
audio_in_channels=t.get("audio_in_channels", 128),
|
| 161 |
+
audio_out_channels=t.get("audio_out_channels", 128),
|
| 162 |
+
num_layers=t.get("num_layers", 48),
|
| 163 |
+
audio_cross_attention_dim=t.get("audio_cross_attention_dim", 2048),
|
| 164 |
+
norm_eps=t.get("norm_eps", 1e-6),
|
| 165 |
+
attention_type=AttentionFunction(t.get("attention_type", "default")),
|
| 166 |
+
positional_embedding_theta=10000.0,
|
| 167 |
+
audio_positional_embedding_max_pos=[20.0],
|
| 168 |
+
timestep_scale_multiplier=t.get("timestep_scale_multiplier", 1000),
|
| 169 |
+
use_middle_indices_grid=t.get("use_middle_indices_grid", True),
|
| 170 |
+
rope_type=LTXRopeType(t.get("rope_type", "interleaved")),
|
| 171 |
+
double_precision_rope=t.get("frequencies_precision", False) == "float64",
|
| 172 |
+
apply_gated_attention=t.get("apply_gated_attention", False),
|
| 173 |
+
audio_caption_projection=cp,
|
| 174 |
+
cross_attention_adaln=t.get("cross_attention_adaln", False),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
audio_sd_ops = SDOps("AO").with_matching(prefix="model.diffusion_model.").with_replacement(
|
| 178 |
+
"model.diffusion_model.", "")
|
| 179 |
+
builder = Builder(
|
| 180 |
+
model_path=self.checkpoint,
|
| 181 |
+
model_class_configurator=AudioOnlyConfigurator,
|
| 182 |
+
model_sd_ops=audio_sd_ops,
|
| 183 |
+
registry=DummyRegistry(),
|
| 184 |
+
)
|
| 185 |
+
self._velocity_model = builder.build(device=self.device, dtype=self.dtype).to(self.device).eval()
|
| 186 |
+
n_params = sum(p.numel() for p in self._velocity_model.parameters()) / 1e9
|
| 187 |
+
vram_gb = sum(p.numel() * p.element_size() for p in self._velocity_model.parameters()) / 1e9
|
| 188 |
+
logging.info(f" Transformer: {time.time()-t0:.1f}s ({n_params:.1f}B params, {vram_gb:.1f}GB VRAM, {self.dtype})")
|
| 189 |
+
|
| 190 |
+
# torch.compile for faster denoising
|
| 191 |
+
if self.compile_model:
|
| 192 |
+
t0 = time.time()
|
| 193 |
+
logging.info(" Compiling transformer with torch.compile (default mode)...")
|
| 194 |
+
self._velocity_model = torch.compile(self._velocity_model, mode="default", dynamic=True)
|
| 195 |
+
logging.info(f" Compiled: {time.time()-t0:.1f}s (first call triggers actual compilation)")
|
| 196 |
+
|
| 197 |
+
# 4. Audio decoder (VAE decoder + vocoder kept warm)
|
| 198 |
+
t0 = time.time()
|
| 199 |
+
self._audio_decoder = AudioDecoder(
|
| 200 |
+
checkpoint_path=self.full_checkpoint,
|
| 201 |
+
dtype=self.dtype, device=self.device,
|
| 202 |
+
warm=True,
|
| 203 |
+
)
|
| 204 |
+
logging.info(f" AudioDecoder (warm): {time.time()-t0:.1f}s")
|
| 205 |
+
|
| 206 |
+
@torch.inference_mode()
|
| 207 |
+
def generate(self, prompt, voice_ref=None, cfg_scale=2.5, stg_scale=1.5,
|
| 208 |
+
duration_multiplier=1.1, seed=42, ref_duration=10.0,
|
| 209 |
+
rescale_scale="auto", gen_duration: float = 0.0):
|
| 210 |
+
"""Generate audio. Returns (waveform_path, duration_seconds).
|
| 211 |
+
|
| 212 |
+
rescale_scale: latent-side CFG std-rescale that prevents clipping at
|
| 213 |
+
high cfg. Set to "auto" (default) for the cfg-aware schedule, a
|
| 214 |
+
float in [0, 1] for a fixed override, or 0 to disable.
|
| 215 |
+
gen_duration: explicit target duration in seconds. 0 (default) → auto
|
| 216 |
+
from prompt + duration_multiplier; >0 overrides everything else.
|
| 217 |
+
"""
|
| 218 |
+
t_total = time.time()
|
| 219 |
+
|
| 220 |
+
# Duration + target shape — explicit gen_duration wins over the estimator.
|
| 221 |
+
if gen_duration and gen_duration > 0:
|
| 222 |
+
gen_dur = float(gen_duration)
|
| 223 |
+
else:
|
| 224 |
+
gen_dur = estimate_duration(prompt, duration_multiplier)
|
| 225 |
+
fps = 25.0
|
| 226 |
+
n_frames = int(round(gen_dur * fps)) + 1
|
| 227 |
+
n_frames = ((n_frames - 1 + 4) // 8) * 8 + 1
|
| 228 |
+
pixel_shape = VideoPixelShape(batch=1, frames=n_frames, height=64, width=64, fps=fps)
|
| 229 |
+
target_shape = AudioLatentShape.from_video_pixel_shape(pixel_shape)
|
| 230 |
+
audio_tools = AudioLatentTools(patchifier=self.patchifier, target_shape=target_shape)
|
| 231 |
+
|
| 232 |
+
# Initial state
|
| 233 |
+
state = audio_tools.create_initial_state(device=self.device, dtype=self.dtype)
|
| 234 |
+
|
| 235 |
+
# Voice ref conditioning
|
| 236 |
+
if voice_ref and os.path.exists(voice_ref):
|
| 237 |
+
t0 = time.time()
|
| 238 |
+
voice = decode_audio_from_file(voice_ref, self.device, 0.0, ref_duration)
|
| 239 |
+
w = voice.waveform
|
| 240 |
+
if w.dim() == 2:
|
| 241 |
+
if w.shape[0] == 1:
|
| 242 |
+
w = w.repeat(2, 1)
|
| 243 |
+
w = w.unsqueeze(0)
|
| 244 |
+
elif w.dim() == 3 and w.shape[1] == 1:
|
| 245 |
+
w = w.repeat(1, 2, 1)
|
| 246 |
+
target_samples = int(ref_duration * voice.sampling_rate)
|
| 247 |
+
if w.shape[-1] < target_samples:
|
| 248 |
+
w = w.repeat(1, 1, (target_samples // w.shape[-1]) + 1)
|
| 249 |
+
w = w[..., :target_samples]
|
| 250 |
+
peak = w.abs().max()
|
| 251 |
+
if peak > 0:
|
| 252 |
+
w = w * (10 ** (-4.0 / 20) / peak)
|
| 253 |
+
voice = Audio(waveform=w, sampling_rate=voice.sampling_rate)
|
| 254 |
+
ref_latent = self._audio_conditioner(lambda enc: vae_encode_audio(voice, enc, None))
|
| 255 |
+
cond = AudioConditionByReferenceLatent(latent=ref_latent.to(self.device, self.dtype), strength=1.0)
|
| 256 |
+
state = cond.apply_to(state, audio_tools)
|
| 257 |
+
logging.info(f"Voice ref: {time.time()-t0:.2f}s")
|
| 258 |
+
|
| 259 |
+
# Noise
|
| 260 |
+
gen = torch.Generator(device=self.device).manual_seed(seed)
|
| 261 |
+
noiser = GaussianNoiser(generator=gen)
|
| 262 |
+
state = noiser(state, noise_scale=1.0)
|
| 263 |
+
|
| 264 |
+
# Prompt encode
|
| 265 |
+
t0 = time.time()
|
| 266 |
+
prompts = [prompt, DEFAULT_NEG] if cfg_scale > 1.0 else [prompt]
|
| 267 |
+
ctx = self._prompt_encoder(prompts, streaming_prefetch_count=None)
|
| 268 |
+
a_ctx = ctx[0].audio_encoding
|
| 269 |
+
a_ctx_neg = ctx[1].audio_encoding if cfg_scale > 1.0 else None
|
| 270 |
+
logging.info(f"Prompt: {time.time()-t0:.2f}s")
|
| 271 |
+
|
| 272 |
+
# Denoiser
|
| 273 |
+
resc = auto_rescale_for_cfg(cfg_scale) if rescale_scale == "auto" else float(rescale_scale)
|
| 274 |
+
if rescale_scale == "auto":
|
| 275 |
+
logging.info(f"Auto rescale_scale = {resc:.2f} for cfg={cfg_scale}")
|
| 276 |
+
guider = MultiModalGuider(
|
| 277 |
+
params=MultiModalGuiderParams(
|
| 278 |
+
cfg_scale=cfg_scale, stg_scale=stg_scale,
|
| 279 |
+
stg_blocks=[29], rescale_scale=resc, modality_scale=1.0,
|
| 280 |
+
),
|
| 281 |
+
negative_context=a_ctx_neg,
|
| 282 |
+
)
|
| 283 |
+
denoiser = GuidedDenoiser(
|
| 284 |
+
v_context=None, a_context=a_ctx,
|
| 285 |
+
video_guider=None, audio_guider=guider,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Sigmas
|
| 289 |
+
sigmas = LTX2Scheduler().execute(steps=30, latent=state.latent).to(self.device)
|
| 290 |
+
|
| 291 |
+
# Denoise
|
| 292 |
+
t0 = time.time()
|
| 293 |
+
x0 = X0Model(self._velocity_model)
|
| 294 |
+
_, audio_state = euler_denoising_loop(
|
| 295 |
+
sigmas=sigmas, video_state=None, audio_state=state,
|
| 296 |
+
stepper=EulerDiffusionStep(), transformer=x0, denoiser=denoiser,
|
| 297 |
+
)
|
| 298 |
+
logging.info(f"Denoise (30 steps): {time.time()-t0:.2f}s")
|
| 299 |
+
|
| 300 |
+
# Strip + unpatchify + decode
|
| 301 |
+
audio_state = audio_tools.clear_conditioning(audio_state)
|
| 302 |
+
audio_state = audio_tools.unpatchify(audio_state)
|
| 303 |
+
|
| 304 |
+
# End-of-clip silence-prior fix.
|
| 305 |
+
# The base LTX-2.3 22B DiT was trained on audio clips ≤ ~20 s and
|
| 306 |
+
# learned a strong "clip-end silence" prior that lands on the next
|
| 307 |
+
# patchifier-aligned latent frame after 20 s — index 513 = 8*64+1.
|
| 308 |
+
# When inference produces longer audio, this prior leaks through as a
|
| 309 |
+
# high-norm latent burst at frame 513 (and adjacent 512), which the
|
| 310 |
+
# audio VAE + vocoder render as a ~30 ms hard silence dip near 20.4 s.
|
| 311 |
+
# Linear interpolation across the two affected frames removes the dip
|
| 312 |
+
# cleanly without any retraining. Only runs when the latent is long
|
| 313 |
+
# enough to actually contain the boundary.
|
| 314 |
+
latent = audio_state.latent
|
| 315 |
+
if latent.shape[2] > 513:
|
| 316 |
+
f0, f1 = 511, 514 # neighbours used for interpolation
|
| 317 |
+
n = f1 - f0 # = 3
|
| 318 |
+
patched = latent.clone()
|
| 319 |
+
for f in (512, 513):
|
| 320 |
+
t = (f - f0) / n
|
| 321 |
+
patched[:, :, f, :] = (1.0 - t) * latent[:, :, f0, :] + t * latent[:, :, f1, :]
|
| 322 |
+
latent = patched
|
| 323 |
+
|
| 324 |
+
t0 = time.time()
|
| 325 |
+
decoded = self._audio_decoder(latent)
|
| 326 |
+
logging.info(f"Decode: {time.time()-t0:.2f}s")
|
| 327 |
+
|
| 328 |
+
total = time.time() - t_total
|
| 329 |
+
dur = decoded.waveform.shape[-1] / decoded.sampling_rate
|
| 330 |
+
logging.info(f"Total: {total:.2f}s for {dur:.1f}s audio")
|
| 331 |
+
return decoded.waveform, decoded.sampling_rate
|
| 332 |
+
|
| 333 |
+
def generate_to_file(self, prompt, output, watermark: bool = True, **kwargs):
|
| 334 |
+
waveform, sr = self.generate(prompt, **kwargs)
|
| 335 |
+
wav_cpu = waveform.cpu().float()
|
| 336 |
+
if watermark:
|
| 337 |
+
try:
|
| 338 |
+
import numpy as np, perth
|
| 339 |
+
if not hasattr(self, "_perth"):
|
| 340 |
+
self._perth = perth.PerthImplicitWatermarker()
|
| 341 |
+
mono = wav_cpu.mean(dim=0).numpy() if wav_cpu.shape[0] > 1 else wav_cpu[0].numpy()
|
| 342 |
+
mono_wm = self._perth.apply_watermark(mono, sample_rate=sr)
|
| 343 |
+
mono_wm_t = torch.from_numpy(np.asarray(mono_wm, dtype=np.float32)).unsqueeze(0)
|
| 344 |
+
wav_cpu = mono_wm_t if wav_cpu.shape[0] == 1 else mono_wm_t.repeat(wav_cpu.shape[0], 1)
|
| 345 |
+
except Exception as e:
|
| 346 |
+
logging.warning(f"Perth watermark skipped ({e})")
|
| 347 |
+
torchaudio.save(output, wav_cpu, sr)
|
| 348 |
+
logging.info(f"Saved: {output}")
|
| 349 |
+
return output
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
if __name__ == "__main__":
|
| 353 |
+
import argparse
|
| 354 |
+
p = argparse.ArgumentParser()
|
| 355 |
+
p.add_argument("--device", default="cuda")
|
| 356 |
+
p.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
|
| 357 |
+
p.add_argument("--no-compile", action="store_true")
|
| 358 |
+
p.add_argument("--no-bnb-4bit", action="store_true",
|
| 359 |
+
help="Disable bitsandbytes 4-bit path (default: on, since the default "
|
| 360 |
+
"unsloth Gemma checkpoint is pre-quantized).")
|
| 361 |
+
args = p.parse_args()
|
| 362 |
+
|
| 363 |
+
server = TTSServer(device=args.device, dtype=args.dtype, compile_model=not args.no_compile,
|
| 364 |
+
bnb_4bit=not args.no_bnb_4bit)
|
| 365 |
+
|
| 366 |
+
# First call - includes any warmup
|
| 367 |
+
logging.info("=== First request ===")
|
| 368 |
+
server.generate_to_file(
|
| 369 |
+
prompt='A woman speaks clearly, "The weather today will be sunny."',
|
| 370 |
+
output="/tmp/warm_test1.wav",
|
| 371 |
+
voice_ref="/mnt/persistent0/manmay/expressive/female_radio_nikole/female_radio_nikole.wav",
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
# Second call - should be much faster (models already warm)
|
| 375 |
+
logging.info("\n=== Second request (warm) ===")
|
| 376 |
+
server.generate_to_file(
|
| 377 |
+
prompt='A man speaks excitedly, "This is amazing, I cannot believe it!"',
|
| 378 |
+
output="/tmp/warm_test2.wav",
|
| 379 |
+
voice_ref="/mnt/persistent0/manmay/expressive/male_arnie/male_arnie.mp3",
|
| 380 |
+
)
|
dramabox_src/model_downloader.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Download Dramabox models from HuggingFace.
|
| 4 |
+
|
| 5 |
+
Models are cached locally after first download.
|
| 6 |
+
Gemma text encoder is fetched separately from Google's repo.
|
| 7 |
+
"""
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
DRAMABOX_REPO = "ResembleAI/Dramabox"
|
| 17 |
+
GEMMA_REPO = "unsloth/gemma-3-12b-it-bnb-4bit"
|
| 18 |
+
|
| 19 |
+
# Default cache directory
|
| 20 |
+
DEFAULT_CACHE = os.environ.get(
|
| 21 |
+
"DRAMABOX_CACHE",
|
| 22 |
+
os.path.join(os.path.expanduser("~"), ".cache", "dramabox"),
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Model files in the HF repo (flat structure)
|
| 26 |
+
MODEL_FILES = {
|
| 27 |
+
"transformer": "dramabox-dit-v1.safetensors",
|
| 28 |
+
"audio_components": "dramabox-audio-components.safetensors",
|
| 29 |
+
"silence_latent": "assets/silence_latent_frame.pt",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_model_path(name: str, cache_dir: str = None) -> str:
|
| 34 |
+
"""Download a model file from HF and return local path.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
name: One of 'transformer', 'audio_components', 'silence_latent'
|
| 38 |
+
cache_dir: Local cache directory (default: ~/.cache/dramabox)
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Local file path
|
| 42 |
+
"""
|
| 43 |
+
cache_dir = cache_dir or DEFAULT_CACHE
|
| 44 |
+
|
| 45 |
+
if name not in MODEL_FILES:
|
| 46 |
+
raise ValueError(f"Unknown model: {name}. Choose from: {list(MODEL_FILES.keys())}")
|
| 47 |
+
|
| 48 |
+
repo_path = MODEL_FILES[name]
|
| 49 |
+
logger.info(f"Fetching {name} from {DRAMABOX_REPO}/{repo_path}...")
|
| 50 |
+
|
| 51 |
+
local_path = hf_hub_download(
|
| 52 |
+
repo_id=DRAMABOX_REPO,
|
| 53 |
+
filename=repo_path,
|
| 54 |
+
cache_dir=cache_dir,
|
| 55 |
+
token=os.environ.get("HF_TOKEN"),
|
| 56 |
+
)
|
| 57 |
+
logger.info(f" -> {local_path}")
|
| 58 |
+
return local_path
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_gemma_path(cache_dir: str = None) -> str:
|
| 62 |
+
"""Download Gemma 3 12B IT (pre-quantized bnb-4bit via unsloth) and return
|
| 63 |
+
the snapshot directory. Using the pre-quantized variant skips runtime
|
| 64 |
+
bitsandbytes quantization and ~halves the Gemma load time.
|
| 65 |
+
"""
|
| 66 |
+
cache_dir = cache_dir or DEFAULT_CACHE
|
| 67 |
+
logger.info(f"Fetching Gemma from {GEMMA_REPO}...")
|
| 68 |
+
|
| 69 |
+
local_dir = snapshot_download(
|
| 70 |
+
repo_id=GEMMA_REPO,
|
| 71 |
+
cache_dir=cache_dir,
|
| 72 |
+
token=os.environ.get("HF_TOKEN"),
|
| 73 |
+
)
|
| 74 |
+
logger.info(f" -> {local_dir}")
|
| 75 |
+
return local_dir
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_all_paths(cache_dir: str = None) -> dict:
|
| 79 |
+
"""Download all required models and return paths dict.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
{
|
| 83 |
+
'transformer': '/path/to/transformer.safetensors',
|
| 84 |
+
'audio_components': '/path/to/audio-components.safetensors',
|
| 85 |
+
'silence_latent': '/path/to/silence_latent_frame.pt',
|
| 86 |
+
'gemma_root': '/path/to/unsloth/gemma-3-12b-it-bnb-4bit/',
|
| 87 |
+
}
|
| 88 |
+
"""
|
| 89 |
+
cache_dir = cache_dir or DEFAULT_CACHE
|
| 90 |
+
paths = {}
|
| 91 |
+
|
| 92 |
+
for name in MODEL_FILES:
|
| 93 |
+
paths[name] = get_model_path(name, cache_dir)
|
| 94 |
+
|
| 95 |
+
paths["gemma_root"] = get_gemma_path(cache_dir)
|
| 96 |
+
return paths
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
logging.basicConfig(level=logging.INFO)
|
| 101 |
+
paths = get_all_paths()
|
| 102 |
+
print("\nAll models downloaded:")
|
| 103 |
+
for k, v in paths.items():
|
| 104 |
+
size = os.path.getsize(v) / 1e9 if os.path.isfile(v) else "dir"
|
| 105 |
+
print(f" {k}: {v} ({size:.2f}GB)" if isinstance(size, float) else f" {k}: {v} (directory)")
|
dramabox_src/preprocess.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Preprocess TTS datasets for LTX-2.3 audio-only LoRA fine-tuning.
|
| 4 |
+
|
| 5 |
+
Takes paired (audio, transcript) data and produces the format expected by
|
| 6 |
+
the LTX trainer:
|
| 7 |
+
.precomputed/
|
| 8 |
+
├── latents/sample_N.pt # Dummy video latents (minimal)
|
| 9 |
+
├── conditions/sample_N.pt # Text embeddings from Gemma
|
| 10 |
+
└── audio_latents/sample_N.pt # Audio VAE-encoded latents
|
| 11 |
+
|
| 12 |
+
Supports multiple dataset formats:
|
| 13 |
+
- gemini_synthetic: index.txt with ~-separated fields (id~speaker~lang~sr~samples~dur~phonemes~text)
|
| 14 |
+
- libriheavy: index_ft.txt with ~-separated fields (id~speaker~lang~samples~dur~phonemes~text)
|
| 15 |
+
- manifest: JSON/JSONL with {"audio_filepath": ..., "text": ...}
|
| 16 |
+
- tsv: TSV file with audio_path<TAB>text columns
|
| 17 |
+
|
| 18 |
+
Usage:
|
| 19 |
+
python preprocess_tts_data.py \
|
| 20 |
+
--dataset-type gemini_synthetic \
|
| 21 |
+
--index /mnt/large-datasets/gemini_synthetic_dataset/conversational_dataset_pp/index.txt \
|
| 22 |
+
--audio-dir /mnt/large-datasets/gemini_synthetic_dataset/conversational_dataset_pp/wavs \
|
| 23 |
+
--output-dir /mnt/persistent0/manmay/tts_training_data \
|
| 24 |
+
--max-samples 10000 \
|
| 25 |
+
--max-duration 20.0 \
|
| 26 |
+
--min-duration 3.0
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import argparse
|
| 30 |
+
import json
|
| 31 |
+
import logging
|
| 32 |
+
import os
|
| 33 |
+
import sys
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
import torchaudio
|
| 38 |
+
|
| 39 |
+
REPO_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 40 |
+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ltx2"))
|
| 41 |
+
# ltx-pipelines on path via ltx2/
|
| 42 |
+
|
| 43 |
+
MODEL_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 44 |
+
GEMMA_DIR = os.environ.get("GEMMA_DIR", "gemma-3-12b-it-qat-q4_0-unquantized")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def parse_args():
|
| 48 |
+
p = argparse.ArgumentParser(description="Preprocess TTS data for LTX-2.3 fine-tuning")
|
| 49 |
+
p.add_argument("--dataset-type", required=True,
|
| 50 |
+
choices=["gemini_synthetic", "libriheavy", "manifest", "tsv"],
|
| 51 |
+
help="Dataset format type")
|
| 52 |
+
p.add_argument("--index", required=True, help="Path to index/manifest file")
|
| 53 |
+
p.add_argument("--audio-dir", default=None,
|
| 54 |
+
help="Base directory for audio files (if paths in index are relative)")
|
| 55 |
+
p.add_argument("--output-dir", required=True, help="Output directory for preprocessed data")
|
| 56 |
+
p.add_argument("--checkpoint", default=os.path.join(MODEL_DIR, "ltx-2.3-22b-distilled.safetensors"))
|
| 57 |
+
p.add_argument("--gemma-root", default=GEMMA_DIR)
|
| 58 |
+
p.add_argument("--max-samples", type=int, default=0, help="Max samples to process (0=all)")
|
| 59 |
+
p.add_argument("--max-duration", type=float, default=20.0, help="Max audio duration in seconds")
|
| 60 |
+
p.add_argument("--min-duration", type=float, default=2.0, help="Min audio duration in seconds")
|
| 61 |
+
p.add_argument("--batch-size", type=int, default=8, help="Batch size for text encoding")
|
| 62 |
+
p.add_argument("--skip-existing", action="store_true", help="Skip already processed samples")
|
| 63 |
+
p.add_argument("--audio-only-ckpt", default=None,
|
| 64 |
+
help="Audio-only checkpoint for VAE encoding (optional, uses full ckpt if not set)")
|
| 65 |
+
p.add_argument("--shard", type=int, default=0, help="Shard index (for parallel processing)")
|
| 66 |
+
p.add_argument("--num-shards", type=int, default=1, help="Total number of shards")
|
| 67 |
+
p.add_argument("--gpu", type=int, default=None, help="GPU device index to use")
|
| 68 |
+
return p.parse_args()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def parse_gemini_synthetic(index_path: str, audio_dir: str | None) -> list[dict]:
|
| 72 |
+
"""Parse gemini_synthetic format: id~speaker~lang~sr~samples~dur~phonemes~text"""
|
| 73 |
+
samples = []
|
| 74 |
+
with open(index_path) as f:
|
| 75 |
+
for line in f:
|
| 76 |
+
parts = line.strip().split("~")
|
| 77 |
+
if len(parts) < 7:
|
| 78 |
+
continue
|
| 79 |
+
file_id = parts[0]
|
| 80 |
+
text = parts[-1] # Last field is always the text
|
| 81 |
+
sr = int(parts[3])
|
| 82 |
+
n_samples = int(parts[4])
|
| 83 |
+
duration = n_samples / sr
|
| 84 |
+
|
| 85 |
+
# Find audio file
|
| 86 |
+
if audio_dir:
|
| 87 |
+
# Try common extensions
|
| 88 |
+
for ext in [".flac", ".wav", ".mp3"]:
|
| 89 |
+
audio_path = os.path.join(audio_dir, file_id + ext)
|
| 90 |
+
if os.path.exists(audio_path):
|
| 91 |
+
break
|
| 92 |
+
else:
|
| 93 |
+
continue
|
| 94 |
+
else:
|
| 95 |
+
audio_path = file_id
|
| 96 |
+
|
| 97 |
+
samples.append({
|
| 98 |
+
"id": file_id,
|
| 99 |
+
"audio_path": audio_path,
|
| 100 |
+
"text": text,
|
| 101 |
+
"duration": duration,
|
| 102 |
+
})
|
| 103 |
+
return samples
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def parse_libriheavy(index_path: str, audio_dir: str | None) -> list[dict]:
|
| 107 |
+
"""Parse libriheavy format: id~speaker~lang~samples~dur~phonemes~text"""
|
| 108 |
+
samples = []
|
| 109 |
+
with open(index_path) as f:
|
| 110 |
+
for line in f:
|
| 111 |
+
parts = line.strip().split("~")
|
| 112 |
+
if len(parts) < 7:
|
| 113 |
+
continue
|
| 114 |
+
file_id = parts[0]
|
| 115 |
+
text = parts[-1]
|
| 116 |
+
n_samples = int(parts[3])
|
| 117 |
+
duration = int(parts[4]) / 1000.0 # milliseconds to seconds
|
| 118 |
+
|
| 119 |
+
if audio_dir:
|
| 120 |
+
for ext in [".flac", ".wav", ".mp3"]:
|
| 121 |
+
audio_path = os.path.join(audio_dir, file_id + ext)
|
| 122 |
+
if os.path.exists(audio_path):
|
| 123 |
+
break
|
| 124 |
+
else:
|
| 125 |
+
continue
|
| 126 |
+
else:
|
| 127 |
+
audio_path = file_id
|
| 128 |
+
|
| 129 |
+
samples.append({
|
| 130 |
+
"id": file_id,
|
| 131 |
+
"audio_path": audio_path,
|
| 132 |
+
"text": text,
|
| 133 |
+
"duration": duration,
|
| 134 |
+
})
|
| 135 |
+
return samples
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def parse_manifest(index_path: str, audio_dir: str | None) -> list[dict]:
|
| 139 |
+
"""Parse JSON/JSONL manifest with audio_filepath and text fields."""
|
| 140 |
+
samples = []
|
| 141 |
+
with open(index_path) as f:
|
| 142 |
+
for line in f:
|
| 143 |
+
entry = json.loads(line.strip())
|
| 144 |
+
audio_path = entry.get("audio_filepath", entry.get("audio_path", ""))
|
| 145 |
+
text = entry.get("text", entry.get("transcript", ""))
|
| 146 |
+
duration = entry.get("duration", 0.0)
|
| 147 |
+
|
| 148 |
+
if audio_dir and not os.path.isabs(audio_path):
|
| 149 |
+
audio_path = os.path.join(audio_dir, audio_path)
|
| 150 |
+
|
| 151 |
+
if os.path.exists(audio_path) and text:
|
| 152 |
+
samples.append({
|
| 153 |
+
"id": Path(audio_path).stem,
|
| 154 |
+
"audio_path": audio_path,
|
| 155 |
+
"text": text,
|
| 156 |
+
"duration": duration,
|
| 157 |
+
})
|
| 158 |
+
return samples
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def parse_tsv(index_path: str, audio_dir: str | None) -> list[dict]:
|
| 162 |
+
"""Parse TSV file with audio_path<TAB>text."""
|
| 163 |
+
samples = []
|
| 164 |
+
with open(index_path) as f:
|
| 165 |
+
for line in f:
|
| 166 |
+
parts = line.strip().split("\t")
|
| 167 |
+
if len(parts) < 2:
|
| 168 |
+
continue
|
| 169 |
+
audio_path, text = parts[0], parts[1]
|
| 170 |
+
if audio_dir and not os.path.isabs(audio_path):
|
| 171 |
+
audio_path = os.path.join(audio_dir, audio_path)
|
| 172 |
+
if os.path.exists(audio_path):
|
| 173 |
+
samples.append({
|
| 174 |
+
"id": Path(audio_path).stem,
|
| 175 |
+
"audio_path": audio_path,
|
| 176 |
+
"text": text,
|
| 177 |
+
"duration": 0.0,
|
| 178 |
+
})
|
| 179 |
+
return samples
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
PARSERS = {
|
| 183 |
+
"gemini_synthetic": parse_gemini_synthetic,
|
| 184 |
+
"libriheavy": parse_libriheavy,
|
| 185 |
+
"manifest": parse_manifest,
|
| 186 |
+
"tsv": parse_tsv,
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@torch.inference_mode()
|
| 191 |
+
def main():
|
| 192 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 193 |
+
args = parse_args()
|
| 194 |
+
|
| 195 |
+
from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
|
| 196 |
+
from ltx_core.types import Audio
|
| 197 |
+
from ltx_pipelines.utils.blocks import AudioConditioner
|
| 198 |
+
from ltx_pipelines.utils.media_io import decode_audio_from_file
|
| 199 |
+
from ltx_trainer.model_loader import load_text_encoder, load_embeddings_processor
|
| 200 |
+
|
| 201 |
+
if args.gpu is not None:
|
| 202 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
|
| 203 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 204 |
+
dtype = torch.bfloat16
|
| 205 |
+
|
| 206 |
+
# Create output directories
|
| 207 |
+
out = Path(args.output_dir)
|
| 208 |
+
(out / "latents").mkdir(parents=True, exist_ok=True)
|
| 209 |
+
(out / "conditions").mkdir(parents=True, exist_ok=True)
|
| 210 |
+
(out / "audio_latents").mkdir(parents=True, exist_ok=True)
|
| 211 |
+
|
| 212 |
+
# Parse dataset
|
| 213 |
+
logging.info(f"Parsing {args.dataset_type} dataset from {args.index}...")
|
| 214 |
+
samples = PARSERS[args.dataset_type](args.index, args.audio_dir)
|
| 215 |
+
logging.info(f"Found {len(samples)} samples")
|
| 216 |
+
|
| 217 |
+
# Filter by duration
|
| 218 |
+
before = len(samples)
|
| 219 |
+
samples = [s for s in samples if args.min_duration <= s["duration"] <= args.max_duration]
|
| 220 |
+
logging.info(f"After duration filter [{args.min_duration}s, {args.max_duration}s]: {len(samples)} (dropped {before - len(samples)})")
|
| 221 |
+
|
| 222 |
+
if args.max_samples > 0:
|
| 223 |
+
samples = samples[:args.max_samples]
|
| 224 |
+
logging.info(f"Limiting to {len(samples)} samples")
|
| 225 |
+
|
| 226 |
+
# Assign global indices before sharding
|
| 227 |
+
for i, s in enumerate(samples):
|
| 228 |
+
s["global_idx"] = i
|
| 229 |
+
|
| 230 |
+
# Shard the data for parallel processing
|
| 231 |
+
if args.num_shards > 1:
|
| 232 |
+
total = len(samples)
|
| 233 |
+
samples = samples[args.shard::args.num_shards]
|
| 234 |
+
logging.info(f"Shard {args.shard}/{args.num_shards}: {len(samples)} samples (of {total} total)")
|
| 235 |
+
|
| 236 |
+
# ── Step 1: Encode text with Gemma (Blocks 1+2 only) ──
|
| 237 |
+
# The trainer runs Block 3 (embeddings processor/connectors) during training,
|
| 238 |
+
# so we only precompute Blocks 1+2 here (Gemma LLM + feature extractor).
|
| 239 |
+
logging.info("Loading text encoder (Gemma + feature extractor)...")
|
| 240 |
+
text_encoder = load_text_encoder(args.gemma_root, device=device, dtype=dtype)
|
| 241 |
+
|
| 242 |
+
# Load feature extractor on CPU first to save GPU memory, then move to device
|
| 243 |
+
logging.info("Loading feature extractor (on CPU first to save GPU memory)...")
|
| 244 |
+
emb_proc = load_embeddings_processor(args.checkpoint, device="cpu", dtype=dtype)
|
| 245 |
+
text_encoder.feature_extractor = emb_proc.feature_extractor.to(device)
|
| 246 |
+
del emb_proc
|
| 247 |
+
torch.cuda.empty_cache()
|
| 248 |
+
|
| 249 |
+
logging.info("Encoding text prompts (Blocks 1+2: Gemma + feature extractor)...")
|
| 250 |
+
for i, sample in enumerate(samples):
|
| 251 |
+
gidx = sample["global_idx"]
|
| 252 |
+
cond_path = out / "conditions" / f"sample_{gidx:06d}.pt"
|
| 253 |
+
if args.skip_existing and cond_path.exists():
|
| 254 |
+
continue
|
| 255 |
+
|
| 256 |
+
text = sample["text"]
|
| 257 |
+
# Run Blocks 1+2: Gemma LLM → feature extractor
|
| 258 |
+
hidden_states, attention_mask = text_encoder.encode(text)
|
| 259 |
+
video_feats, audio_feats = text_encoder.feature_extractor(
|
| 260 |
+
hidden_states, attention_mask, "left"
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
torch.save({
|
| 264 |
+
"video_prompt_embeds": video_feats.squeeze(0).cpu(),
|
| 265 |
+
"audio_prompt_embeds": audio_feats.squeeze(0).cpu() if audio_feats is not None else video_feats.squeeze(0).cpu(),
|
| 266 |
+
"prompt_attention_mask": attention_mask.squeeze(0).bool().cpu(),
|
| 267 |
+
}, cond_path)
|
| 268 |
+
|
| 269 |
+
if i % 100 == 0:
|
| 270 |
+
logging.info(f" Text encoding: {i}/{len(samples)}")
|
| 271 |
+
|
| 272 |
+
del text_encoder
|
| 273 |
+
torch.cuda.empty_cache()
|
| 274 |
+
|
| 275 |
+
# ── Step 2: Encode audio with Audio VAE ──
|
| 276 |
+
ckpt_for_vae = args.audio_only_ckpt or args.checkpoint
|
| 277 |
+
logging.info(f"Loading audio VAE from {ckpt_for_vae}...")
|
| 278 |
+
|
| 279 |
+
ac = AudioConditioner(checkpoint_path=ckpt_for_vae, dtype=dtype, device=device)
|
| 280 |
+
|
| 281 |
+
logging.info("Encoding audio samples...")
|
| 282 |
+
for idx, sample in enumerate(samples):
|
| 283 |
+
gidx = sample["global_idx"]
|
| 284 |
+
audio_path = out / "audio_latents" / f"sample_{gidx:06d}.pt"
|
| 285 |
+
if args.skip_existing and audio_path.exists():
|
| 286 |
+
continue
|
| 287 |
+
|
| 288 |
+
try:
|
| 289 |
+
# Load audio
|
| 290 |
+
voice = decode_audio_from_file(sample["audio_path"], device, 0.0, args.max_duration)
|
| 291 |
+
if voice is None:
|
| 292 |
+
logging.warning(f" Skipping {sample['id']}: no audio")
|
| 293 |
+
continue
|
| 294 |
+
|
| 295 |
+
w = voice.waveform
|
| 296 |
+
if w.dim() == 2:
|
| 297 |
+
if w.shape[0] == 1:
|
| 298 |
+
w = w.repeat(2, 1)
|
| 299 |
+
w = w.unsqueeze(0)
|
| 300 |
+
elif w.dim() == 3 and w.shape[1] == 1:
|
| 301 |
+
w = w.repeat(1, 2, 1)
|
| 302 |
+
voice = Audio(waveform=w, sampling_rate=voice.sampling_rate)
|
| 303 |
+
|
| 304 |
+
# Encode through Audio VAE
|
| 305 |
+
audio_latent = ac(lambda enc: vae_encode_audio(voice, enc, None))
|
| 306 |
+
|
| 307 |
+
# Save audio latent
|
| 308 |
+
torch.save({
|
| 309 |
+
"latents": audio_latent.squeeze(0).cpu(), # [C=8, T, F=16]
|
| 310 |
+
"sample_rate": 16000,
|
| 311 |
+
}, audio_path)
|
| 312 |
+
|
| 313 |
+
except Exception as e:
|
| 314 |
+
logging.warning(f" Skipping {sample['id']}: {e}")
|
| 315 |
+
continue
|
| 316 |
+
|
| 317 |
+
if idx % 100 == 0:
|
| 318 |
+
logging.info(f" Audio encoding: {idx}/{len(samples)}")
|
| 319 |
+
|
| 320 |
+
del ac
|
| 321 |
+
torch.cuda.empty_cache()
|
| 322 |
+
|
| 323 |
+
# ── Step 3: Create dummy video latents ──
|
| 324 |
+
logging.info("Creating dummy video latents...")
|
| 325 |
+
# Minimal video: 1 frame, 64x64 = 2x2 in latent space
|
| 326 |
+
dummy_video = {
|
| 327 |
+
"latents": torch.zeros(128, 1, 2, 2),
|
| 328 |
+
"num_frames": 1,
|
| 329 |
+
"height": 2,
|
| 330 |
+
"width": 2,
|
| 331 |
+
"fps": 24.0,
|
| 332 |
+
}
|
| 333 |
+
for idx, sample in enumerate(samples):
|
| 334 |
+
gidx = sample["global_idx"]
|
| 335 |
+
latent_path = out / "latents" / f"sample_{gidx:06d}.pt"
|
| 336 |
+
if args.skip_existing and latent_path.exists():
|
| 337 |
+
continue
|
| 338 |
+
torch.save(dummy_video, latent_path)
|
| 339 |
+
|
| 340 |
+
# ── Summary ──
|
| 341 |
+
n_audio = len(list((out / "audio_latents").glob("*.pt")))
|
| 342 |
+
n_cond = len(list((out / "conditions").glob("*.pt")))
|
| 343 |
+
n_lat = len(list((out / "latents").glob("*.pt")))
|
| 344 |
+
logging.info(f"\nDone! Output: {args.output_dir}")
|
| 345 |
+
logging.info(f" audio_latents: {n_audio} files")
|
| 346 |
+
logging.info(f" conditions: {n_cond} files")
|
| 347 |
+
logging.info(f" latents: {n_lat} files")
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
if __name__ == "__main__":
|
| 351 |
+
main()
|
dramabox_src/train.py
ADDED
|
@@ -0,0 +1,882 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Audio-Only IC-LoRA Training for Voice Cloning on LTX-2.3.
|
| 4 |
+
|
| 5 |
+
Uses the IC-LoRA pattern: reference audio tokens are APPENDED to the end of
|
| 6 |
+
the target sequence using AudioConditionByReferenceLatent. Loss is computed
|
| 7 |
+
only on target tokens; reference tokens remain clean (denoise_mask=0).
|
| 8 |
+
|
| 9 |
+
This follows the official video-to-video IC-LoRA strategy closely, but adapted
|
| 10 |
+
for the audio-only modality path.
|
| 11 |
+
|
| 12 |
+
Usage (single GPU):
|
| 13 |
+
CUDA_VISIBLE_DEVICES=0 python train_audio_iclora.py --data-dir ... --speaker-index ...
|
| 14 |
+
|
| 15 |
+
Usage (multi-GPU with accelerate):
|
| 16 |
+
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --num_processes=4 train_audio_iclora.py ...
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import logging
|
| 21 |
+
import math
|
| 22 |
+
import os
|
| 23 |
+
import random
|
| 24 |
+
import shutil
|
| 25 |
+
import sys
|
| 26 |
+
import time
|
| 27 |
+
from collections import defaultdict
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
from torch.utils.data import DataLoader, Dataset
|
| 33 |
+
|
| 34 |
+
REPO_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 35 |
+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ltx2"))
|
| 36 |
+
# ltx-pipelines already on path via ltx2/
|
| 37 |
+
|
| 38 |
+
MODEL_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 39 |
+
|
| 40 |
+
# Import audio conditioning item from our module
|
| 41 |
+
sys.path.insert(0, MODEL_DIR)
|
| 42 |
+
from audio_conditioning import AudioConditionByReferenceLatent
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ─── Timestep Sampling ───
|
| 46 |
+
|
| 47 |
+
class DistilledTimestepSampler:
|
| 48 |
+
"""Sample timesteps from the distilled sigma schedule.
|
| 49 |
+
|
| 50 |
+
The distilled model was trained to denoise at these specific sigma values.
|
| 51 |
+
We sample uniformly from the intervals between consecutive sigmas,
|
| 52 |
+
matching the distribution the model actually operates on.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
# Distilled 8-step sigma values (boundaries of denoising intervals)
|
| 56 |
+
SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
| 57 |
+
|
| 58 |
+
def __init__(self, jitter: float = 0.02):
|
| 59 |
+
self.jitter = jitter
|
| 60 |
+
|
| 61 |
+
def sample(self, batch_size: int, seq_length: int = None, device: torch.device = None) -> torch.Tensor:
|
| 62 |
+
n_intervals = len(self.SIGMAS) - 1
|
| 63 |
+
interval_idx = torch.randint(0, n_intervals, (batch_size,), device=device)
|
| 64 |
+
t = torch.rand(batch_size, device=device)
|
| 65 |
+
sigma_high = torch.tensor([self.SIGMAS[i] for i in interval_idx], device=device)
|
| 66 |
+
sigma_low = torch.tensor([self.SIGMAS[i + 1] for i in interval_idx], device=device)
|
| 67 |
+
sigma = sigma_low + t * (sigma_high - sigma_low)
|
| 68 |
+
return sigma.clamp(0.01, 0.99)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ShiftedLogitNormalTimestepSampler:
|
| 72 |
+
"""Shifted logit-normal distribution, shift depends on sequence length."""
|
| 73 |
+
|
| 74 |
+
def __init__(self, std: float = 1.0, eps: float = 1e-3, uniform_prob: float = 0.1):
|
| 75 |
+
self.std = std
|
| 76 |
+
self.eps = eps
|
| 77 |
+
self.uniform_prob = uniform_prob
|
| 78 |
+
self.normal_999_percentile = 3.0902 * std
|
| 79 |
+
self.normal_005_percentile = -2.5758 * std
|
| 80 |
+
|
| 81 |
+
def sample(self, batch_size: int, seq_length: int, device: torch.device = None) -> torch.Tensor:
|
| 82 |
+
mu = self._get_shift(seq_length)
|
| 83 |
+
normal = torch.randn(batch_size, device=device) * self.std + mu
|
| 84 |
+
logitnormal = torch.sigmoid(normal)
|
| 85 |
+
|
| 86 |
+
p999 = torch.sigmoid(torch.tensor(mu + self.normal_999_percentile, device=device))
|
| 87 |
+
p005 = torch.sigmoid(torch.tensor(mu + self.normal_005_percentile, device=device))
|
| 88 |
+
stretched = (logitnormal - p005) / (p999 - p005)
|
| 89 |
+
stretched = torch.where(stretched >= self.eps, stretched, 2 * self.eps - stretched)
|
| 90 |
+
stretched = stretched.clamp(0, 1)
|
| 91 |
+
|
| 92 |
+
uniform = (1 - self.eps) * torch.rand(batch_size, device=device) + self.eps
|
| 93 |
+
prob = torch.rand(batch_size, device=device)
|
| 94 |
+
return torch.where(prob > self.uniform_prob, stretched, uniform)
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
def _get_shift(seq_length, min_tok=1024, max_tok=4096, min_s=0.95, max_s=2.05):
|
| 98 |
+
m = (max_s - min_s) / (max_tok - min_tok)
|
| 99 |
+
return m * seq_length + (min_s - m * min_tok)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ─── Dataset ───
|
| 103 |
+
|
| 104 |
+
def build_speaker_map(index_paths, data_dirs):
|
| 105 |
+
"""Map speaker → [(data_dir, sample_idx)] from index file(s).
|
| 106 |
+
|
| 107 |
+
The sample index comes from field 0 of the `~`-delimited row when it
|
| 108 |
+
parses as int (allows subset indexes that keep original sample numbers),
|
| 109 |
+
otherwise we fall back to the row's line number (legacy behaviour for
|
| 110 |
+
string-keyed indexes like tts_training_data_podcast).
|
| 111 |
+
"""
|
| 112 |
+
speaker_to_samples = defaultdict(list)
|
| 113 |
+
for index_path, data_dir in zip(index_paths, data_dirs):
|
| 114 |
+
with open(index_path) as f:
|
| 115 |
+
for line_num, line in enumerate(f):
|
| 116 |
+
parts = line.strip().split("~")
|
| 117 |
+
if len(parts) < 7:
|
| 118 |
+
continue
|
| 119 |
+
try:
|
| 120 |
+
idx = int(parts[0])
|
| 121 |
+
except ValueError:
|
| 122 |
+
idx = line_num
|
| 123 |
+
speaker_id = parts[1]
|
| 124 |
+
speaker_to_samples[speaker_id].append((data_dir, idx))
|
| 125 |
+
return {k: v for k, v in speaker_to_samples.items() if len(v) >= 2}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class IDLoRADataset(Dataset):
|
| 129 |
+
# Silence-latent reference loaded once, used to detect and strip any
|
| 130 |
+
# leading silence frames baked into the preprocessed audio_latents. The
|
| 131 |
+
# training loop ALREADY prepends 0-25 random silence frames, so we don't
|
| 132 |
+
# want accidental silence in the source data compounding on top.
|
| 133 |
+
_silence_ref = None
|
| 134 |
+
|
| 135 |
+
@classmethod
|
| 136 |
+
def _load_silence_ref(cls):
|
| 137 |
+
if cls._silence_ref is None:
|
| 138 |
+
p = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
| 139 |
+
"assets", "silence_latent_frame.pt")
|
| 140 |
+
if os.path.exists(p):
|
| 141 |
+
cls._silence_ref = torch.load(p, weights_only=True).float().squeeze() # [C, F]
|
| 142 |
+
return cls._silence_ref
|
| 143 |
+
|
| 144 |
+
def __init__(self, speaker_map):
|
| 145 |
+
self.samples = []
|
| 146 |
+
self.speaker_map = {}
|
| 147 |
+
for speaker, entries in speaker_map.items():
|
| 148 |
+
valid = []
|
| 149 |
+
for data_dir, idx in entries:
|
| 150 |
+
audio_path = Path(data_dir) / "audio_latents" / f"sample_{idx:06d}.pt"
|
| 151 |
+
cond_path = Path(data_dir) / "conditions" / f"sample_{idx:06d}.pt"
|
| 152 |
+
if audio_path.exists() and cond_path.exists():
|
| 153 |
+
valid.append((data_dir, idx))
|
| 154 |
+
if len(valid) >= 2:
|
| 155 |
+
self.speaker_map[speaker] = valid
|
| 156 |
+
for speaker, entries in self.speaker_map.items():
|
| 157 |
+
for entry in entries:
|
| 158 |
+
self.samples.append((entry, speaker))
|
| 159 |
+
IDLoRADataset._load_silence_ref()
|
| 160 |
+
|
| 161 |
+
def __len__(self):
|
| 162 |
+
return len(self.samples)
|
| 163 |
+
|
| 164 |
+
def _load_sample(self, data_dir, idx):
|
| 165 |
+
base = Path(data_dir)
|
| 166 |
+
audio = torch.load(base / "audio_latents" / f"sample_{idx:06d}.pt", weights_only=False)
|
| 167 |
+
# Prefer prefix-stripped text embeddings if they exist (re-encoded with
|
| 168 |
+
# just the quoted dialogue, dropping the "A woman says, " / "A man
|
| 169 |
+
# speaks with X accent, " scene-description prefix).
|
| 170 |
+
stripped = base / "conditions_stripped" / f"sample_{idx:06d}.pt"
|
| 171 |
+
cond_path = stripped if stripped.exists() else base / "conditions" / f"sample_{idx:06d}.pt"
|
| 172 |
+
cond = torch.load(cond_path, weights_only=False)
|
| 173 |
+
if isinstance(audio, dict):
|
| 174 |
+
audio = audio.get("audio_latent", audio.get("latent", list(audio.values())[0]))
|
| 175 |
+
if audio.dim() == 2:
|
| 176 |
+
audio = audio.unsqueeze(0)
|
| 177 |
+
audio_feats = cond.get("audio_prompt_embeds", cond.get("prompt_embeds"))
|
| 178 |
+
attn_mask = cond.get("prompt_attention_mask")
|
| 179 |
+
# The audio_connector has num_learnable_registers=128 and asserts the
|
| 180 |
+
# input sequence length is divisible by 128. Our new preprocessing
|
| 181 |
+
# saved trimmed conditions (dropping left-padding to save disk), which
|
| 182 |
+
# produces short/irregular sequence lengths. Left-pad back to the next
|
| 183 |
+
# multiple of 128 with zeros (matching the tokenizer's left-padding
|
| 184 |
+
# convention) so this assertion holds.
|
| 185 |
+
REG = 128
|
| 186 |
+
L = audio_feats.shape[0]
|
| 187 |
+
target_L = ((L + REG - 1) // REG) * REG
|
| 188 |
+
if target_L != L:
|
| 189 |
+
pad_len = target_L - L
|
| 190 |
+
pad_emb = torch.zeros(pad_len, audio_feats.shape[1],
|
| 191 |
+
dtype=audio_feats.dtype)
|
| 192 |
+
pad_mask = torch.zeros(pad_len, dtype=attn_mask.dtype)
|
| 193 |
+
audio_feats = torch.cat([pad_emb, audio_feats], dim=0)
|
| 194 |
+
attn_mask = torch.cat([pad_mask, attn_mask], dim=0)
|
| 195 |
+
return audio, audio_feats, attn_mask
|
| 196 |
+
|
| 197 |
+
def __getitem__(self, idx):
|
| 198 |
+
(data_dir, tgt_idx), speaker = self.samples[idx]
|
| 199 |
+
tgt_latent, audio_feats, attn_mask = self._load_sample(data_dir, tgt_idx)
|
| 200 |
+
|
| 201 |
+
# Drop the reference entirely for non-voice-cloning categories:
|
| 202 |
+
# - SFX samples (speaker starts with "sfx_"): descriptive sound events,
|
| 203 |
+
# no speaker identity to clone.
|
| 204 |
+
# - Song/music samples (suno dataset): prompts describe the music style,
|
| 205 |
+
# reference audio doesn't transfer anything useful.
|
| 206 |
+
# Return a zero-length ref so the model trains target-only for these.
|
| 207 |
+
drop_ref = speaker.startswith("sfx_") or "preprocessed_ltx_suno" in str(data_dir)
|
| 208 |
+
if drop_ref:
|
| 209 |
+
C, F_dim = tgt_latent.shape[0], tgt_latent.shape[2]
|
| 210 |
+
ref_latent = torch.zeros(C, 0, F_dim, dtype=tgt_latent.dtype)
|
| 211 |
+
else:
|
| 212 |
+
entries = self.speaker_map[speaker]
|
| 213 |
+
ref_entry = random.choice([e for e in entries if e[1] != tgt_idx])
|
| 214 |
+
ref_latent, _, _ = self._load_sample(*ref_entry)
|
| 215 |
+
|
| 216 |
+
return {
|
| 217 |
+
"tgt_latent": tgt_latent,
|
| 218 |
+
"ref_latent": ref_latent,
|
| 219 |
+
"audio_features": audio_feats,
|
| 220 |
+
"attention_mask": attn_mask,
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# ─── Model building ───
|
| 225 |
+
|
| 226 |
+
def build_audio_only_model(checkpoint_path, device, dtype):
|
| 227 |
+
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
|
| 228 |
+
from ltx_core.loader.registry import DummyRegistry
|
| 229 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 230 |
+
from ltx_core.model.transformer.model import LTXModel, LTXModelType
|
| 231 |
+
from ltx_core.model.model_protocol import ModelConfigurator
|
| 232 |
+
from ltx_core.model.transformer.attention import AttentionFunction
|
| 233 |
+
from ltx_core.model.transformer.rope import LTXRopeType
|
| 234 |
+
|
| 235 |
+
sd_ops = SDOps("AO").with_matching(prefix="model.diffusion_model.").with_replacement("model.diffusion_model.", "")
|
| 236 |
+
|
| 237 |
+
class Cfg(ModelConfigurator[LTXModel]):
|
| 238 |
+
@classmethod
|
| 239 |
+
def from_config(cls, config):
|
| 240 |
+
t = config.get("transformer", {})
|
| 241 |
+
cp = None
|
| 242 |
+
if not t.get("caption_proj_before_connector", False):
|
| 243 |
+
from ltx_core.model.transformer.text_projection import create_caption_projection
|
| 244 |
+
with torch.device("meta"):
|
| 245 |
+
cp = create_caption_projection(t, audio=True)
|
| 246 |
+
return LTXModel(
|
| 247 |
+
model_type=LTXModelType.AudioOnly,
|
| 248 |
+
audio_num_attention_heads=t.get("audio_num_attention_heads", 32),
|
| 249 |
+
audio_attention_head_dim=t.get("audio_attention_head_dim", 64),
|
| 250 |
+
audio_in_channels=t.get("audio_in_channels", 128),
|
| 251 |
+
audio_out_channels=t.get("audio_out_channels", 128),
|
| 252 |
+
num_layers=t.get("num_layers", 48),
|
| 253 |
+
audio_cross_attention_dim=t.get("audio_cross_attention_dim", 2048),
|
| 254 |
+
norm_eps=t.get("norm_eps", 1e-6),
|
| 255 |
+
attention_type=AttentionFunction(t.get("attention_type", "default")),
|
| 256 |
+
positional_embedding_theta=t.get("positional_embedding_theta", 10000.0),
|
| 257 |
+
audio_positional_embedding_max_pos=t.get("audio_positional_embedding_max_pos", [20]),
|
| 258 |
+
timestep_scale_multiplier=t.get("timestep_scale_multiplier", 1000),
|
| 259 |
+
use_middle_indices_grid=t.get("use_middle_indices_grid", True),
|
| 260 |
+
rope_type=LTXRopeType(t.get("rope_type", "interleaved")),
|
| 261 |
+
double_precision_rope=t.get("frequencies_precision", False) == "float64",
|
| 262 |
+
apply_gated_attention=t.get("apply_gated_attention", False),
|
| 263 |
+
audio_caption_projection=cp,
|
| 264 |
+
cross_attention_adaln=t.get("cross_attention_adaln", False),
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
builder = Builder(model_path=checkpoint_path, model_class_configurator=Cfg,
|
| 268 |
+
model_sd_ops=sd_ops, registry=DummyRegistry())
|
| 269 |
+
return builder.build(device=device, dtype=dtype)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def load_audio_connector(checkpoint_path, device, dtype):
|
| 273 |
+
# ltx-trainer already on path via ltx2/
|
| 274 |
+
from ltx_trainer.model_loader import load_embeddings_processor
|
| 275 |
+
emb_proc = load_embeddings_processor(checkpoint_path, device=device, dtype=dtype)
|
| 276 |
+
connector = emb_proc.audio_connector
|
| 277 |
+
del emb_proc
|
| 278 |
+
return connector
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def apply_lora(model, rank, alpha, dropout=0.0):
|
| 282 |
+
from peft import LoraConfig, get_peft_model
|
| 283 |
+
config = LoraConfig(
|
| 284 |
+
r=rank, lora_alpha=alpha, lora_dropout=dropout, bias="none",
|
| 285 |
+
target_modules=[
|
| 286 |
+
# Self-attention over audio tokens (voice-transfer pathway via ref).
|
| 287 |
+
"audio_attn1.to_k", "audio_attn1.to_q", "audio_attn1.to_v", "audio_attn1.to_out.0",
|
| 288 |
+
# Cross-attention (audio ↔ text context) NOT adapted — keep base
|
| 289 |
+
# model's prompt→audio behaviour intact and rely on dataset balance
|
| 290 |
+
# to drive expressiveness. (v15c tried this with adaLN unfreeze,
|
| 291 |
+
# that proved too destructive; v16 tries it adaLN-frozen.)
|
| 292 |
+
# FFN — non-linear capacity for style/phonetic adaptation.
|
| 293 |
+
"audio_ff.net.0.proj", "audio_ff.net.2",
|
| 294 |
+
],
|
| 295 |
+
)
|
| 296 |
+
model = get_peft_model(model, config)
|
| 297 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 298 |
+
total = sum(p.numel() for p in model.parameters())
|
| 299 |
+
logging.info(f"LoRA: {trainable:,} trainable / {total:,} total ({100*trainable/total:.1f}%)")
|
| 300 |
+
return model
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
@torch.no_grad()
|
| 304 |
+
def prepare_audio_context(audio_connector, audio_features, attention_mask, device, dtype):
|
| 305 |
+
from ltx_core.text_encoders.gemma.embeddings_processor import convert_to_additive_mask
|
| 306 |
+
audio_features = audio_features.to(device=device, dtype=dtype)
|
| 307 |
+
attention_mask = attention_mask.to(device=device)
|
| 308 |
+
if audio_features.shape[0] > 1:
|
| 309 |
+
results = []
|
| 310 |
+
for i in range(audio_features.shape[0]):
|
| 311 |
+
feat_i = audio_features[i:i+1]
|
| 312 |
+
mask_i = attention_mask[i:i+1]
|
| 313 |
+
additive = convert_to_additive_mask(mask_i, feat_i.dtype)
|
| 314 |
+
enc_i, _ = audio_connector(feat_i, additive)
|
| 315 |
+
results.append(enc_i)
|
| 316 |
+
return torch.cat(results, dim=0)
|
| 317 |
+
additive_mask = convert_to_additive_mask(attention_mask, audio_features.dtype)
|
| 318 |
+
audio_encoded, _ = audio_connector(audio_features, additive_mask)
|
| 319 |
+
return audio_encoded
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
# ─── Validation ───
|
| 323 |
+
|
| 324 |
+
def _unwrap_model_safe(model):
|
| 325 |
+
"""Strip DDP / peft wrappers without going through accelerate.unwrap_model,
|
| 326 |
+
which imports deepspeed — broken in our env (torch API drift)."""
|
| 327 |
+
while hasattr(model, "module"):
|
| 328 |
+
model = model.module
|
| 329 |
+
return model
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def run_validation(lora_path, val_config_path, output_dir, step, lora_rank=128):
|
| 333 |
+
"""Call validate.py in a subprocess. It loads TTSServer (the same stack
|
| 334 |
+
the warm server / Gradio app uses), attaches our LoRA, then iterates every
|
| 335 |
+
entry in val_config with the same inference settings the user tests with.
|
| 336 |
+
Single subprocess amortises the model-load cost across all val entries.
|
| 337 |
+
|
| 338 |
+
Forces validation onto VAL_GPU (default "0") because training already
|
| 339 |
+
occupies the rest. Override via TRAIN_VAL_GPU env var.
|
| 340 |
+
"""
|
| 341 |
+
import subprocess
|
| 342 |
+
val_dir = os.path.join(output_dir, "validation", f"step_{step:05d}")
|
| 343 |
+
os.makedirs(val_dir, exist_ok=True)
|
| 344 |
+
script = os.path.join(os.path.dirname(__file__), "validate.py")
|
| 345 |
+
cmd = [
|
| 346 |
+
sys.executable, script,
|
| 347 |
+
"--val-config", val_config_path,
|
| 348 |
+
"--output-dir", val_dir,
|
| 349 |
+
"--lora", lora_path,
|
| 350 |
+
"--lora-rank", str(lora_rank),
|
| 351 |
+
# Use raw estimator output (no +10% buffer) so we can hear
|
| 352 |
+
# whether the model needs more/less duration at current quality.
|
| 353 |
+
"--duration-multiplier", "1.0",
|
| 354 |
+
]
|
| 355 |
+
log_path = os.path.join(val_dir, "validate.log")
|
| 356 |
+
env = os.environ.copy()
|
| 357 |
+
# Validation needs its OWN GPU (training fills the others).
|
| 358 |
+
env["CUDA_VISIBLE_DEVICES"] = os.environ.get("TRAIN_VAL_GPU", "0")
|
| 359 |
+
try:
|
| 360 |
+
with open(log_path, "w") as logf:
|
| 361 |
+
result = subprocess.run(
|
| 362 |
+
cmd, stdout=logf, stderr=subprocess.STDOUT, timeout=1800, env=env,
|
| 363 |
+
)
|
| 364 |
+
if result.returncode == 0:
|
| 365 |
+
logging.info(f" Validation step {step}: OK → {val_dir}")
|
| 366 |
+
else:
|
| 367 |
+
logging.warning(f" Validation step {step} FAILED (see {log_path})")
|
| 368 |
+
except subprocess.TimeoutExpired:
|
| 369 |
+
logging.warning(f" Validation step {step} TIMEOUT (>30min)")
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
# ─── Args ───
|
| 373 |
+
|
| 374 |
+
def parse_args():
|
| 375 |
+
# First pass: pull out --config so its values can become argparse defaults.
|
| 376 |
+
cfg_parser = argparse.ArgumentParser(add_help=False)
|
| 377 |
+
cfg_parser.add_argument("--config", default=None,
|
| 378 |
+
help="YAML file with default values for any of the flags below. "
|
| 379 |
+
"Explicit CLI flags still override the YAML.")
|
| 380 |
+
cfg_args, remaining = cfg_parser.parse_known_args()
|
| 381 |
+
yaml_defaults: dict = {}
|
| 382 |
+
if cfg_args.config:
|
| 383 |
+
import yaml as _yaml
|
| 384 |
+
with open(cfg_args.config) as f:
|
| 385 |
+
yaml_defaults = _yaml.safe_load(f) or {}
|
| 386 |
+
# YAML keys are dashes-or-underscores → normalize to argparse dest (underscore).
|
| 387 |
+
yaml_defaults = {k.replace("-", "_"): v for k, v in yaml_defaults.items()}
|
| 388 |
+
|
| 389 |
+
def _yaml(name, fallback):
|
| 390 |
+
return yaml_defaults.get(name, fallback)
|
| 391 |
+
|
| 392 |
+
p = argparse.ArgumentParser(
|
| 393 |
+
parents=[cfg_parser],
|
| 394 |
+
description="Audio-Only IC-LoRA Training for Voice Cloning",
|
| 395 |
+
)
|
| 396 |
+
p.add_argument("--data-dir", required="data_dir" not in yaml_defaults,
|
| 397 |
+
nargs="+", default=_yaml("data_dir", None))
|
| 398 |
+
p.add_argument("--speaker-index", required="speaker_index" not in yaml_defaults,
|
| 399 |
+
nargs="+", default=_yaml("speaker_index", None))
|
| 400 |
+
p.add_argument("--output-dir", default=_yaml("output_dir", os.path.join(MODEL_DIR, "tts_iclora_v1")))
|
| 401 |
+
p.add_argument("--checkpoint", default=_yaml("checkpoint", os.path.join(MODEL_DIR, "dramabox-dit-v1.safetensors")))
|
| 402 |
+
p.add_argument("--full-checkpoint", default=_yaml("full_checkpoint", os.path.join(MODEL_DIR, "dramabox-audio-components.safetensors")))
|
| 403 |
+
p.add_argument("--base-model", choices=["distilled", "dev"], default=_yaml("base_model", "dev"),
|
| 404 |
+
help="Base model type: distilled uses DistilledTimestepSampler, dev uses ShiftedLogitNormal")
|
| 405 |
+
p.add_argument("--lora-rank", type=int, default=_yaml("lora_rank", 128))
|
| 406 |
+
p.add_argument("--lora-alpha", type=int, default=_yaml("lora_alpha", 128))
|
| 407 |
+
p.add_argument("--lora-dropout", type=float, default=_yaml("lora_dropout", 0.0),
|
| 408 |
+
help="Dropout applied to LoRA A/B matrices during training. "
|
| 409 |
+
"Recommended ~0.1 for small datasets to regularize.")
|
| 410 |
+
p.add_argument("--resume-lora", default=_yaml("resume_lora", None))
|
| 411 |
+
p.add_argument("--resume-step-offset", type=int, default=_yaml("resume_step_offset", None),
|
| 412 |
+
help="Step to add when naming saved checkpoints. If None, inferred "
|
| 413 |
+
"from --resume-lora filename (e.g. lora_step_10000.safetensors → 10000). "
|
| 414 |
+
"Set to 0 to start numbering at 0 regardless.")
|
| 415 |
+
p.add_argument("--ref-ratio", type=float, default=_yaml("ref_ratio", 0.3),
|
| 416 |
+
help="Fraction of target length to use as reference (default 0.3)")
|
| 417 |
+
p.add_argument("--max-ref-tokens", type=int, default=_yaml("max_ref_tokens", 200),
|
| 418 |
+
help="Maximum reference tokens after patchification (default 200)")
|
| 419 |
+
p.add_argument("--text-dropout", type=float, default=_yaml("text_dropout", 0.0),
|
| 420 |
+
help="Probability of dropping text conditioning (forces reliance on voice ref)")
|
| 421 |
+
p.add_argument("--steps", type=int, default=_yaml("steps", 30000))
|
| 422 |
+
p.add_argument("--lr", type=float, default=_yaml("lr", 3e-5))
|
| 423 |
+
p.add_argument("--lr-scheduler", choices=["cosine", "linear", "constant"], default=_yaml("lr_scheduler", "cosine"))
|
| 424 |
+
p.add_argument("--batch-size", type=int, default=_yaml("batch_size", 1))
|
| 425 |
+
p.add_argument("--grad-accum", type=int, default=_yaml("grad_accum", 4))
|
| 426 |
+
p.add_argument("--max-grad-norm", type=float, default=_yaml("max_grad_norm", 1.0))
|
| 427 |
+
p.add_argument("--save-every", type=int, default=_yaml("save_every", 1000))
|
| 428 |
+
p.add_argument("--log-every", type=int, default=_yaml("log_every", 50))
|
| 429 |
+
p.add_argument("--seed", type=int, default=_yaml("seed", 42))
|
| 430 |
+
p.add_argument("--warmup-steps", type=int, default=_yaml("warmup_steps", 100))
|
| 431 |
+
p.add_argument("--val-config", default=_yaml("val_config", None))
|
| 432 |
+
return p.parse_args(remaining)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
# ─── Main ───
|
| 436 |
+
|
| 437 |
+
def main():
|
| 438 |
+
from accelerate import Accelerator
|
| 439 |
+
from accelerate.utils import set_seed
|
| 440 |
+
|
| 441 |
+
args = parse_args()
|
| 442 |
+
|
| 443 |
+
accelerator = Accelerator(
|
| 444 |
+
gradient_accumulation_steps=args.grad_accum,
|
| 445 |
+
mixed_precision="bf16",
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
is_main = accelerator.is_main_process
|
| 449 |
+
if is_main:
|
| 450 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 451 |
+
else:
|
| 452 |
+
logging.basicConfig(level=logging.WARNING)
|
| 453 |
+
|
| 454 |
+
set_seed(args.seed)
|
| 455 |
+
device = accelerator.device
|
| 456 |
+
dtype = torch.bfloat16
|
| 457 |
+
|
| 458 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 459 |
+
|
| 460 |
+
# Save training args
|
| 461 |
+
if is_main:
|
| 462 |
+
import yaml
|
| 463 |
+
args_dict = vars(args).copy()
|
| 464 |
+
args_dict["_meta"] = {
|
| 465 |
+
"world_size": accelerator.num_processes,
|
| 466 |
+
"dtype": str(dtype),
|
| 467 |
+
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
| 468 |
+
"script": "train_audio_iclora.py",
|
| 469 |
+
"pattern": "IC-LoRA (ref appended to end)",
|
| 470 |
+
}
|
| 471 |
+
with open(os.path.join(args.output_dir, "training_args.yaml"), "w") as f:
|
| 472 |
+
yaml.dump(args_dict, f, default_flow_style=False, sort_keys=False)
|
| 473 |
+
|
| 474 |
+
from ltx_core.components.patchifiers import AudioPatchifier
|
| 475 |
+
from ltx_core.model.transformer.modality import Modality
|
| 476 |
+
from ltx_core.guidance.perturbations import BatchedPerturbationConfig
|
| 477 |
+
from ltx_core.tools import AudioLatentTools
|
| 478 |
+
from ltx_core.types import AudioLatentShape, LatentState
|
| 479 |
+
from ltx_pipelines.utils.helpers import modality_from_latent_state, timesteps_from_mask
|
| 480 |
+
|
| 481 |
+
# Build speaker map
|
| 482 |
+
if is_main:
|
| 483 |
+
logging.info("Building speaker map...")
|
| 484 |
+
speaker_map = build_speaker_map(args.speaker_index, args.data_dir)
|
| 485 |
+
if is_main:
|
| 486 |
+
logging.info(f"Speaker map: {len(speaker_map)} speakers, "
|
| 487 |
+
f"{sum(len(v) for v in speaker_map.values())} samples")
|
| 488 |
+
|
| 489 |
+
# Load model
|
| 490 |
+
if is_main:
|
| 491 |
+
logging.info("Loading audio-only model...")
|
| 492 |
+
model = build_audio_only_model(args.checkpoint, device, dtype)
|
| 493 |
+
|
| 494 |
+
if is_main:
|
| 495 |
+
logging.info("Loading audio connector...")
|
| 496 |
+
audio_connector = load_audio_connector(args.full_checkpoint, device, dtype)
|
| 497 |
+
audio_connector.eval()
|
| 498 |
+
for p in audio_connector.parameters():
|
| 499 |
+
p.requires_grad = False
|
| 500 |
+
|
| 501 |
+
if is_main:
|
| 502 |
+
logging.info(f"Applying LoRA (rank={args.lora_rank}, alpha={args.lora_alpha})...")
|
| 503 |
+
model = apply_lora(model, args.lora_rank, args.lora_alpha, args.lora_dropout)
|
| 504 |
+
|
| 505 |
+
# Resume from checkpoint
|
| 506 |
+
if args.resume_lora:
|
| 507 |
+
from safetensors.torch import load_file as st_load
|
| 508 |
+
if is_main:
|
| 509 |
+
logging.info(f"Resuming from: {args.resume_lora}")
|
| 510 |
+
lora_sd = st_load(args.resume_lora)
|
| 511 |
+
mapped = {}
|
| 512 |
+
for k, v in lora_sd.items():
|
| 513 |
+
nk = k.replace(".lora_A.weight", ".lora_A.default.weight").replace(
|
| 514 |
+
".lora_B.weight", ".lora_B.default.weight")
|
| 515 |
+
mapped[nk] = v
|
| 516 |
+
model.load_state_dict(mapped, strict=False)
|
| 517 |
+
|
| 518 |
+
# Determine step offset for save filenames. Without this, resuming a run
|
| 519 |
+
# restarts step numbering at 0 and would overwrite earlier phase-1
|
| 520 |
+
# checkpoints with the same save_every cadence.
|
| 521 |
+
if args.resume_step_offset is None:
|
| 522 |
+
resume_offset = 0
|
| 523 |
+
if args.resume_lora:
|
| 524 |
+
import re as _re
|
| 525 |
+
m = _re.search(r"lora_step_(\d+)", os.path.basename(args.resume_lora))
|
| 526 |
+
if m:
|
| 527 |
+
resume_offset = int(m.group(1))
|
| 528 |
+
args.resume_step_offset = resume_offset
|
| 529 |
+
if is_main and args.resume_step_offset:
|
| 530 |
+
logging.info(f"Save-step offset: +{args.resume_step_offset}")
|
| 531 |
+
|
| 532 |
+
model.train()
|
| 533 |
+
model.base_model.model.set_gradient_checkpointing(True)
|
| 534 |
+
|
| 535 |
+
# Dataset & DataLoader
|
| 536 |
+
dataset = IDLoRADataset(speaker_map)
|
| 537 |
+
if is_main:
|
| 538 |
+
logging.info(f"Dataset: {len(dataset)} samples, {len(dataset.speaker_map)} speakers")
|
| 539 |
+
|
| 540 |
+
def collate_fn(batch):
|
| 541 |
+
"""Pad variable-length audio to max in batch, track real lengths for loss masking."""
|
| 542 |
+
max_tgt_T = max(b["tgt_latent"].shape[1] for b in batch) # [C, T, F]
|
| 543 |
+
max_ref_T = max(b["ref_latent"].shape[1] for b in batch)
|
| 544 |
+
C = batch[0]["tgt_latent"].shape[0]
|
| 545 |
+
F_dim = batch[0]["tgt_latent"].shape[2]
|
| 546 |
+
|
| 547 |
+
tgt_list, ref_list, feat_list, mask_list = [], [], [], []
|
| 548 |
+
tgt_lengths, ref_lengths = [], []
|
| 549 |
+
|
| 550 |
+
for b in batch:
|
| 551 |
+
tgt = b["tgt_latent"]
|
| 552 |
+
ref = b["ref_latent"]
|
| 553 |
+
tgt_lengths.append(tgt.shape[1])
|
| 554 |
+
ref_lengths.append(ref.shape[1])
|
| 555 |
+
|
| 556 |
+
if tgt.shape[1] < max_tgt_T:
|
| 557 |
+
pad = torch.zeros(C, max_tgt_T - tgt.shape[1], F_dim, dtype=tgt.dtype)
|
| 558 |
+
tgt = torch.cat([tgt, pad], dim=1)
|
| 559 |
+
tgt_list.append(tgt)
|
| 560 |
+
|
| 561 |
+
if ref.shape[1] < max_ref_T:
|
| 562 |
+
pad = torch.zeros(C, max_ref_T - ref.shape[1], F_dim, dtype=ref.dtype)
|
| 563 |
+
ref = torch.cat([ref, pad], dim=1)
|
| 564 |
+
ref_list.append(ref)
|
| 565 |
+
|
| 566 |
+
feat_list.append(b["audio_features"])
|
| 567 |
+
mask_list.append(b["attention_mask"])
|
| 568 |
+
|
| 569 |
+
return {
|
| 570 |
+
"tgt_latent": torch.stack(tgt_list),
|
| 571 |
+
"ref_latent": torch.stack(ref_list),
|
| 572 |
+
"audio_features": torch.stack(feat_list),
|
| 573 |
+
"attention_mask": torch.stack(mask_list),
|
| 574 |
+
"tgt_lengths": torch.tensor(tgt_lengths),
|
| 575 |
+
"ref_lengths": torch.tensor(ref_lengths),
|
| 576 |
+
}
|
| 577 |
+
|
| 578 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2,
|
| 579 |
+
pin_memory=True, drop_last=True, collate_fn=collate_fn)
|
| 580 |
+
|
| 581 |
+
# Optimizer & Scheduler
|
| 582 |
+
optimizer = torch.optim.AdamW(
|
| 583 |
+
[p for p in model.parameters() if p.requires_grad],
|
| 584 |
+
lr=args.lr, betas=(0.9, 0.999), weight_decay=0.01,
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR, ConstantLR
|
| 588 |
+
warmup = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=args.warmup_steps)
|
| 589 |
+
remaining = args.steps - args.warmup_steps
|
| 590 |
+
if args.lr_scheduler == "cosine":
|
| 591 |
+
# Warmup -> constant hold (20% of remaining) -> cosine decay
|
| 592 |
+
hold_steps = max(remaining // 5, 0)
|
| 593 |
+
decay_steps = max(remaining - hold_steps, 1)
|
| 594 |
+
hold_sched = ConstantLR(optimizer, factor=1.0, total_iters=hold_steps)
|
| 595 |
+
decay_sched = CosineAnnealingLR(optimizer, T_max=decay_steps, eta_min=1e-6)
|
| 596 |
+
scheduler = SequentialLR(
|
| 597 |
+
optimizer,
|
| 598 |
+
[warmup, hold_sched, decay_sched],
|
| 599 |
+
milestones=[args.warmup_steps, args.warmup_steps + hold_steps],
|
| 600 |
+
)
|
| 601 |
+
elif args.lr_scheduler == "linear":
|
| 602 |
+
main_sched = LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=max(remaining, 1))
|
| 603 |
+
scheduler = SequentialLR(optimizer, [warmup, main_sched], milestones=[args.warmup_steps])
|
| 604 |
+
else:
|
| 605 |
+
main_sched = ConstantLR(optimizer, factor=1.0, total_iters=max(remaining, 1))
|
| 606 |
+
scheduler = SequentialLR(optimizer, [warmup, main_sched], milestones=[args.warmup_steps])
|
| 607 |
+
|
| 608 |
+
# Prepare with Accelerate — but NOT the scheduler. AcceleratedScheduler
|
| 609 |
+
# calls the underlying scheduler.step() `num_processes` times per sync,
|
| 610 |
+
# which silently scales down our warmup/cosine spans by that factor.
|
| 611 |
+
# We call scheduler.step() ourselves, gated on sync_gradients → exactly
|
| 612 |
+
# one advance per optimizer step, as the yaml spec intends.
|
| 613 |
+
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
| 614 |
+
|
| 615 |
+
patchifier = AudioPatchifier(patch_size=1)
|
| 616 |
+
|
| 617 |
+
# Select timestep sampler based on base model type
|
| 618 |
+
if args.base_model == "distilled":
|
| 619 |
+
timestep_sampler = DistilledTimestepSampler()
|
| 620 |
+
if is_main:
|
| 621 |
+
logging.info("Using DistilledTimestepSampler (matching distilled model sigmas)")
|
| 622 |
+
else:
|
| 623 |
+
timestep_sampler = ShiftedLogitNormalTimestepSampler()
|
| 624 |
+
if is_main:
|
| 625 |
+
logging.info("Using ShiftedLogitNormalTimestepSampler (dev model)")
|
| 626 |
+
|
| 627 |
+
# Training loop
|
| 628 |
+
if is_main:
|
| 629 |
+
logging.info(f"Training: {args.steps} steps, lr={args.lr}, scheduler={args.lr_scheduler}, "
|
| 630 |
+
f"batch={args.batch_size}, grad_accum={args.grad_accum}, "
|
| 631 |
+
f"world_size={accelerator.num_processes}, "
|
| 632 |
+
f"ref_ratio={args.ref_ratio}, max_ref_tokens={args.max_ref_tokens}")
|
| 633 |
+
logging.info("IC-LoRA pattern: ref tokens APPENDED to target, loss on target only")
|
| 634 |
+
|
| 635 |
+
data_iter = iter(dataloader)
|
| 636 |
+
step = 0
|
| 637 |
+
accum_loss = 0.0
|
| 638 |
+
best_loss = float("inf")
|
| 639 |
+
best_step = 0
|
| 640 |
+
t0 = time.time()
|
| 641 |
+
|
| 642 |
+
total_micro_steps = args.steps * args.grad_accum
|
| 643 |
+
|
| 644 |
+
for micro_step in range(total_micro_steps):
|
| 645 |
+
try:
|
| 646 |
+
batch = next(data_iter)
|
| 647 |
+
except StopIteration:
|
| 648 |
+
data_iter = iter(dataloader)
|
| 649 |
+
batch = next(data_iter)
|
| 650 |
+
|
| 651 |
+
is_opt_step = (micro_step + 1) % args.grad_accum == 0
|
| 652 |
+
if is_opt_step:
|
| 653 |
+
step += 1
|
| 654 |
+
|
| 655 |
+
with accelerator.accumulate(model):
|
| 656 |
+
tgt_latent = batch["tgt_latent"].to(dtype=dtype) # [B, C, max_tgt_T, F]
|
| 657 |
+
ref_latent = batch["ref_latent"].to(dtype=dtype) # [B, C, max_ref_T, F]
|
| 658 |
+
tgt_lengths = batch["tgt_lengths"].to(device=device) # [B]
|
| 659 |
+
B = tgt_latent.shape[0]
|
| 660 |
+
|
| 661 |
+
# ── Random silence padding (0-1s) ── ltx_audio_tts baseline.
|
| 662 |
+
# User observed reference-audio leak at end of generations when this
|
| 663 |
+
# was reduced to 5 (v14) or 10 frames (v16/v17) — the model seemed
|
| 664 |
+
# to use the extra target budget to regurgitate ref content. Full
|
| 665 |
+
# 25 frames (0-1s avg 500ms) was apparently load-bearing for
|
| 666 |
+
# regularising the boundary and reducing hallucinations.
|
| 667 |
+
# Uses the real silence latent (not zeros) so the VAE decodes it as
|
| 668 |
+
# true silence instead of static noise.
|
| 669 |
+
max_pad_frames = 25 # ~1s at 25 latent frames/sec
|
| 670 |
+
pad_frames = random.randint(0, max_pad_frames)
|
| 671 |
+
if pad_frames > 0:
|
| 672 |
+
C, F_dim = tgt_latent.shape[1], tgt_latent.shape[3]
|
| 673 |
+
if not hasattr(args, '_silence_frame') or args._silence_frame is None:
|
| 674 |
+
_sf_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "assets", "silence_latent_frame.pt")
|
| 675 |
+
if os.path.exists(_sf_path):
|
| 676 |
+
args._silence_frame = torch.load(_sf_path, weights_only=True) # [C, 1, F]
|
| 677 |
+
if is_main:
|
| 678 |
+
logging.info(f"Loaded silence latent from {_sf_path}")
|
| 679 |
+
else:
|
| 680 |
+
args._silence_frame = False # fallback to zeros
|
| 681 |
+
if is_main:
|
| 682 |
+
logging.warning(f"silence_latent_frame.pt not found, using zeros")
|
| 683 |
+
if args._silence_frame is not False:
|
| 684 |
+
sf = args._silence_frame.to(dtype=dtype, device=device) # [C, 1, F]
|
| 685 |
+
silence_pad = sf.unsqueeze(0).expand(B, -1, pad_frames, -1) # [B, C, pad, F]
|
| 686 |
+
else:
|
| 687 |
+
silence_pad = torch.zeros(B, C, pad_frames, F_dim, dtype=dtype, device=device)
|
| 688 |
+
tgt_latent = torch.cat([silence_pad, tgt_latent], dim=2)
|
| 689 |
+
|
| 690 |
+
# Cap reference to max_ref_tokens (in latent frames, before patchification)
|
| 691 |
+
# After patchification, ref_T tokens = ref frames (patch_size=1)
|
| 692 |
+
ref_T_frames = min(ref_latent.shape[2], args.max_ref_tokens)
|
| 693 |
+
ref_latent = ref_latent[:, :, :ref_T_frames, :]
|
| 694 |
+
|
| 695 |
+
tgt_T_frames = tgt_latent.shape[2] # max (padded) target frames
|
| 696 |
+
|
| 697 |
+
# ── Step 1: Create target AudioLatentShape and AudioLatentTools ──
|
| 698 |
+
tgt_shape = AudioLatentShape(
|
| 699 |
+
batch=B,
|
| 700 |
+
channels=tgt_latent.shape[1], # 8
|
| 701 |
+
frames=tgt_T_frames,
|
| 702 |
+
mel_bins=tgt_latent.shape[3], # 16
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
audio_tools = AudioLatentTools(
|
| 706 |
+
patchifier=patchifier,
|
| 707 |
+
target_shape=tgt_shape,
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
# ── Step 2: Create initial state from target latent ──
|
| 711 |
+
# create_initial_state patchifies: [B, C, T, F] -> [B, T, C*F]
|
| 712 |
+
# Also creates denoise_mask=1 (all target tokens will be denoised)
|
| 713 |
+
# and computes temporal positions
|
| 714 |
+
state = audio_tools.create_initial_state(
|
| 715 |
+
device=device,
|
| 716 |
+
dtype=dtype,
|
| 717 |
+
initial_latent=tgt_latent,
|
| 718 |
+
)
|
| 719 |
+
# state.latent: [B, tgt_T, 128], state.denoise_mask: [B, tgt_T, 1]
|
| 720 |
+
# state.positions: [B, 1, tgt_T, 2]
|
| 721 |
+
|
| 722 |
+
tgt_T = audio_tools.target_shape.token_count() # = tgt_T_frames
|
| 723 |
+
|
| 724 |
+
# ── Step 3: Apply flow-matching noise to target BEFORE appending ref ──
|
| 725 |
+
# Sample sigma
|
| 726 |
+
total_tokens = tgt_T + ref_T_frames
|
| 727 |
+
sigma = timestep_sampler.sample(B, total_tokens, device=device)
|
| 728 |
+
sigma_exp = sigma.view(-1, 1, 1) # [B, 1, 1]
|
| 729 |
+
|
| 730 |
+
noise = torch.randn_like(state.latent) # [B, tgt_T, 128]
|
| 731 |
+
noisy_tgt = (1 - sigma_exp) * state.latent + sigma_exp * noise
|
| 732 |
+
|
| 733 |
+
# Replace the latent in state with the noisy version
|
| 734 |
+
# (clean_latent stays clean for post_process_latent pattern)
|
| 735 |
+
state = LatentState(
|
| 736 |
+
latent=noisy_tgt,
|
| 737 |
+
denoise_mask=state.denoise_mask,
|
| 738 |
+
positions=state.positions,
|
| 739 |
+
clean_latent=state.clean_latent,
|
| 740 |
+
attention_mask=state.attention_mask,
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
# ── Step 4: Append reference tokens using AudioConditionByReferenceLatent ──
|
| 744 |
+
# This appends ref tokens to the END with denoise_mask=0 (frozen/clean)
|
| 745 |
+
# Skip entirely when ref_T=0 (SFX / song samples): the model trains
|
| 746 |
+
# target-only for those categories since there's no voice to clone.
|
| 747 |
+
if ref_T_frames > 0:
|
| 748 |
+
ref_conditioning = AudioConditionByReferenceLatent(
|
| 749 |
+
latent=ref_latent,
|
| 750 |
+
strength=1.0, # 1.0 = ref fully clean (denoise_mask=0)
|
| 751 |
+
)
|
| 752 |
+
state = ref_conditioning.apply_to(
|
| 753 |
+
latent_state=state,
|
| 754 |
+
latent_tools=audio_tools,
|
| 755 |
+
)
|
| 756 |
+
# state.latent: [B, tgt_T + ref_T, 128]
|
| 757 |
+
# state.denoise_mask: [B, tgt_T + ref_T, 1]
|
| 758 |
+
# target tokens: 1.0 (denoise), ref tokens: 0.0 (frozen)
|
| 759 |
+
# state.positions: [B, 1, tgt_T + ref_T, 2]
|
| 760 |
+
|
| 761 |
+
# ── Step 5: Build loss mask for target tokens (excluding padding) ──
|
| 762 |
+
# loss_mask: 1 for real target tokens, 0 for padding and ref tokens
|
| 763 |
+
loss_mask = torch.zeros(B, tgt_T, device=device)
|
| 764 |
+
for b_idx in range(B):
|
| 765 |
+
real_len = min(tgt_lengths[b_idx].item(), tgt_T)
|
| 766 |
+
loss_mask[b_idx, :real_len] = 1.0
|
| 767 |
+
|
| 768 |
+
# ── Step 6: Prepare text context ──
|
| 769 |
+
# Text conditioning dropout: randomly zero out text context to force
|
| 770 |
+
# the model to rely on the voice reference for identity/style.
|
| 771 |
+
with torch.no_grad():
|
| 772 |
+
audio_context = prepare_audio_context(
|
| 773 |
+
audio_connector, batch["audio_features"],
|
| 774 |
+
batch["attention_mask"], device, dtype)
|
| 775 |
+
if args.text_dropout > 0 and random.random() < args.text_dropout:
|
| 776 |
+
audio_context = torch.zeros_like(audio_context)
|
| 777 |
+
|
| 778 |
+
# ── Step 7: Build Modality using modality_from_latent_state ──
|
| 779 |
+
# timesteps = sigma * denoise_mask (ref gets 0, target gets sigma)
|
| 780 |
+
audio_mod = modality_from_latent_state(
|
| 781 |
+
state=state,
|
| 782 |
+
context=audio_context,
|
| 783 |
+
sigma=sigma,
|
| 784 |
+
enabled=True,
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
# ── Step 8: Forward pass ──
|
| 788 |
+
perturbations = BatchedPerturbationConfig.empty(B)
|
| 789 |
+
with torch.autocast(device_type="cuda", dtype=dtype):
|
| 790 |
+
_, velocity_pred = model(video=None, audio=audio_mod, perturbations=perturbations)
|
| 791 |
+
|
| 792 |
+
# ── Step 9: Compute loss (IC-LoRA pattern) ──
|
| 793 |
+
# Target is at the FRONT (indices 0..tgt_T), ref at the END
|
| 794 |
+
# velocity target = noise - clean
|
| 795 |
+
tgt_patchified = audio_tools.patchifier.patchify(tgt_latent) # [B, tgt_T, 128]
|
| 796 |
+
target_velocity = noise - tgt_patchified
|
| 797 |
+
|
| 798 |
+
# Extract target portion of prediction
|
| 799 |
+
pred_tgt = velocity_pred[:, :tgt_T] # [B, tgt_T, 128]
|
| 800 |
+
|
| 801 |
+
# MSE loss with mask: only on real target tokens (not padding or ref)
|
| 802 |
+
per_token_mse = (pred_tgt - target_velocity).pow(2).mean(dim=-1) # [B, tgt_T]
|
| 803 |
+
loss = per_token_mse.mul(loss_mask).div(loss_mask.mean().clamp(min=1e-6)).mean()
|
| 804 |
+
|
| 805 |
+
accelerator.backward(loss)
|
| 806 |
+
|
| 807 |
+
if accelerator.sync_gradients and args.max_grad_norm > 0:
|
| 808 |
+
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
| 809 |
+
|
| 810 |
+
optimizer.step()
|
| 811 |
+
optimizer.zero_grad()
|
| 812 |
+
# Only advance the LR scheduler once per OPTIMIZER step (not per
|
| 813 |
+
# micro-step). Mirrors AcceleratedOptimizer.step() which is
|
| 814 |
+
# internally gated on sync_gradients.
|
| 815 |
+
if accelerator.sync_gradients:
|
| 816 |
+
scheduler.step()
|
| 817 |
+
|
| 818 |
+
accum_loss += loss.item()
|
| 819 |
+
|
| 820 |
+
# Logging & saving on optimization steps only
|
| 821 |
+
if is_opt_step and step % args.log_every == 0 and is_main:
|
| 822 |
+
avg_loss = accum_loss / (args.log_every * args.grad_accum)
|
| 823 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 824 |
+
elapsed = time.time() - t0
|
| 825 |
+
sps = step / elapsed if elapsed > 0 else 0
|
| 826 |
+
eta = (args.steps - step) / sps if sps > 0 else 0
|
| 827 |
+
logging.info(
|
| 828 |
+
f"Step {step}/{args.steps} | loss={avg_loss:.4f} | lr={lr:.2e} | "
|
| 829 |
+
f"tgt_T={tgt_T} ref_T={ref_T_frames} total={tgt_T + ref_T_frames} | "
|
| 830 |
+
f"{sps:.1f} steps/s | ETA {eta/60:.0f}min"
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
# Save best whenever loss improves — no warmup gate, so we can
|
| 834 |
+
# observe best checkpoints during warmup too.
|
| 835 |
+
if avg_loss < best_loss:
|
| 836 |
+
best_loss = avg_loss
|
| 837 |
+
old_best = os.path.join(args.output_dir, f"best_step_{best_step:05d}.safetensors")
|
| 838 |
+
best_step = step + args.resume_step_offset
|
| 839 |
+
new_best = os.path.join(args.output_dir, f"best_step_{best_step:05d}.safetensors")
|
| 840 |
+
unwrapped = _unwrap_model_safe(model)
|
| 841 |
+
unwrapped.save_pretrained(args.output_dir)
|
| 842 |
+
adapter = os.path.join(args.output_dir, "adapter_model.safetensors")
|
| 843 |
+
if os.path.exists(adapter):
|
| 844 |
+
shutil.copy(adapter, new_best)
|
| 845 |
+
if old_best != new_best and os.path.exists(old_best):
|
| 846 |
+
os.remove(old_best)
|
| 847 |
+
logging.info(f"New best: loss={best_loss:.4f} at step {best_step}")
|
| 848 |
+
|
| 849 |
+
accum_loss = 0.0
|
| 850 |
+
|
| 851 |
+
if is_opt_step and step % args.save_every == 0 and is_main:
|
| 852 |
+
global_step = step + args.resume_step_offset
|
| 853 |
+
save_path = os.path.join(args.output_dir, f"lora_step_{global_step:05d}.safetensors")
|
| 854 |
+
logging.info(f"Saving: {save_path}")
|
| 855 |
+
unwrapped = _unwrap_model_safe(model)
|
| 856 |
+
unwrapped.save_pretrained(args.output_dir)
|
| 857 |
+
adapter = os.path.join(args.output_dir, "adapter_model.safetensors")
|
| 858 |
+
if os.path.exists(adapter):
|
| 859 |
+
shutil.copy(adapter, save_path)
|
| 860 |
+
|
| 861 |
+
if args.val_config:
|
| 862 |
+
logging.info(f"Running validation at step {global_step}...")
|
| 863 |
+
model.eval()
|
| 864 |
+
run_validation(save_path, args.val_config, args.output_dir, global_step,
|
| 865 |
+
lora_rank=args.lora_rank)
|
| 866 |
+
model.train()
|
| 867 |
+
|
| 868 |
+
# Final save
|
| 869 |
+
if is_main:
|
| 870 |
+
unwrapped = _unwrap_model_safe(model)
|
| 871 |
+
unwrapped.save_pretrained(args.output_dir)
|
| 872 |
+
adapter = os.path.join(args.output_dir, "adapter_model.safetensors")
|
| 873 |
+
global_step = step + args.resume_step_offset
|
| 874 |
+
save_path = os.path.join(args.output_dir, f"lora_step_{global_step:05d}.safetensors")
|
| 875 |
+
if os.path.exists(adapter):
|
| 876 |
+
shutil.copy(adapter, save_path)
|
| 877 |
+
logging.info(f"Training complete! {step} steps in {time.time()-t0:.0f}s")
|
| 878 |
+
logging.info(f"Best loss: {best_loss:.4f} at step {best_step}")
|
| 879 |
+
|
| 880 |
+
|
| 881 |
+
if __name__ == "__main__":
|
| 882 |
+
main()
|
dramabox_src/validate.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Warm validation runner — loads base dev + LoRA + all aux models ONCE,
|
| 3 |
+
then iterates every speaker in val_config generating each output.
|
| 4 |
+
|
| 5 |
+
Matches the same generation path as inference.py but keeps Gemma / audio VAE
|
| 6 |
+
/ velocity model / audio decoder resident across entries. Inference
|
| 7 |
+
settings default to the Gradio warm-server values (cfg=2.5, stg=1.5,
|
| 8 |
+
modality=1.0, rescale=0, 30 steps, fps=25) — use --inference-params to
|
| 9 |
+
override.
|
| 10 |
+
"""
|
| 11 |
+
import argparse
|
| 12 |
+
import logging
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
import time
|
| 16 |
+
import traceback
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torchaudio
|
| 20 |
+
|
| 21 |
+
REPO_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 22 |
+
MODEL_DIR = REPO_DIR
|
| 23 |
+
sys.path.insert(0, os.path.join(REPO_DIR, "ltx2"))
|
| 24 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 25 |
+
|
| 26 |
+
DEV_FULL_CKPT = os.environ.get(
|
| 27 |
+
"LTX_FULL_CHECKPOINT",
|
| 28 |
+
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ltx-2.3-22b-dev.safetensors"),
|
| 29 |
+
)
|
| 30 |
+
GEMMA_ROOT = os.environ.get(
|
| 31 |
+
"GEMMA_ROOT",
|
| 32 |
+
os.path.expanduser("~/.cache/dramabox/gemma-3-12b-it-bnb-4bit"),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def parse_args():
|
| 37 |
+
p = argparse.ArgumentParser()
|
| 38 |
+
p.add_argument("--val-config", required=True)
|
| 39 |
+
p.add_argument("--output-dir", required=True)
|
| 40 |
+
p.add_argument("--lora", default=None)
|
| 41 |
+
p.add_argument("--lora-rank", type=int, default=128)
|
| 42 |
+
p.add_argument("--full-checkpoint", default=DEV_FULL_CKPT)
|
| 43 |
+
p.add_argument("--gemma-root", default=GEMMA_ROOT)
|
| 44 |
+
p.add_argument("--cfg-scale", type=float, default=2.5)
|
| 45 |
+
p.add_argument("--stg-scale", type=float, default=1.5)
|
| 46 |
+
p.add_argument("--rescale-scale", type=float, default=0.0)
|
| 47 |
+
p.add_argument("--modality-scale", type=float, default=1.0)
|
| 48 |
+
p.add_argument("--steps", type=int, default=30)
|
| 49 |
+
p.add_argument("--fps", type=float, default=25.0)
|
| 50 |
+
p.add_argument("--stg-block", type=int, default=29)
|
| 51 |
+
p.add_argument("--cfg-clamp", type=float, default=0.0)
|
| 52 |
+
p.add_argument("--seed", type=int, default=42)
|
| 53 |
+
p.add_argument("--duration-multiplier", type=float, default=1.1)
|
| 54 |
+
# Match Gradio / inference_server.py DEFAULT_NEG exactly
|
| 55 |
+
p.add_argument("--negative-prompt", default=(
|
| 56 |
+
"worst quality, inconsistent, robotic, distorted, noise, static, "
|
| 57 |
+
"muffled, unclear, unnatural, monotone"
|
| 58 |
+
))
|
| 59 |
+
return p.parse_args()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def estimate_speech_duration(prompt: str, speed: float = 1.0) -> float:
|
| 63 |
+
import re
|
| 64 |
+
quoted = re.findall(r'"([^"]*)"', prompt) or re.findall(r"'([^']*)'", prompt)
|
| 65 |
+
text = " ".join(quoted) if quoted else prompt
|
| 66 |
+
duration = len(text) * 0.065 / max(speed, 0.1) + 1.5
|
| 67 |
+
return max(3.0, round(duration, 1))
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class WarmValidator:
|
| 71 |
+
def __init__(self, full_checkpoint, gemma_root, lora_path=None, lora_rank=128,
|
| 72 |
+
device="cuda", dtype=torch.bfloat16):
|
| 73 |
+
from audio_conditioning import AudioConditionByReferenceLatent # noqa: F401 (imported by inference.py)
|
| 74 |
+
from ltx_core.components.patchifiers import AudioPatchifier
|
| 75 |
+
from ltx_pipelines.utils.blocks import PromptEncoder, AudioConditioner, AudioDecoder
|
| 76 |
+
|
| 77 |
+
self.device = torch.device(device)
|
| 78 |
+
self.dtype = dtype
|
| 79 |
+
self.full_checkpoint = full_checkpoint
|
| 80 |
+
self.gemma_root = gemma_root
|
| 81 |
+
self.patchifier = AudioPatchifier(patch_size=1)
|
| 82 |
+
|
| 83 |
+
logging.info("Loading PromptEncoder (Gemma + embeddings_processor)...")
|
| 84 |
+
t0 = time.time()
|
| 85 |
+
self.prompt_encoder = PromptEncoder(
|
| 86 |
+
checkpoint_path=full_checkpoint, gemma_root=gemma_root,
|
| 87 |
+
dtype=dtype, device=self.device, warm=True, audio_only=True,
|
| 88 |
+
)
|
| 89 |
+
logging.info(f" PromptEncoder ready in {time.time()-t0:.1f}s")
|
| 90 |
+
|
| 91 |
+
logging.info("Loading AudioConditioner (audio VAE encoder)...")
|
| 92 |
+
t0 = time.time()
|
| 93 |
+
self.audio_conditioner = AudioConditioner(
|
| 94 |
+
checkpoint_path=full_checkpoint, dtype=dtype, device=self.device, warm=True,
|
| 95 |
+
)
|
| 96 |
+
logging.info(f" AudioConditioner ready in {time.time()-t0:.1f}s")
|
| 97 |
+
|
| 98 |
+
logging.info("Loading AudioDecoder...")
|
| 99 |
+
t0 = time.time()
|
| 100 |
+
self.audio_decoder = AudioDecoder(
|
| 101 |
+
checkpoint_path=full_checkpoint, dtype=dtype, device=self.device, warm=True,
|
| 102 |
+
)
|
| 103 |
+
logging.info(f" AudioDecoder ready in {time.time()-t0:.1f}s")
|
| 104 |
+
|
| 105 |
+
logging.info("Building velocity model (audio-only from base dev)...")
|
| 106 |
+
t0 = time.time()
|
| 107 |
+
self.velocity_model = self._build_velocity_model(full_checkpoint, lora_path, lora_rank)
|
| 108 |
+
logging.info(f" Velocity model ready in {time.time()-t0:.1f}s "
|
| 109 |
+
f"({sum(p.numel() for p in self.velocity_model.parameters()) / 1e9:.1f}B params)")
|
| 110 |
+
|
| 111 |
+
def _build_velocity_model(self, checkpoint_path, lora_path, lora_rank):
|
| 112 |
+
from ltx_core.loader.registry import DummyRegistry
|
| 113 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 114 |
+
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
|
| 115 |
+
from ltx_core.model.model_protocol import ModelConfigurator
|
| 116 |
+
from ltx_core.model.transformer.attention import AttentionFunction
|
| 117 |
+
from ltx_core.model.transformer.model import LTXModel, LTXModelType
|
| 118 |
+
from ltx_core.model.transformer.rope import LTXRopeType
|
| 119 |
+
|
| 120 |
+
sd_ops = (
|
| 121 |
+
SDOps("AO")
|
| 122 |
+
.with_matching(prefix="model.diffusion_model.")
|
| 123 |
+
.with_replacement("model.diffusion_model.", "")
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
class Cfg(ModelConfigurator[LTXModel]):
|
| 127 |
+
@classmethod
|
| 128 |
+
def from_config(cls, config):
|
| 129 |
+
t = config.get("transformer", {})
|
| 130 |
+
cp = None
|
| 131 |
+
if not t.get("caption_proj_before_connector", False):
|
| 132 |
+
from ltx_core.model.transformer.text_projection import create_caption_projection
|
| 133 |
+
with torch.device("meta"):
|
| 134 |
+
cp = create_caption_projection(t, audio=True)
|
| 135 |
+
return LTXModel(
|
| 136 |
+
model_type=LTXModelType.AudioOnly,
|
| 137 |
+
audio_num_attention_heads=t.get("audio_num_attention_heads", 32),
|
| 138 |
+
audio_attention_head_dim=t.get("audio_attention_head_dim", 64),
|
| 139 |
+
audio_in_channels=t.get("audio_in_channels", 128),
|
| 140 |
+
audio_out_channels=t.get("audio_out_channels", 128),
|
| 141 |
+
num_layers=t.get("num_layers", 48),
|
| 142 |
+
audio_cross_attention_dim=t.get("audio_cross_attention_dim", 2048),
|
| 143 |
+
norm_eps=t.get("norm_eps", 1e-6),
|
| 144 |
+
attention_type=AttentionFunction(t.get("attention_type", "default")),
|
| 145 |
+
positional_embedding_theta=10000.0,
|
| 146 |
+
audio_positional_embedding_max_pos=[20.0],
|
| 147 |
+
timestep_scale_multiplier=t.get("timestep_scale_multiplier", 1000),
|
| 148 |
+
use_middle_indices_grid=t.get("use_middle_indices_grid", True),
|
| 149 |
+
rope_type=LTXRopeType(t.get("rope_type", "interleaved")),
|
| 150 |
+
double_precision_rope=t.get("frequencies_precision", False) == "float64",
|
| 151 |
+
apply_gated_attention=t.get("apply_gated_attention", False),
|
| 152 |
+
audio_caption_projection=cp,
|
| 153 |
+
cross_attention_adaln=t.get("cross_attention_adaln", False),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
builder = Builder(
|
| 157 |
+
model_path=checkpoint_path, model_class_configurator=Cfg,
|
| 158 |
+
model_sd_ops=sd_ops, registry=DummyRegistry(),
|
| 159 |
+
)
|
| 160 |
+
velocity = builder.build(device=self.device, dtype=self.dtype).to(self.device).eval()
|
| 161 |
+
|
| 162 |
+
if lora_path and os.path.exists(lora_path):
|
| 163 |
+
from peft import LoraConfig, get_peft_model
|
| 164 |
+
from safetensors.torch import load_file as st_load
|
| 165 |
+
logging.info(f"Attaching LoRA: {lora_path}")
|
| 166 |
+
lora_sd = st_load(lora_path)
|
| 167 |
+
is_peft = any("base_model.model." in k for k in lora_sd.keys())
|
| 168 |
+
is_iclora = any("diffusion_model." in k for k in lora_sd.keys())
|
| 169 |
+
cfg = LoraConfig(
|
| 170 |
+
r=lora_rank, lora_alpha=lora_rank, lora_dropout=0.0, bias="none",
|
| 171 |
+
target_modules=[
|
| 172 |
+
"audio_attn1.to_k", "audio_attn1.to_q",
|
| 173 |
+
"audio_attn1.to_v", "audio_attn1.to_out.0",
|
| 174 |
+
"audio_attn2.to_k", "audio_attn2.to_q",
|
| 175 |
+
"audio_attn2.to_v", "audio_attn2.to_out.0",
|
| 176 |
+
"audio_ff.net.0.proj", "audio_ff.net.2",
|
| 177 |
+
],
|
| 178 |
+
)
|
| 179 |
+
velocity = get_peft_model(velocity, cfg)
|
| 180 |
+
|
| 181 |
+
if is_peft:
|
| 182 |
+
mapped = {}
|
| 183 |
+
for k, v in lora_sd.items():
|
| 184 |
+
nk = k
|
| 185 |
+
if ".lora_A.weight" in k and ".lora_A.default.weight" not in k:
|
| 186 |
+
nk = k.replace(".lora_A.weight", ".lora_A.default.weight")
|
| 187 |
+
if ".lora_B.weight" in k and ".lora_B.default.weight" not in k:
|
| 188 |
+
nk = k.replace(".lora_B.weight", ".lora_B.default.weight")
|
| 189 |
+
mapped[nk] = v
|
| 190 |
+
_, unexpected = velocity.load_state_dict(mapped, strict=False)
|
| 191 |
+
logging.info(f" Loaded {len(mapped) - len(unexpected)} LoRA weights (peft)")
|
| 192 |
+
elif is_iclora:
|
| 193 |
+
audio_keys = {k: v for k, v in lora_sd.items()
|
| 194 |
+
if "audio_attn1" in k or "audio_attn2" in k or "audio_ff" in k}
|
| 195 |
+
mapped = {}
|
| 196 |
+
for k, v in audio_keys.items():
|
| 197 |
+
nk = k.replace("diffusion_model.", "base_model.model.")
|
| 198 |
+
nk = nk.replace(".lora_A.weight", ".lora_A.default.weight")
|
| 199 |
+
nk = nk.replace(".lora_B.weight", ".lora_B.default.weight")
|
| 200 |
+
mapped[nk] = v
|
| 201 |
+
_, unexpected = velocity.load_state_dict(mapped, strict=False)
|
| 202 |
+
logging.info(f" Loaded {len(mapped) - len(unexpected)} LoRA weights (iclora)")
|
| 203 |
+
|
| 204 |
+
velocity = velocity.merge_and_unload()
|
| 205 |
+
logging.info(" Merged LoRA into base weights")
|
| 206 |
+
|
| 207 |
+
return velocity
|
| 208 |
+
|
| 209 |
+
@torch.inference_mode()
|
| 210 |
+
def generate(self, prompt, output_path, voice_ref=None, args=None):
|
| 211 |
+
from audio_conditioning import AudioConditionByReferenceLatent
|
| 212 |
+
from ltx_core.batch_split import BatchSplitAdapter
|
| 213 |
+
from ltx_core.components.diffusion_steps import EulerDiffusionStep
|
| 214 |
+
from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
|
| 215 |
+
from ltx_core.components.noisers import GaussianNoiser
|
| 216 |
+
from ltx_core.components.schedulers import LTX2Scheduler
|
| 217 |
+
from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
|
| 218 |
+
from ltx_core.model.transformer.model import X0Model
|
| 219 |
+
from ltx_core.tools import AudioLatentTools
|
| 220 |
+
from ltx_core.types import Audio, AudioLatentShape, VideoPixelShape
|
| 221 |
+
from ltx_pipelines.utils.denoisers import GuidedDenoiser, SimpleDenoiser
|
| 222 |
+
from ltx_pipelines.utils.gpu_model import gpu_model
|
| 223 |
+
from ltx_pipelines.utils.media_io import decode_audio_from_file
|
| 224 |
+
from ltx_pipelines.utils.samplers import euler_denoising_loop
|
| 225 |
+
|
| 226 |
+
t_total = time.time()
|
| 227 |
+
|
| 228 |
+
# ---- Duration + shape ----
|
| 229 |
+
gen_dur = estimate_speech_duration(prompt) * args.duration_multiplier
|
| 230 |
+
raw_frames = int(round(gen_dur * args.fps)) + 1
|
| 231 |
+
num_frames = ((raw_frames - 1 + 4) // 8) * 8 + 1
|
| 232 |
+
pixel_shape = VideoPixelShape(batch=1, frames=num_frames, height=64, width=64, fps=args.fps)
|
| 233 |
+
tgt_shape = AudioLatentShape.from_video_pixel_shape(pixel_shape)
|
| 234 |
+
audio_tools = AudioLatentTools(patchifier=self.patchifier, target_shape=tgt_shape)
|
| 235 |
+
|
| 236 |
+
state = audio_tools.create_initial_state(self.device, self.dtype)
|
| 237 |
+
|
| 238 |
+
# ---- Voice reference ----
|
| 239 |
+
if voice_ref and os.path.exists(voice_ref):
|
| 240 |
+
voice = decode_audio_from_file(voice_ref, self.device, 0.0, 10.0)
|
| 241 |
+
if voice is not None:
|
| 242 |
+
w = voice.waveform
|
| 243 |
+
if w.dim() == 2:
|
| 244 |
+
if w.shape[0] == 1:
|
| 245 |
+
w = w.repeat(2, 1)
|
| 246 |
+
w = w.unsqueeze(0)
|
| 247 |
+
elif w.dim() == 3 and w.shape[1] == 1:
|
| 248 |
+
w = w.repeat(1, 2, 1)
|
| 249 |
+
target_samples = int(10.0 * voice.sampling_rate)
|
| 250 |
+
if w.shape[-1] < target_samples:
|
| 251 |
+
w = w.repeat(1, 1, (target_samples // w.shape[-1]) + 1)
|
| 252 |
+
w = w[..., :target_samples]
|
| 253 |
+
peak = w.abs().max()
|
| 254 |
+
if peak > 0:
|
| 255 |
+
w = w * (10 ** (-4.0 / 20) / peak)
|
| 256 |
+
voice = Audio(waveform=w, sampling_rate=voice.sampling_rate)
|
| 257 |
+
ref_latent = self.audio_conditioner(lambda enc: vae_encode_audio(voice, enc, None))
|
| 258 |
+
cond = AudioConditionByReferenceLatent(
|
| 259 |
+
latent=ref_latent.to(self.device, self.dtype), strength=1.0,
|
| 260 |
+
)
|
| 261 |
+
state = cond.apply_to(latent_state=state, latent_tools=audio_tools)
|
| 262 |
+
|
| 263 |
+
# ---- Noise ----
|
| 264 |
+
gen = torch.Generator(device=self.device).manual_seed(args.seed)
|
| 265 |
+
noiser = GaussianNoiser(generator=gen)
|
| 266 |
+
state = noiser(state, noise_scale=1.0)
|
| 267 |
+
|
| 268 |
+
# ---- Prompt encode ----
|
| 269 |
+
use_cfg = args.cfg_scale > 1.0
|
| 270 |
+
prompts = [prompt, args.negative_prompt] if use_cfg else [prompt]
|
| 271 |
+
ctx = self.prompt_encoder(prompts, streaming_prefetch_count=None)
|
| 272 |
+
a_ctx = ctx[0].audio_encoding
|
| 273 |
+
a_ctx_neg = ctx[1].audio_encoding if use_cfg else None
|
| 274 |
+
|
| 275 |
+
# ---- Denoiser ----
|
| 276 |
+
needs_guidance = args.cfg_scale > 1.0 or args.stg_scale > 0.0 or args.modality_scale > 1.0
|
| 277 |
+
if needs_guidance:
|
| 278 |
+
guider = MultiModalGuider(
|
| 279 |
+
params=MultiModalGuiderParams(
|
| 280 |
+
cfg_scale=args.cfg_scale, stg_scale=args.stg_scale,
|
| 281 |
+
stg_blocks=[args.stg_block] if args.stg_scale > 0 else [],
|
| 282 |
+
rescale_scale=args.rescale_scale,
|
| 283 |
+
modality_scale=args.modality_scale,
|
| 284 |
+
cfg_clamp_scale=args.cfg_clamp,
|
| 285 |
+
),
|
| 286 |
+
negative_context=a_ctx_neg,
|
| 287 |
+
)
|
| 288 |
+
denoiser = GuidedDenoiser(
|
| 289 |
+
v_context=None, a_context=a_ctx,
|
| 290 |
+
video_guider=None, audio_guider=guider,
|
| 291 |
+
)
|
| 292 |
+
else:
|
| 293 |
+
denoiser = SimpleDenoiser(v_context=None, a_context=a_ctx)
|
| 294 |
+
|
| 295 |
+
sigmas = LTX2Scheduler().execute(steps=args.steps, latent=state.latent).to(self.device)
|
| 296 |
+
|
| 297 |
+
# ---- Denoise ----
|
| 298 |
+
# NOTE: don't wrap in gpu_model() — that context manager moves the
|
| 299 |
+
# model back off GPU on exit, which breaks subsequent iterations of
|
| 300 |
+
# our warm validator. We keep the velocity model resident.
|
| 301 |
+
x0 = X0Model(self.velocity_model)
|
| 302 |
+
batched = BatchSplitAdapter(x0, max_batch_size=1)
|
| 303 |
+
_, audio_state = euler_denoising_loop(
|
| 304 |
+
sigmas=sigmas, video_state=None, audio_state=state,
|
| 305 |
+
stepper=EulerDiffusionStep(), transformer=batched, denoiser=denoiser,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
audio_state = audio_tools.clear_conditioning(audio_state)
|
| 309 |
+
audio_state = audio_tools.unpatchify(audio_state)
|
| 310 |
+
decoded = self.audio_decoder(audio_state.latent)
|
| 311 |
+
|
| 312 |
+
wav = decoded.waveform
|
| 313 |
+
if wav.dim() == 1:
|
| 314 |
+
wav = wav.unsqueeze(0)
|
| 315 |
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
| 316 |
+
torchaudio.save(output_path, wav.float().cpu(), decoded.sampling_rate)
|
| 317 |
+
logging.info(f" -> {output_path} ({wav.shape[-1]/decoded.sampling_rate:.1f}s, "
|
| 318 |
+
f"{time.time()-t_total:.1f}s)")
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def main():
|
| 322 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 323 |
+
args = parse_args()
|
| 324 |
+
import yaml
|
| 325 |
+
with open(args.val_config) as f:
|
| 326 |
+
val_cfg = yaml.safe_load(f)
|
| 327 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 328 |
+
|
| 329 |
+
# Build validator once (models warm for all entries).
|
| 330 |
+
validator = WarmValidator(
|
| 331 |
+
full_checkpoint=args.full_checkpoint,
|
| 332 |
+
gemma_root=args.gemma_root,
|
| 333 |
+
lora_path=args.lora,
|
| 334 |
+
lora_rank=args.lora_rank,
|
| 335 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 336 |
+
dtype=torch.bfloat16,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
n_ok = n_fail = 0
|
| 340 |
+
t0 = time.time()
|
| 341 |
+
for entry in val_cfg.get("speakers", []):
|
| 342 |
+
name = entry["name"]
|
| 343 |
+
out_path = os.path.join(args.output_dir, f"{name}.wav")
|
| 344 |
+
try:
|
| 345 |
+
validator.generate(
|
| 346 |
+
prompt=entry["prompt"],
|
| 347 |
+
output_path=out_path,
|
| 348 |
+
voice_ref=entry.get("reference"),
|
| 349 |
+
args=args,
|
| 350 |
+
)
|
| 351 |
+
n_ok += 1
|
| 352 |
+
logging.info(f" [{name}] OK")
|
| 353 |
+
except Exception as e:
|
| 354 |
+
n_fail += 1
|
| 355 |
+
logging.warning(f" [{name}] FAILED: {e}")
|
| 356 |
+
traceback.print_exc()
|
| 357 |
+
|
| 358 |
+
logging.info(f"Validation done: ok={n_ok} fail={n_fail} in {(time.time()-t0)/60:.1f}min "
|
| 359 |
+
f"at {args.output_dir}")
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
if __name__ == "__main__":
|
| 363 |
+
main()
|