| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import TYPE_CHECKING, NamedTuple, Protocol |
|
|
| import torch |
|
|
| from ltx_core.loader.module_ops import ModuleOps |
| from ltx_core.loader.sd_ops import SDOps |
| from ltx_core.model.model_protocol import ModelType |
|
|
| if TYPE_CHECKING: |
| from ltx_core.loader.registry import Registry |
|
|
|
|
| @dataclass(frozen=True) |
| class StateDict: |
| """ |
| Immutable container for a PyTorch state dictionary. |
| Contains: |
| - sd: Dictionary of tensors (weights, buffers, etc.) |
| - device: Device where tensors are stored |
| - size: Total memory footprint in bytes |
| - dtype: Set of tensor dtypes present |
| """ |
|
|
| sd: dict |
| device: torch.device |
| size: int |
| dtype: set[torch.dtype] |
|
|
| def footprint(self) -> tuple[int, torch.device]: |
| return self.size, self.device |
|
|
|
|
| class StateDictLoader(Protocol): |
| """ |
| Protocol for loading state dictionaries from various sources. |
| Implementations must provide: |
| - metadata: Extract model metadata from a single path |
| - load: Load state dict from path(s) and apply SDOps transformations |
| """ |
|
|
| def metadata(self, path: str) -> dict: |
| """ |
| Load metadata from path |
| """ |
|
|
| def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict: |
| """ |
| Load state dict from path or paths (for sharded model storage) and apply sd_ops |
| """ |
|
|
|
|
| class ModelBuilderProtocol(Protocol[ModelType]): |
| """ |
| Protocol for building PyTorch models from configuration dictionaries. |
| Implementations must provide: |
| - meta_model: Create a model from configuration dictionary and apply module operations |
| - build: Create and initialize a model from state dictionary and apply dtype transformations |
| """ |
|
|
| model_sd_ops: SDOps | None |
| module_ops: tuple[ModuleOps, ...] |
| loras: tuple["LoraPathStrengthAndSDOps", ...] |
| registry: "Registry" |
|
|
| def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType: |
| """ |
| Create a model on the meta device from a configuration dictionary. |
| This decouples model creation from weight loading, allowing the model |
| architecture to be instantiated without allocating memory for parameters. |
| Args: |
| config: Model configuration dictionary. |
| module_ops: Optional list of module operations to apply (e.g., quantization). |
| Returns: |
| Model instance on meta device (no actual memory allocated for parameters). |
| """ |
| ... |
|
|
| def with_sd_ops(self, sd_ops: SDOps | None) -> "ModelBuilderProtocol[ModelType]": |
| """Return a copy of this builder with the given state-dict key remapping ops.""" |
| ... |
|
|
| def with_module_ops(self, module_ops: tuple[ModuleOps, ...]) -> "ModelBuilderProtocol[ModelType]": |
| """Return a copy of this builder with the given module operations (e.g. quantization).""" |
| ... |
|
|
| def with_loras(self, loras: tuple["LoraPathStrengthAndSDOps", ...]) -> "ModelBuilderProtocol[ModelType]": |
| """Return a copy of this builder with the given LoRAs to fuse at build time.""" |
| ... |
|
|
| def with_registry(self, registry: "Registry") -> "ModelBuilderProtocol[ModelType]": |
| """Return a copy of this builder using the given weight registry for allocation.""" |
| ... |
|
|
| def with_lora_load_device(self, device: torch.device) -> "ModelBuilderProtocol[ModelType]": |
| """Return a copy of this builder that loads LoRA weights onto the given device.""" |
| ... |
|
|
| def build( |
| self, device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: object |
| ) -> ModelType: |
| """ |
| Build the model |
| Args: |
| device: Target device for the model |
| dtype: Target dtype for the model, if None, uses the dtype of the model_path model |
| Returns: |
| Model instance |
| """ |
| ... |
|
|
| def model_config(self) -> dict: |
| """Return the model configuration dictionary extracted from the checkpoint metadata.""" |
| ... |
|
|
|
|
| class LoRAAdaptableProtocol(Protocol): |
| """ |
| Protocol for models that can be adapted with LoRAs. |
| Implementations must provide: |
| - lora: Add a LoRA to the model |
| """ |
|
|
| def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol": |
| pass |
|
|
|
|
| class LoraPathStrengthAndSDOps(NamedTuple): |
| """ |
| Tuple containing a LoRA path, strength, and SDOps for applying to the LoRA state dict. |
| """ |
|
|
| path: str |
| strength: float |
| sd_ops: SDOps |
|
|
|
|
| class LoraStateDictWithStrength(NamedTuple): |
| """ |
| Tuple containing a LoRA state dict and strength for applying to the model. |
| """ |
|
|
| state_dict: StateDict |
| strength: float |
|
|