Dramabox / ltx2 /ltx_trainer /model_loader.py
Manmay's picture
DramaBox Space — initial app + vendored ltx2
08c5e28 verified
raw
history blame
13.8 kB
# ruff: noqa: PLC0415
"""
Model loader for LTX-2 trainer using the new ltx-core package.
This module provides a unified interface for loading LTX-2 model components
for training, using SingleGPUModelBuilder from ltx-core.
Example usage:
# Load individual components
vae_encoder = load_video_vae_encoder("/path/to/checkpoint.safetensors", device="cuda")
vae_decoder = load_video_vae_decoder("/path/to/checkpoint.safetensors", device="cuda")
text_encoder = load_text_encoder("/path/to/gemma", device="cuda")
# Load all components at once
components = load_model("/path/to/checkpoint.safetensors", text_encoder_path="/path/to/gemma")
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
import torch
from ltx_trainer import logger
# Type alias for device specification
Device = str | torch.device
# Type checking imports (not loaded at runtime)
if TYPE_CHECKING:
from ltx_core.components.schedulers import LTX2Scheduler
from ltx_core.model.audio_vae import AudioDecoder, AudioEncoder, Vocoder
from ltx_core.model.transformer import LTXModel
from ltx_core.model.video_vae import VideoDecoder, VideoEncoder
from ltx_core.text_encoders.gemma import GemmaTextEncoder
from ltx_core.text_encoders.gemma.embeddings_processor import EmbeddingsProcessor
def _to_torch_device(device: Device) -> torch.device:
"""Convert device specification to torch.device."""
return torch.device(device) if isinstance(device, str) else device
# =============================================================================
# Individual Component Loaders
# =============================================================================
def load_transformer(
checkpoint_path: str | Path,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
) -> "LTXModel":
"""Load the LTX transformer model.
Args:
checkpoint_path: Path to the safetensors checkpoint file
device: Device to load model on
dtype: Data type for model weights
Returns:
Loaded LTXModel transformer
"""
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
from ltx_core.model.transformer.model_configurator import (
LTXV_MODEL_COMFY_RENAMING_MAP,
LTXModelConfigurator,
)
return SingleGPUModelBuilder(
model_path=str(checkpoint_path),
model_class_configurator=LTXModelConfigurator,
model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP,
).build(device=_to_torch_device(device), dtype=dtype)
def load_video_vae_encoder(
checkpoint_path: str | Path,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
) -> "VideoEncoder":
"""Load the video VAE encoder (for preprocessing).
Args:
checkpoint_path: Path to the safetensors checkpoint file
device: Device to load model on
dtype: Data type for model weights
Returns:
Loaded VideoEncoder
"""
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
from ltx_core.model.video_vae import VAE_ENCODER_COMFY_KEYS_FILTER, VideoEncoderConfigurator
return SingleGPUModelBuilder(
model_path=str(checkpoint_path),
model_class_configurator=VideoEncoderConfigurator,
model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER,
).build(device=_to_torch_device(device), dtype=dtype)
def load_video_vae_decoder(
checkpoint_path: str | Path,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
) -> "VideoDecoder":
"""Load the video VAE decoder (for inference/validation).
Args:
checkpoint_path: Path to the safetensors checkpoint file
device: Device to load model on
dtype: Data type for model weights
Returns:
Loaded VideoDecoder
"""
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
from ltx_core.model.video_vae import VAE_DECODER_COMFY_KEYS_FILTER, VideoDecoderConfigurator
return SingleGPUModelBuilder(
model_path=str(checkpoint_path),
model_class_configurator=VideoDecoderConfigurator,
model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER,
).build(device=_to_torch_device(device), dtype=dtype)
def load_audio_vae_encoder(
checkpoint_path: str | Path,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
) -> "AudioEncoder":
"""Load the audio VAE encoder (for preprocessing).
Args:
checkpoint_path: Path to the safetensors checkpoint file
device: Device to load model on
dtype: Data type for model weights (default bfloat16, but float32 recommended for quality)
Returns:
Loaded AudioEncoder
"""
from ltx_core.loader import SingleGPUModelBuilder
from ltx_core.model.audio_vae import AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER, AudioEncoderConfigurator
return SingleGPUModelBuilder(
model_path=str(checkpoint_path),
model_class_configurator=AudioEncoderConfigurator,
model_sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
).build(device=_to_torch_device(device), dtype=dtype)
def load_audio_vae_decoder(
checkpoint_path: str | Path,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
) -> "AudioDecoder":
"""Load the audio VAE decoder.
Args:
checkpoint_path: Path to the safetensors checkpoint file
device: Device to load model on
dtype: Data type for model weights
Returns:
Loaded AudioDecoder
"""
from ltx_core.loader import SingleGPUModelBuilder
from ltx_core.model.audio_vae import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, AudioDecoderConfigurator
return SingleGPUModelBuilder(
model_path=str(checkpoint_path),
model_class_configurator=AudioDecoderConfigurator,
model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
).build(device=_to_torch_device(device), dtype=dtype)
def load_vocoder(
checkpoint_path: str | Path,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
) -> "Vocoder":
"""Load the vocoder (for audio waveform generation).
Args:
checkpoint_path: Path to the safetensors checkpoint file
device: Device to load model on
dtype: Data type for model weights
Returns:
Loaded Vocoder
"""
from ltx_core.loader import SingleGPUModelBuilder
from ltx_core.model.audio_vae import VOCODER_COMFY_KEYS_FILTER, VocoderConfigurator
return SingleGPUModelBuilder(
model_path=str(checkpoint_path),
model_class_configurator=VocoderConfigurator,
model_sd_ops=VOCODER_COMFY_KEYS_FILTER,
).build(device=_to_torch_device(device), dtype=dtype)
def load_text_encoder(
gemma_model_path: str | Path,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
load_in_8bit: bool = False,
) -> "GemmaTextEncoder":
"""Load the Gemma text encoder.
Args:
gemma_model_path: Path to Gemma model directory
device: Device to load model on
dtype: Data type for model weights
load_in_8bit: Whether to load the Gemma model in 8-bit precision using bitsandbytes.
When True, the model is loaded with device_map="auto" and the device argument
is ignored for the Gemma backbone.
Returns:
Loaded GemmaTextEncoder
"""
if not Path(gemma_model_path).is_dir():
raise ValueError(f"Gemma model path is not a directory: {gemma_model_path}")
# Use 8-bit loading path if requested
if load_in_8bit:
from ltx_trainer.gemma_8bit import load_8bit_gemma
return load_8bit_gemma(gemma_model_path, dtype)
# Standard loading path
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
from ltx_core.text_encoders.gemma import (
GEMMA_LLM_KEY_OPS,
GEMMA_MODEL_OPS,
GemmaTextEncoderConfigurator,
module_ops_from_gemma_root,
)
from ltx_core.utils import find_matching_file
torch_device = _to_torch_device(device)
gemma_model_folder = find_matching_file(str(gemma_model_path), "model*.safetensors").parent
gemma_weight_paths = [str(p) for p in gemma_model_folder.rglob("*.safetensors")]
text_encoder = SingleGPUModelBuilder(
model_path=tuple(gemma_weight_paths),
model_class_configurator=GemmaTextEncoderConfigurator,
model_sd_ops=GEMMA_LLM_KEY_OPS,
module_ops=(GEMMA_MODEL_OPS, *module_ops_from_gemma_root(str(gemma_model_path))),
).build(device=torch_device, dtype=dtype)
return text_encoder
def load_embeddings_processor(
checkpoint_path: str | Path,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
) -> "EmbeddingsProcessor":
"""Load the embeddings processor (feature extractor + video/audio connectors).
Args:
checkpoint_path: Path to the LTX-2 safetensors checkpoint file
device: Device to load model on
dtype: Data type for model weights
Returns:
Loaded EmbeddingsProcessor with feature extractor and connectors
"""
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
from ltx_core.text_encoders.gemma import (
EMBEDDINGS_PROCESSOR_KEY_OPS,
EmbeddingsProcessorConfigurator,
)
torch_device = _to_torch_device(device)
return SingleGPUModelBuilder(
model_path=str(checkpoint_path),
model_class_configurator=EmbeddingsProcessorConfigurator,
model_sd_ops=EMBEDDINGS_PROCESSOR_KEY_OPS,
).build(device=torch_device, dtype=dtype)
# =============================================================================
# Combined Component Loader
# =============================================================================
@dataclass
class LtxModelComponents:
"""Container for all LTX-2 model components."""
transformer: "LTXModel"
video_vae_encoder: "VideoEncoder | None" = None
video_vae_decoder: "VideoDecoder | None" = None
audio_vae_decoder: "AudioDecoder | None" = None
vocoder: "Vocoder | None" = None
text_encoder: "GemmaTextEncoder | None" = None
scheduler: "LTX2Scheduler | None" = None
def load_model(
checkpoint_path: str | Path,
text_encoder_path: str | Path | None = None,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
with_video_vae_encoder: bool = False,
with_video_vae_decoder: bool = True,
with_audio_vae_decoder: bool = True,
with_vocoder: bool = True,
with_text_encoder: bool = True,
) -> LtxModelComponents:
"""
Load LTX-2 model components from a safetensors checkpoint.
This is a convenience function that loads multiple components at once.
For loading individual components, use the dedicated functions:
- load_transformer()
- load_video_vae_encoder()
- load_video_vae_decoder()
- load_audio_vae_decoder()
- load_vocoder()
- load_text_encoder()
Args:
checkpoint_path: Path to the safetensors checkpoint file
text_encoder_path: Path to Gemma model directory (required if with_text_encoder=True)
device: Device to load models on ("cuda", "cpu", etc.)
dtype: Data type for model weights
with_video_vae_encoder: Whether to load the video VAE encoder (for preprocessing)
with_video_vae_decoder: Whether to load the video VAE decoder (for inference/validation)
with_audio_vae_decoder: Whether to load the audio VAE decoder
with_vocoder: Whether to load the vocoder
with_text_encoder: Whether to load the text encoder
Returns:
LtxModelComponents containing all loaded model components
"""
from ltx_core.components.schedulers import LTX2Scheduler
checkpoint_path = Path(checkpoint_path)
# Validate checkpoint exists
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
logger.info(f"Loading LTX-2 model from {checkpoint_path}")
torch_device = _to_torch_device(device)
# Load transformer
logger.debug("Loading transformer...")
transformer = load_transformer(checkpoint_path, torch_device, dtype)
# Load video VAE encoder
video_vae_encoder = None
if with_video_vae_encoder:
logger.debug("Loading video VAE encoder...")
video_vae_encoder = load_video_vae_encoder(checkpoint_path, torch_device, dtype)
# Load video VAE decoder
video_vae_decoder = None
if with_video_vae_decoder:
logger.debug("Loading video VAE decoder...")
video_vae_decoder = load_video_vae_decoder(checkpoint_path, torch_device, dtype)
# Load audio VAE decoder
audio_vae_decoder = None
if with_audio_vae_decoder:
logger.debug("Loading audio VAE decoder...")
audio_vae_decoder = load_audio_vae_decoder(checkpoint_path, torch_device, dtype)
# Load vocoder
vocoder = None
if with_vocoder:
logger.debug("Loading vocoder...")
vocoder = load_vocoder(checkpoint_path, torch_device, dtype)
# Load text encoder
text_encoder = None
if with_text_encoder:
if text_encoder_path is None:
raise ValueError("text_encoder_path must be provided when with_text_encoder=True")
logger.debug("Loading Gemma text encoder...")
text_encoder = load_text_encoder(text_encoder_path, torch_device, dtype)
# Create scheduler (stateless, no loading needed)
scheduler = LTX2Scheduler()
return LtxModelComponents(
transformer=transformer,
video_vae_encoder=video_vae_encoder,
video_vae_decoder=video_vae_decoder,
audio_vae_decoder=audio_vae_decoder,
vocoder=vocoder,
text_encoder=text_encoder,
scheduler=scheduler,
)