Dramabox / ltx2 /ltx_core /loader /sft_loader.py
Manmay's picture
DramaBox Space — initial app + vendored ltx2
08c5e28 verified
import json
import safetensors
import torch
from ltx_core.loader.primitives import StateDict, StateDictLoader
from ltx_core.loader.sd_ops import SDOps
class SafetensorsStateDictLoader(StateDictLoader):
"""
Loads weights from safetensors files without metadata support.
Use this for loading raw weight files. For model files that include
configuration metadata, use SafetensorsModelStateDictLoader instead.
"""
def metadata(self, path: str) -> dict:
raise NotImplementedError("Not implemented")
def load(self, path: str | list[str], sd_ops: SDOps, device: torch.device | None = None) -> StateDict:
"""
Load state dict from path or paths (for sharded model storage) and apply sd_ops
"""
sd = {}
size = 0
dtype = set()
device = device or torch.device("cpu")
model_paths = path if isinstance(path, list) else [path]
for shard_path in model_paths:
with safetensors.safe_open(shard_path, framework="pt", device=str(device)) as f:
safetensor_keys = f.keys()
for name in safetensor_keys:
expected_name = name if sd_ops is None else sd_ops.apply_to_key(name)
if expected_name is None:
continue
value = f.get_tensor(name).to(device=device, non_blocking=True, copy=False)
key_value_pairs = ((expected_name, value),)
if sd_ops is not None:
key_value_pairs = sd_ops.apply_to_key_value(expected_name, value)
for key, value in key_value_pairs:
size += value.nbytes
dtype.add(value.dtype)
sd[key] = value
return StateDict(sd=sd, device=device, size=size, dtype=dtype)
class SafetensorsModelStateDictLoader(StateDictLoader):
"""
Loads weights and configuration metadata from safetensors model files.
Unlike SafetensorsStateDictLoader, this loader can read model configuration
from the safetensors file metadata via the metadata() method.
"""
def __init__(self, weight_loader: SafetensorsStateDictLoader | None = None):
self.weight_loader = weight_loader if weight_loader is not None else SafetensorsStateDictLoader()
def metadata(self, path: str) -> dict:
with safetensors.safe_open(path, framework="pt") as f:
meta = f.metadata()
if meta is None or "config" not in meta:
return {}
return json.loads(meta["config"])
def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
return self.weight_loader.load(path, sd_ops, device)