Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import TYPE_CHECKING, NamedTuple, Protocol | |
| import torch | |
| from ltx_core.loader.module_ops import ModuleOps | |
| from ltx_core.loader.sd_ops import SDOps | |
| from ltx_core.model.model_protocol import ModelType | |
| if TYPE_CHECKING: | |
| from ltx_core.loader.registry import Registry | |
| class StateDict: | |
| """ | |
| Immutable container for a PyTorch state dictionary. | |
| Contains: | |
| - sd: Dictionary of tensors (weights, buffers, etc.) | |
| - device: Device where tensors are stored | |
| - size: Total memory footprint in bytes | |
| - dtype: Set of tensor dtypes present | |
| """ | |
| sd: dict | |
| device: torch.device | |
| size: int | |
| dtype: set[torch.dtype] | |
| def footprint(self) -> tuple[int, torch.device]: | |
| return self.size, self.device | |
| class StateDictLoader(Protocol): | |
| """ | |
| Protocol for loading state dictionaries from various sources. | |
| Implementations must provide: | |
| - metadata: Extract model metadata from a single path | |
| - load: Load state dict from path(s) and apply SDOps transformations | |
| """ | |
| def metadata(self, path: str) -> dict: | |
| """ | |
| Load metadata from path | |
| """ | |
| def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict: | |
| """ | |
| Load state dict from path or paths (for sharded model storage) and apply sd_ops | |
| """ | |
| class ModelBuilderProtocol(Protocol[ModelType]): | |
| """ | |
| Protocol for building PyTorch models from configuration dictionaries. | |
| Implementations must provide: | |
| - meta_model: Create a model from configuration dictionary and apply module operations | |
| - build: Create and initialize a model from state dictionary and apply dtype transformations | |
| """ | |
| model_sd_ops: SDOps | None | |
| module_ops: tuple[ModuleOps, ...] | |
| loras: tuple["LoraPathStrengthAndSDOps", ...] | |
| registry: "Registry" | |
| def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType: | |
| """ | |
| Create a model on the meta device from a configuration dictionary. | |
| This decouples model creation from weight loading, allowing the model | |
| architecture to be instantiated without allocating memory for parameters. | |
| Args: | |
| config: Model configuration dictionary. | |
| module_ops: Optional list of module operations to apply (e.g., quantization). | |
| Returns: | |
| Model instance on meta device (no actual memory allocated for parameters). | |
| """ | |
| ... | |
| def with_sd_ops(self, sd_ops: SDOps | None) -> "ModelBuilderProtocol[ModelType]": | |
| """Return a copy of this builder with the given state-dict key remapping ops.""" | |
| ... | |
| def with_module_ops(self, module_ops: tuple[ModuleOps, ...]) -> "ModelBuilderProtocol[ModelType]": | |
| """Return a copy of this builder with the given module operations (e.g. quantization).""" | |
| ... | |
| def with_loras(self, loras: tuple["LoraPathStrengthAndSDOps", ...]) -> "ModelBuilderProtocol[ModelType]": | |
| """Return a copy of this builder with the given LoRAs to fuse at build time.""" | |
| ... | |
| def with_registry(self, registry: "Registry") -> "ModelBuilderProtocol[ModelType]": | |
| """Return a copy of this builder using the given weight registry for allocation.""" | |
| ... | |
| def with_lora_load_device(self, device: torch.device) -> "ModelBuilderProtocol[ModelType]": | |
| """Return a copy of this builder that loads LoRA weights onto the given device.""" | |
| ... | |
| def build( | |
| self, device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: object | |
| ) -> ModelType: | |
| """ | |
| Build the model | |
| Args: | |
| device: Target device for the model | |
| dtype: Target dtype for the model, if None, uses the dtype of the model_path model | |
| Returns: | |
| Model instance | |
| """ | |
| ... | |
| def model_config(self) -> dict: | |
| """Return the model configuration dictionary extracted from the checkpoint metadata.""" | |
| ... | |
| class LoRAAdaptableProtocol(Protocol): | |
| """ | |
| Protocol for models that can be adapted with LoRAs. | |
| Implementations must provide: | |
| - lora: Add a LoRA to the model | |
| """ | |
| def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol": | |
| pass | |
| class LoraPathStrengthAndSDOps(NamedTuple): | |
| """ | |
| Tuple containing a LoRA path, strength, and SDOps for applying to the LoRA state dict. | |
| """ | |
| path: str | |
| strength: float | |
| sd_ops: SDOps | |
| class LoraStateDictWithStrength(NamedTuple): | |
| """ | |
| Tuple containing a LoRA state dict and strength for applying to the model. | |
| """ | |
| state_dict: StateDict | |
| strength: float | |