Dramabox / ltx2 /ltx_core /loader /primitives.py
Manmay's picture
DramaBox Space — initial app + vendored ltx2
08c5e28 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, NamedTuple, Protocol
import torch
from ltx_core.loader.module_ops import ModuleOps
from ltx_core.loader.sd_ops import SDOps
from ltx_core.model.model_protocol import ModelType
if TYPE_CHECKING:
from ltx_core.loader.registry import Registry
@dataclass(frozen=True)
class StateDict:
"""
Immutable container for a PyTorch state dictionary.
Contains:
- sd: Dictionary of tensors (weights, buffers, etc.)
- device: Device where tensors are stored
- size: Total memory footprint in bytes
- dtype: Set of tensor dtypes present
"""
sd: dict
device: torch.device
size: int
dtype: set[torch.dtype]
def footprint(self) -> tuple[int, torch.device]:
return self.size, self.device
class StateDictLoader(Protocol):
"""
Protocol for loading state dictionaries from various sources.
Implementations must provide:
- metadata: Extract model metadata from a single path
- load: Load state dict from path(s) and apply SDOps transformations
"""
def metadata(self, path: str) -> dict:
"""
Load metadata from path
"""
def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
"""
Load state dict from path or paths (for sharded model storage) and apply sd_ops
"""
class ModelBuilderProtocol(Protocol[ModelType]):
"""
Protocol for building PyTorch models from configuration dictionaries.
Implementations must provide:
- meta_model: Create a model from configuration dictionary and apply module operations
- build: Create and initialize a model from state dictionary and apply dtype transformations
"""
model_sd_ops: SDOps | None
module_ops: tuple[ModuleOps, ...]
loras: tuple["LoraPathStrengthAndSDOps", ...]
registry: "Registry"
def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType:
"""
Create a model on the meta device from a configuration dictionary.
This decouples model creation from weight loading, allowing the model
architecture to be instantiated without allocating memory for parameters.
Args:
config: Model configuration dictionary.
module_ops: Optional list of module operations to apply (e.g., quantization).
Returns:
Model instance on meta device (no actual memory allocated for parameters).
"""
...
def with_sd_ops(self, sd_ops: SDOps | None) -> "ModelBuilderProtocol[ModelType]":
"""Return a copy of this builder with the given state-dict key remapping ops."""
...
def with_module_ops(self, module_ops: tuple[ModuleOps, ...]) -> "ModelBuilderProtocol[ModelType]":
"""Return a copy of this builder with the given module operations (e.g. quantization)."""
...
def with_loras(self, loras: tuple["LoraPathStrengthAndSDOps", ...]) -> "ModelBuilderProtocol[ModelType]":
"""Return a copy of this builder with the given LoRAs to fuse at build time."""
...
def with_registry(self, registry: "Registry") -> "ModelBuilderProtocol[ModelType]":
"""Return a copy of this builder using the given weight registry for allocation."""
...
def with_lora_load_device(self, device: torch.device) -> "ModelBuilderProtocol[ModelType]":
"""Return a copy of this builder that loads LoRA weights onto the given device."""
...
def build(
self, device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: object
) -> ModelType:
"""
Build the model
Args:
device: Target device for the model
dtype: Target dtype for the model, if None, uses the dtype of the model_path model
Returns:
Model instance
"""
...
def model_config(self) -> dict:
"""Return the model configuration dictionary extracted from the checkpoint metadata."""
...
class LoRAAdaptableProtocol(Protocol):
"""
Protocol for models that can be adapted with LoRAs.
Implementations must provide:
- lora: Add a LoRA to the model
"""
def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol":
pass
class LoraPathStrengthAndSDOps(NamedTuple):
"""
Tuple containing a LoRA path, strength, and SDOps for applying to the LoRA state dict.
"""
path: str
strength: float
sd_ops: SDOps
class LoraStateDictWithStrength(NamedTuple):
"""
Tuple containing a LoRA state dict and strength for applying to the model.
"""
state_dict: StateDict
strength: float