Spaces:
Running on Zero
Running on Zero
| # ruff: noqa: PLC0415 | |
| """ | |
| Model loader for LTX-2 trainer using the new ltx-core package. | |
| This module provides a unified interface for loading LTX-2 model components | |
| for training, using SingleGPUModelBuilder from ltx-core. | |
| Example usage: | |
| # Load individual components | |
| vae_encoder = load_video_vae_encoder("/path/to/checkpoint.safetensors", device="cuda") | |
| vae_decoder = load_video_vae_decoder("/path/to/checkpoint.safetensors", device="cuda") | |
| text_encoder = load_text_encoder("/path/to/gemma", device="cuda") | |
| # Load all components at once | |
| components = load_model("/path/to/checkpoint.safetensors", text_encoder_path="/path/to/gemma") | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING | |
| import torch | |
| from ltx_trainer import logger | |
| # Type alias for device specification | |
| Device = str | torch.device | |
| # Type checking imports (not loaded at runtime) | |
| if TYPE_CHECKING: | |
| from ltx_core.components.schedulers import LTX2Scheduler | |
| from ltx_core.model.audio_vae import AudioDecoder, AudioEncoder, Vocoder | |
| from ltx_core.model.transformer import LTXModel | |
| from ltx_core.model.video_vae import VideoDecoder, VideoEncoder | |
| from ltx_core.text_encoders.gemma import GemmaTextEncoder | |
| from ltx_core.text_encoders.gemma.embeddings_processor import EmbeddingsProcessor | |
| def _to_torch_device(device: Device) -> torch.device: | |
| """Convert device specification to torch.device.""" | |
| return torch.device(device) if isinstance(device, str) else device | |
| # ============================================================================= | |
| # Individual Component Loaders | |
| # ============================================================================= | |
| def load_transformer( | |
| checkpoint_path: str | Path, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| ) -> "LTXModel": | |
| """Load the LTX transformer model. | |
| Args: | |
| checkpoint_path: Path to the safetensors checkpoint file | |
| device: Device to load model on | |
| dtype: Data type for model weights | |
| Returns: | |
| Loaded LTXModel transformer | |
| """ | |
| from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder | |
| from ltx_core.model.transformer.model_configurator import ( | |
| LTXV_MODEL_COMFY_RENAMING_MAP, | |
| LTXModelConfigurator, | |
| ) | |
| return SingleGPUModelBuilder( | |
| model_path=str(checkpoint_path), | |
| model_class_configurator=LTXModelConfigurator, | |
| model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP, | |
| ).build(device=_to_torch_device(device), dtype=dtype) | |
| def load_video_vae_encoder( | |
| checkpoint_path: str | Path, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| ) -> "VideoEncoder": | |
| """Load the video VAE encoder (for preprocessing). | |
| Args: | |
| checkpoint_path: Path to the safetensors checkpoint file | |
| device: Device to load model on | |
| dtype: Data type for model weights | |
| Returns: | |
| Loaded VideoEncoder | |
| """ | |
| from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder | |
| from ltx_core.model.video_vae import VAE_ENCODER_COMFY_KEYS_FILTER, VideoEncoderConfigurator | |
| return SingleGPUModelBuilder( | |
| model_path=str(checkpoint_path), | |
| model_class_configurator=VideoEncoderConfigurator, | |
| model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER, | |
| ).build(device=_to_torch_device(device), dtype=dtype) | |
| def load_video_vae_decoder( | |
| checkpoint_path: str | Path, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| ) -> "VideoDecoder": | |
| """Load the video VAE decoder (for inference/validation). | |
| Args: | |
| checkpoint_path: Path to the safetensors checkpoint file | |
| device: Device to load model on | |
| dtype: Data type for model weights | |
| Returns: | |
| Loaded VideoDecoder | |
| """ | |
| from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder | |
| from ltx_core.model.video_vae import VAE_DECODER_COMFY_KEYS_FILTER, VideoDecoderConfigurator | |
| return SingleGPUModelBuilder( | |
| model_path=str(checkpoint_path), | |
| model_class_configurator=VideoDecoderConfigurator, | |
| model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER, | |
| ).build(device=_to_torch_device(device), dtype=dtype) | |
| def load_audio_vae_encoder( | |
| checkpoint_path: str | Path, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| ) -> "AudioEncoder": | |
| """Load the audio VAE encoder (for preprocessing). | |
| Args: | |
| checkpoint_path: Path to the safetensors checkpoint file | |
| device: Device to load model on | |
| dtype: Data type for model weights (default bfloat16, but float32 recommended for quality) | |
| Returns: | |
| Loaded AudioEncoder | |
| """ | |
| from ltx_core.loader import SingleGPUModelBuilder | |
| from ltx_core.model.audio_vae import AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER, AudioEncoderConfigurator | |
| return SingleGPUModelBuilder( | |
| model_path=str(checkpoint_path), | |
| model_class_configurator=AudioEncoderConfigurator, | |
| model_sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER, | |
| ).build(device=_to_torch_device(device), dtype=dtype) | |
| def load_audio_vae_decoder( | |
| checkpoint_path: str | Path, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| ) -> "AudioDecoder": | |
| """Load the audio VAE decoder. | |
| Args: | |
| checkpoint_path: Path to the safetensors checkpoint file | |
| device: Device to load model on | |
| dtype: Data type for model weights | |
| Returns: | |
| Loaded AudioDecoder | |
| """ | |
| from ltx_core.loader import SingleGPUModelBuilder | |
| from ltx_core.model.audio_vae import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, AudioDecoderConfigurator | |
| return SingleGPUModelBuilder( | |
| model_path=str(checkpoint_path), | |
| model_class_configurator=AudioDecoderConfigurator, | |
| model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, | |
| ).build(device=_to_torch_device(device), dtype=dtype) | |
| def load_vocoder( | |
| checkpoint_path: str | Path, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| ) -> "Vocoder": | |
| """Load the vocoder (for audio waveform generation). | |
| Args: | |
| checkpoint_path: Path to the safetensors checkpoint file | |
| device: Device to load model on | |
| dtype: Data type for model weights | |
| Returns: | |
| Loaded Vocoder | |
| """ | |
| from ltx_core.loader import SingleGPUModelBuilder | |
| from ltx_core.model.audio_vae import VOCODER_COMFY_KEYS_FILTER, VocoderConfigurator | |
| return SingleGPUModelBuilder( | |
| model_path=str(checkpoint_path), | |
| model_class_configurator=VocoderConfigurator, | |
| model_sd_ops=VOCODER_COMFY_KEYS_FILTER, | |
| ).build(device=_to_torch_device(device), dtype=dtype) | |
| def load_text_encoder( | |
| gemma_model_path: str | Path, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| load_in_8bit: bool = False, | |
| ) -> "GemmaTextEncoder": | |
| """Load the Gemma text encoder. | |
| Args: | |
| gemma_model_path: Path to Gemma model directory | |
| device: Device to load model on | |
| dtype: Data type for model weights | |
| load_in_8bit: Whether to load the Gemma model in 8-bit precision using bitsandbytes. | |
| When True, the model is loaded with device_map="auto" and the device argument | |
| is ignored for the Gemma backbone. | |
| Returns: | |
| Loaded GemmaTextEncoder | |
| """ | |
| if not Path(gemma_model_path).is_dir(): | |
| raise ValueError(f"Gemma model path is not a directory: {gemma_model_path}") | |
| # Use 8-bit loading path if requested | |
| if load_in_8bit: | |
| from ltx_trainer.gemma_8bit import load_8bit_gemma | |
| return load_8bit_gemma(gemma_model_path, dtype) | |
| # Standard loading path | |
| from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder | |
| from ltx_core.text_encoders.gemma import ( | |
| GEMMA_LLM_KEY_OPS, | |
| GEMMA_MODEL_OPS, | |
| GemmaTextEncoderConfigurator, | |
| module_ops_from_gemma_root, | |
| ) | |
| from ltx_core.utils import find_matching_file | |
| torch_device = _to_torch_device(device) | |
| gemma_model_folder = find_matching_file(str(gemma_model_path), "model*.safetensors").parent | |
| gemma_weight_paths = [str(p) for p in gemma_model_folder.rglob("*.safetensors")] | |
| text_encoder = SingleGPUModelBuilder( | |
| model_path=tuple(gemma_weight_paths), | |
| model_class_configurator=GemmaTextEncoderConfigurator, | |
| model_sd_ops=GEMMA_LLM_KEY_OPS, | |
| module_ops=(GEMMA_MODEL_OPS, *module_ops_from_gemma_root(str(gemma_model_path))), | |
| ).build(device=torch_device, dtype=dtype) | |
| return text_encoder | |
| def load_embeddings_processor( | |
| checkpoint_path: str | Path, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| ) -> "EmbeddingsProcessor": | |
| """Load the embeddings processor (feature extractor + video/audio connectors). | |
| Args: | |
| checkpoint_path: Path to the LTX-2 safetensors checkpoint file | |
| device: Device to load model on | |
| dtype: Data type for model weights | |
| Returns: | |
| Loaded EmbeddingsProcessor with feature extractor and connectors | |
| """ | |
| from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder | |
| from ltx_core.text_encoders.gemma import ( | |
| EMBEDDINGS_PROCESSOR_KEY_OPS, | |
| EmbeddingsProcessorConfigurator, | |
| ) | |
| torch_device = _to_torch_device(device) | |
| return SingleGPUModelBuilder( | |
| model_path=str(checkpoint_path), | |
| model_class_configurator=EmbeddingsProcessorConfigurator, | |
| model_sd_ops=EMBEDDINGS_PROCESSOR_KEY_OPS, | |
| ).build(device=torch_device, dtype=dtype) | |
| # ============================================================================= | |
| # Combined Component Loader | |
| # ============================================================================= | |
| class LtxModelComponents: | |
| """Container for all LTX-2 model components.""" | |
| transformer: "LTXModel" | |
| video_vae_encoder: "VideoEncoder | None" = None | |
| video_vae_decoder: "VideoDecoder | None" = None | |
| audio_vae_decoder: "AudioDecoder | None" = None | |
| vocoder: "Vocoder | None" = None | |
| text_encoder: "GemmaTextEncoder | None" = None | |
| scheduler: "LTX2Scheduler | None" = None | |
| def load_model( | |
| checkpoint_path: str | Path, | |
| text_encoder_path: str | Path | None = None, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| with_video_vae_encoder: bool = False, | |
| with_video_vae_decoder: bool = True, | |
| with_audio_vae_decoder: bool = True, | |
| with_vocoder: bool = True, | |
| with_text_encoder: bool = True, | |
| ) -> LtxModelComponents: | |
| """ | |
| Load LTX-2 model components from a safetensors checkpoint. | |
| This is a convenience function that loads multiple components at once. | |
| For loading individual components, use the dedicated functions: | |
| - load_transformer() | |
| - load_video_vae_encoder() | |
| - load_video_vae_decoder() | |
| - load_audio_vae_decoder() | |
| - load_vocoder() | |
| - load_text_encoder() | |
| Args: | |
| checkpoint_path: Path to the safetensors checkpoint file | |
| text_encoder_path: Path to Gemma model directory (required if with_text_encoder=True) | |
| device: Device to load models on ("cuda", "cpu", etc.) | |
| dtype: Data type for model weights | |
| with_video_vae_encoder: Whether to load the video VAE encoder (for preprocessing) | |
| with_video_vae_decoder: Whether to load the video VAE decoder (for inference/validation) | |
| with_audio_vae_decoder: Whether to load the audio VAE decoder | |
| with_vocoder: Whether to load the vocoder | |
| with_text_encoder: Whether to load the text encoder | |
| Returns: | |
| LtxModelComponents containing all loaded model components | |
| """ | |
| from ltx_core.components.schedulers import LTX2Scheduler | |
| checkpoint_path = Path(checkpoint_path) | |
| # Validate checkpoint exists | |
| if not checkpoint_path.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| logger.info(f"Loading LTX-2 model from {checkpoint_path}") | |
| torch_device = _to_torch_device(device) | |
| # Load transformer | |
| logger.debug("Loading transformer...") | |
| transformer = load_transformer(checkpoint_path, torch_device, dtype) | |
| # Load video VAE encoder | |
| video_vae_encoder = None | |
| if with_video_vae_encoder: | |
| logger.debug("Loading video VAE encoder...") | |
| video_vae_encoder = load_video_vae_encoder(checkpoint_path, torch_device, dtype) | |
| # Load video VAE decoder | |
| video_vae_decoder = None | |
| if with_video_vae_decoder: | |
| logger.debug("Loading video VAE decoder...") | |
| video_vae_decoder = load_video_vae_decoder(checkpoint_path, torch_device, dtype) | |
| # Load audio VAE decoder | |
| audio_vae_decoder = None | |
| if with_audio_vae_decoder: | |
| logger.debug("Loading audio VAE decoder...") | |
| audio_vae_decoder = load_audio_vae_decoder(checkpoint_path, torch_device, dtype) | |
| # Load vocoder | |
| vocoder = None | |
| if with_vocoder: | |
| logger.debug("Loading vocoder...") | |
| vocoder = load_vocoder(checkpoint_path, torch_device, dtype) | |
| # Load text encoder | |
| text_encoder = None | |
| if with_text_encoder: | |
| if text_encoder_path is None: | |
| raise ValueError("text_encoder_path must be provided when with_text_encoder=True") | |
| logger.debug("Loading Gemma text encoder...") | |
| text_encoder = load_text_encoder(text_encoder_path, torch_device, dtype) | |
| # Create scheduler (stateless, no loading needed) | |
| scheduler = LTX2Scheduler() | |
| return LtxModelComponents( | |
| transformer=transformer, | |
| video_vae_encoder=video_vae_encoder, | |
| video_vae_decoder=video_vae_decoder, | |
| audio_vae_decoder=audio_vae_decoder, | |
| vocoder=vocoder, | |
| text_encoder=text_encoder, | |
| scheduler=scheduler, | |
| ) | |