from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING, NamedTuple, Protocol import torch from ltx_core.loader.module_ops import ModuleOps from ltx_core.loader.sd_ops import SDOps from ltx_core.model.model_protocol import ModelType if TYPE_CHECKING: from ltx_core.loader.registry import Registry @dataclass(frozen=True) class StateDict: """ Immutable container for a PyTorch state dictionary. Contains: - sd: Dictionary of tensors (weights, buffers, etc.) - device: Device where tensors are stored - size: Total memory footprint in bytes - dtype: Set of tensor dtypes present """ sd: dict device: torch.device size: int dtype: set[torch.dtype] def footprint(self) -> tuple[int, torch.device]: return self.size, self.device class StateDictLoader(Protocol): """ Protocol for loading state dictionaries from various sources. Implementations must provide: - metadata: Extract model metadata from a single path - load: Load state dict from path(s) and apply SDOps transformations """ def metadata(self, path: str) -> dict: """ Load metadata from path """ def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict: """ Load state dict from path or paths (for sharded model storage) and apply sd_ops """ class ModelBuilderProtocol(Protocol[ModelType]): """ Protocol for building PyTorch models from configuration dictionaries. Implementations must provide: - meta_model: Create a model from configuration dictionary and apply module operations - build: Create and initialize a model from state dictionary and apply dtype transformations """ model_sd_ops: SDOps | None module_ops: tuple[ModuleOps, ...] loras: tuple["LoraPathStrengthAndSDOps", ...] registry: "Registry" def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType: """ Create a model on the meta device from a configuration dictionary. This decouples model creation from weight loading, allowing the model architecture to be instantiated without allocating memory for parameters. Args: config: Model configuration dictionary. module_ops: Optional list of module operations to apply (e.g., quantization). Returns: Model instance on meta device (no actual memory allocated for parameters). """ ... def with_sd_ops(self, sd_ops: SDOps | None) -> "ModelBuilderProtocol[ModelType]": """Return a copy of this builder with the given state-dict key remapping ops.""" ... def with_module_ops(self, module_ops: tuple[ModuleOps, ...]) -> "ModelBuilderProtocol[ModelType]": """Return a copy of this builder with the given module operations (e.g. quantization).""" ... def with_loras(self, loras: tuple["LoraPathStrengthAndSDOps", ...]) -> "ModelBuilderProtocol[ModelType]": """Return a copy of this builder with the given LoRAs to fuse at build time.""" ... def with_registry(self, registry: "Registry") -> "ModelBuilderProtocol[ModelType]": """Return a copy of this builder using the given weight registry for allocation.""" ... def with_lora_load_device(self, device: torch.device) -> "ModelBuilderProtocol[ModelType]": """Return a copy of this builder that loads LoRA weights onto the given device.""" ... def build( self, device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: object ) -> ModelType: """ Build the model Args: device: Target device for the model dtype: Target dtype for the model, if None, uses the dtype of the model_path model Returns: Model instance """ ... def model_config(self) -> dict: """Return the model configuration dictionary extracted from the checkpoint metadata.""" ... class LoRAAdaptableProtocol(Protocol): """ Protocol for models that can be adapted with LoRAs. Implementations must provide: - lora: Add a LoRA to the model """ def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol": pass class LoraPathStrengthAndSDOps(NamedTuple): """ Tuple containing a LoRA path, strength, and SDOps for applying to the LoRA state dict. """ path: str strength: float sd_ops: SDOps class LoraStateDictWithStrength(NamedTuple): """ Tuple containing a LoRA state dict and strength for applying to the model. """ state_dict: StateDict strength: float