Spaces:
Running on Zero
Running on Zero
| import json | |
| import safetensors | |
| import torch | |
| from ltx_core.loader.primitives import StateDict, StateDictLoader | |
| from ltx_core.loader.sd_ops import SDOps | |
| class SafetensorsStateDictLoader(StateDictLoader): | |
| """ | |
| Loads weights from safetensors files without metadata support. | |
| Use this for loading raw weight files. For model files that include | |
| configuration metadata, use SafetensorsModelStateDictLoader instead. | |
| """ | |
| def metadata(self, path: str) -> dict: | |
| raise NotImplementedError("Not implemented") | |
| def load(self, path: str | list[str], sd_ops: SDOps, device: torch.device | None = None) -> StateDict: | |
| """ | |
| Load state dict from path or paths (for sharded model storage) and apply sd_ops | |
| """ | |
| sd = {} | |
| size = 0 | |
| dtype = set() | |
| device = device or torch.device("cpu") | |
| model_paths = path if isinstance(path, list) else [path] | |
| for shard_path in model_paths: | |
| with safetensors.safe_open(shard_path, framework="pt", device=str(device)) as f: | |
| safetensor_keys = f.keys() | |
| for name in safetensor_keys: | |
| expected_name = name if sd_ops is None else sd_ops.apply_to_key(name) | |
| if expected_name is None: | |
| continue | |
| value = f.get_tensor(name).to(device=device, non_blocking=True, copy=False) | |
| key_value_pairs = ((expected_name, value),) | |
| if sd_ops is not None: | |
| key_value_pairs = sd_ops.apply_to_key_value(expected_name, value) | |
| for key, value in key_value_pairs: | |
| size += value.nbytes | |
| dtype.add(value.dtype) | |
| sd[key] = value | |
| return StateDict(sd=sd, device=device, size=size, dtype=dtype) | |
| class SafetensorsModelStateDictLoader(StateDictLoader): | |
| """ | |
| Loads weights and configuration metadata from safetensors model files. | |
| Unlike SafetensorsStateDictLoader, this loader can read model configuration | |
| from the safetensors file metadata via the metadata() method. | |
| """ | |
| def __init__(self, weight_loader: SafetensorsStateDictLoader | None = None): | |
| self.weight_loader = weight_loader if weight_loader is not None else SafetensorsStateDictLoader() | |
| def metadata(self, path: str) -> dict: | |
| with safetensors.safe_open(path, framework="pt") as f: | |
| meta = f.metadata() | |
| if meta is None or "config" not in meta: | |
| return {} | |
| return json.loads(meta["config"]) | |
| def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict: | |
| return self.weight_loader.load(path, sd_ops, device) | |