Dramabox / ltx2 /ltx_core /loader /registry.py
Manmay's picture
DramaBox Space — initial app + vendored ltx2
08c5e28 verified
import hashlib
import threading
from dataclasses import dataclass, field
from pathlib import Path
from typing import Protocol
from ltx_core.loader.primitives import StateDict
from ltx_core.loader.sd_ops import SDOps
class Registry(Protocol):
"""
Protocol for managing state dictionaries in a registry.
It is used to store state dictionaries and reuse them later without loading them again.
Implementations must provide:
- add: Add a state dictionary to the registry
- pop: Remove a state dictionary from the registry
- get: Retrieve a state dictionary from the registry
- clear: Clear all state dictionaries from the registry
"""
def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: ...
def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
def clear(self) -> None: ...
class DummyRegistry(Registry):
"""
Dummy registry that does not store state dictionaries.
"""
def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None:
pass
def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
pass
def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
pass
def clear(self) -> None:
pass
@dataclass
class StateDictRegistry(Registry):
"""
Registry that stores state dictionaries in a dictionary.
"""
_state_dicts: dict[str, StateDict] = field(default_factory=dict)
_lock: threading.Lock = field(default_factory=threading.Lock)
def _generate_id(self, paths: list[str], sd_ops: SDOps) -> str:
m = hashlib.sha256()
parts = [str(Path(p).resolve()) for p in paths]
if sd_ops is not None:
parts.append(sd_ops.name)
m.update("\0".join(parts).encode("utf-8"))
return m.hexdigest()
def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> str:
sd_id = self._generate_id(paths, sd_ops)
with self._lock:
if sd_id in self._state_dicts:
raise ValueError(f"State dict retrieved from {paths} with {sd_ops} already added, check with get first")
self._state_dicts[sd_id] = state_dict
return sd_id
def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
with self._lock:
return self._state_dicts.pop(self._generate_id(paths, sd_ops), None)
def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
with self._lock:
return self._state_dicts.get(self._generate_id(paths, sd_ops), None)
def clear(self) -> None:
with self._lock:
self._state_dicts.clear()