import hashlib import threading from dataclasses import dataclass, field from pathlib import Path from typing import Protocol from ltx_core.loader.primitives import StateDict from ltx_core.loader.sd_ops import SDOps class Registry(Protocol): """ Protocol for managing state dictionaries in a registry. It is used to store state dictionaries and reuse them later without loading them again. Implementations must provide: - add: Add a state dictionary to the registry - pop: Remove a state dictionary from the registry - get: Retrieve a state dictionary from the registry - clear: Clear all state dictionaries from the registry """ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: ... def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ... def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ... def clear(self) -> None: ... class DummyRegistry(Registry): """ Dummy registry that does not store state dictionaries. """ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: pass def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: pass def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: pass def clear(self) -> None: pass @dataclass class StateDictRegistry(Registry): """ Registry that stores state dictionaries in a dictionary. """ _state_dicts: dict[str, StateDict] = field(default_factory=dict) _lock: threading.Lock = field(default_factory=threading.Lock) def _generate_id(self, paths: list[str], sd_ops: SDOps) -> str: m = hashlib.sha256() parts = [str(Path(p).resolve()) for p in paths] if sd_ops is not None: parts.append(sd_ops.name) m.update("\0".join(parts).encode("utf-8")) return m.hexdigest() def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> str: sd_id = self._generate_id(paths, sd_ops) with self._lock: if sd_id in self._state_dicts: raise ValueError(f"State dict retrieved from {paths} with {sd_ops} already added, check with get first") self._state_dicts[sd_id] = state_dict return sd_id def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: with self._lock: return self._state_dicts.pop(self._generate_id(paths, sd_ops), None) def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: with self._lock: return self._state_dicts.get(self._generate_id(paths, sd_ops), None) def clear(self) -> None: with self._lock: self._state_dicts.clear()