File size: 4,815 Bytes
08c5e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, NamedTuple, Protocol

import torch

from ltx_core.loader.module_ops import ModuleOps
from ltx_core.loader.sd_ops import SDOps
from ltx_core.model.model_protocol import ModelType

if TYPE_CHECKING:
    from ltx_core.loader.registry import Registry


@dataclass(frozen=True)
class StateDict:
    """
    Immutable container for a PyTorch state dictionary.
    Contains:
    - sd: Dictionary of tensors (weights, buffers, etc.)
    - device: Device where tensors are stored
    - size: Total memory footprint in bytes
    - dtype: Set of tensor dtypes present
    """

    sd: dict
    device: torch.device
    size: int
    dtype: set[torch.dtype]

    def footprint(self) -> tuple[int, torch.device]:
        return self.size, self.device


class StateDictLoader(Protocol):
    """
    Protocol for loading state dictionaries from various sources.
    Implementations must provide:
    - metadata: Extract model metadata from a single path
    - load: Load state dict from path(s) and apply SDOps transformations
    """

    def metadata(self, path: str) -> dict:
        """
        Load metadata from path
        """

    def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
        """
        Load state dict from path or paths (for sharded model storage) and apply sd_ops
        """


class ModelBuilderProtocol(Protocol[ModelType]):
    """
    Protocol for building PyTorch models from configuration dictionaries.
    Implementations must provide:
    - meta_model: Create a model from configuration dictionary and apply module operations
    - build: Create and initialize a model from state dictionary and apply dtype transformations
    """

    model_sd_ops: SDOps | None
    module_ops: tuple[ModuleOps, ...]
    loras: tuple["LoraPathStrengthAndSDOps", ...]
    registry: "Registry"

    def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType:
        """
        Create a model on the meta device from a configuration dictionary.
        This decouples model creation from weight loading, allowing the model
        architecture to be instantiated without allocating memory for parameters.
        Args:
            config: Model configuration dictionary.
            module_ops: Optional list of module operations to apply (e.g., quantization).
        Returns:
            Model instance on meta device (no actual memory allocated for parameters).
        """
        ...

    def with_sd_ops(self, sd_ops: SDOps | None) -> "ModelBuilderProtocol[ModelType]":
        """Return a copy of this builder with the given state-dict key remapping ops."""
        ...

    def with_module_ops(self, module_ops: tuple[ModuleOps, ...]) -> "ModelBuilderProtocol[ModelType]":
        """Return a copy of this builder with the given module operations (e.g. quantization)."""
        ...

    def with_loras(self, loras: tuple["LoraPathStrengthAndSDOps", ...]) -> "ModelBuilderProtocol[ModelType]":
        """Return a copy of this builder with the given LoRAs to fuse at build time."""
        ...

    def with_registry(self, registry: "Registry") -> "ModelBuilderProtocol[ModelType]":
        """Return a copy of this builder using the given weight registry for allocation."""
        ...

    def with_lora_load_device(self, device: torch.device) -> "ModelBuilderProtocol[ModelType]":
        """Return a copy of this builder that loads LoRA weights onto the given device."""
        ...

    def build(
        self, device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: object
    ) -> ModelType:
        """
        Build the model
        Args:
            device: Target device for the model
            dtype: Target dtype for the model, if None, uses the dtype of the model_path model
        Returns:
            Model instance
        """
        ...

    def model_config(self) -> dict:
        """Return the model configuration dictionary extracted from the checkpoint metadata."""
        ...


class LoRAAdaptableProtocol(Protocol):
    """
    Protocol for models that can be adapted with LoRAs.
    Implementations must provide:
    - lora: Add a LoRA to the model
    """

    def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol":
        pass


class LoraPathStrengthAndSDOps(NamedTuple):
    """
    Tuple containing a LoRA path, strength, and SDOps for applying to the LoRA state dict.
    """

    path: str
    strength: float
    sd_ops: SDOps


class LoraStateDictWithStrength(NamedTuple):
    """
    Tuple containing a LoRA state dict and strength for applying to the model.
    """

    state_dict: StateDict
    strength: float