Spaces:
Running on Zero
Running on Zero
File size: 4,815 Bytes
08c5e28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, NamedTuple, Protocol
import torch
from ltx_core.loader.module_ops import ModuleOps
from ltx_core.loader.sd_ops import SDOps
from ltx_core.model.model_protocol import ModelType
if TYPE_CHECKING:
from ltx_core.loader.registry import Registry
@dataclass(frozen=True)
class StateDict:
"""
Immutable container for a PyTorch state dictionary.
Contains:
- sd: Dictionary of tensors (weights, buffers, etc.)
- device: Device where tensors are stored
- size: Total memory footprint in bytes
- dtype: Set of tensor dtypes present
"""
sd: dict
device: torch.device
size: int
dtype: set[torch.dtype]
def footprint(self) -> tuple[int, torch.device]:
return self.size, self.device
class StateDictLoader(Protocol):
"""
Protocol for loading state dictionaries from various sources.
Implementations must provide:
- metadata: Extract model metadata from a single path
- load: Load state dict from path(s) and apply SDOps transformations
"""
def metadata(self, path: str) -> dict:
"""
Load metadata from path
"""
def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
"""
Load state dict from path or paths (for sharded model storage) and apply sd_ops
"""
class ModelBuilderProtocol(Protocol[ModelType]):
"""
Protocol for building PyTorch models from configuration dictionaries.
Implementations must provide:
- meta_model: Create a model from configuration dictionary and apply module operations
- build: Create and initialize a model from state dictionary and apply dtype transformations
"""
model_sd_ops: SDOps | None
module_ops: tuple[ModuleOps, ...]
loras: tuple["LoraPathStrengthAndSDOps", ...]
registry: "Registry"
def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType:
"""
Create a model on the meta device from a configuration dictionary.
This decouples model creation from weight loading, allowing the model
architecture to be instantiated without allocating memory for parameters.
Args:
config: Model configuration dictionary.
module_ops: Optional list of module operations to apply (e.g., quantization).
Returns:
Model instance on meta device (no actual memory allocated for parameters).
"""
...
def with_sd_ops(self, sd_ops: SDOps | None) -> "ModelBuilderProtocol[ModelType]":
"""Return a copy of this builder with the given state-dict key remapping ops."""
...
def with_module_ops(self, module_ops: tuple[ModuleOps, ...]) -> "ModelBuilderProtocol[ModelType]":
"""Return a copy of this builder with the given module operations (e.g. quantization)."""
...
def with_loras(self, loras: tuple["LoraPathStrengthAndSDOps", ...]) -> "ModelBuilderProtocol[ModelType]":
"""Return a copy of this builder with the given LoRAs to fuse at build time."""
...
def with_registry(self, registry: "Registry") -> "ModelBuilderProtocol[ModelType]":
"""Return a copy of this builder using the given weight registry for allocation."""
...
def with_lora_load_device(self, device: torch.device) -> "ModelBuilderProtocol[ModelType]":
"""Return a copy of this builder that loads LoRA weights onto the given device."""
...
def build(
self, device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: object
) -> ModelType:
"""
Build the model
Args:
device: Target device for the model
dtype: Target dtype for the model, if None, uses the dtype of the model_path model
Returns:
Model instance
"""
...
def model_config(self) -> dict:
"""Return the model configuration dictionary extracted from the checkpoint metadata."""
...
class LoRAAdaptableProtocol(Protocol):
"""
Protocol for models that can be adapted with LoRAs.
Implementations must provide:
- lora: Add a LoRA to the model
"""
def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol":
pass
class LoraPathStrengthAndSDOps(NamedTuple):
"""
Tuple containing a LoRA path, strength, and SDOps for applying to the LoRA state dict.
"""
path: str
strength: float
sd_ops: SDOps
class LoraStateDictWithStrength(NamedTuple):
"""
Tuple containing a LoRA state dict and strength for applying to the model.
"""
state_dict: StateDict
strength: float
|