import json import safetensors import torch from ltx_core.loader.primitives import StateDict, StateDictLoader from ltx_core.loader.sd_ops import SDOps class SafetensorsStateDictLoader(StateDictLoader): """ Loads weights from safetensors files without metadata support. Use this for loading raw weight files. For model files that include configuration metadata, use SafetensorsModelStateDictLoader instead. """ def metadata(self, path: str) -> dict: raise NotImplementedError("Not implemented") def load(self, path: str | list[str], sd_ops: SDOps, device: torch.device | None = None) -> StateDict: """ Load state dict from path or paths (for sharded model storage) and apply sd_ops """ sd = {} size = 0 dtype = set() device = device or torch.device("cpu") model_paths = path if isinstance(path, list) else [path] for shard_path in model_paths: with safetensors.safe_open(shard_path, framework="pt", device=str(device)) as f: safetensor_keys = f.keys() for name in safetensor_keys: expected_name = name if sd_ops is None else sd_ops.apply_to_key(name) if expected_name is None: continue value = f.get_tensor(name).to(device=device, non_blocking=True, copy=False) key_value_pairs = ((expected_name, value),) if sd_ops is not None: key_value_pairs = sd_ops.apply_to_key_value(expected_name, value) for key, value in key_value_pairs: size += value.nbytes dtype.add(value.dtype) sd[key] = value return StateDict(sd=sd, device=device, size=size, dtype=dtype) class SafetensorsModelStateDictLoader(StateDictLoader): """ Loads weights and configuration metadata from safetensors model files. Unlike SafetensorsStateDictLoader, this loader can read model configuration from the safetensors file metadata via the metadata() method. """ def __init__(self, weight_loader: SafetensorsStateDictLoader | None = None): self.weight_loader = weight_loader if weight_loader is not None else SafetensorsStateDictLoader() def metadata(self, path: str) -> dict: with safetensors.safe_open(path, framework="pt") as f: meta = f.metadata() if meta is None or "config" not in meta: return {} return json.loads(meta["config"]) def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict: return self.weight_loader.load(path, sd_ops, device)