Spaces:
Running on Zero
Running on Zero
feat(lora): safetensors header sniff + zimage key validation
Browse files- lora.py +70 -0
- tests/test_lora.py +45 -0
lora.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LoRA file validation and apply/revert context manager."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import struct
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
ZIMAGE_LORA_PREFIXES = ("transformer.", "dit.", "model.transformer.")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LoRAValidationError(ValueError):
|
| 13 |
+
"""Raised when a LoRA safetensors file doesn't match Z-Image's key layout."""
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass(frozen=True)
|
| 17 |
+
class LoRAInfo:
|
| 18 |
+
path: Path
|
| 19 |
+
rank: int
|
| 20 |
+
target: str # which submodule it applies to ("transformer" for Z-Image)
|
| 21 |
+
size_bytes: int
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def sniff(path: Path | str) -> LoRAInfo:
|
| 25 |
+
"""Read just the safetensors header to verify and infer rank + target.
|
| 26 |
+
|
| 27 |
+
Doesn't load tensors. Doesn't allocate GPU memory. Cheap enough to call before
|
| 28 |
+
@spaces.GPU fires.
|
| 29 |
+
"""
|
| 30 |
+
path = Path(path)
|
| 31 |
+
raw = path.read_bytes()
|
| 32 |
+
if len(raw) < 8:
|
| 33 |
+
raise LoRAValidationError(f"{path.name}: file too short to be safetensors")
|
| 34 |
+
(header_len,) = struct.unpack("<Q", raw[:8])
|
| 35 |
+
if header_len <= 0 or header_len + 8 > len(raw):
|
| 36 |
+
raise LoRAValidationError(f"{path.name}: not a valid safetensors header")
|
| 37 |
+
try:
|
| 38 |
+
header = json.loads(raw[8 : 8 + header_len])
|
| 39 |
+
except json.JSONDecodeError as e:
|
| 40 |
+
raise LoRAValidationError(f"{path.name}: safetensors header is not JSON ({e})") from e
|
| 41 |
+
|
| 42 |
+
tensor_keys = [k for k in header.keys() if not k.startswith("__")]
|
| 43 |
+
if not tensor_keys:
|
| 44 |
+
raise LoRAValidationError(f"{path.name}: no tensors in file")
|
| 45 |
+
|
| 46 |
+
bad = [k for k in tensor_keys if not k.startswith(ZIMAGE_LORA_PREFIXES)]
|
| 47 |
+
if bad:
|
| 48 |
+
sample = bad[0]
|
| 49 |
+
raise LoRAValidationError(
|
| 50 |
+
f"{path.name}: unexpected key '{sample}' — Z-Image LoRAs must target "
|
| 51 |
+
f"{ZIMAGE_LORA_PREFIXES} (got {len(bad)}/{len(tensor_keys)} mismatched keys)"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
meta = header.get("__metadata__") or {}
|
| 55 |
+
rank = int(meta.get("rank", 0))
|
| 56 |
+
if not rank:
|
| 57 |
+
# Infer from any A/B tensor pair shape
|
| 58 |
+
for k, v in header.items():
|
| 59 |
+
if "lora_A" in k or "lora_down" in k:
|
| 60 |
+
shape = v.get("shape") or []
|
| 61 |
+
if shape:
|
| 62 |
+
rank = int(min(shape))
|
| 63 |
+
break
|
| 64 |
+
|
| 65 |
+
return LoRAInfo(
|
| 66 |
+
path=path,
|
| 67 |
+
rank=rank,
|
| 68 |
+
target="transformer",
|
| 69 |
+
size_bytes=path.stat().st_size,
|
| 70 |
+
)
|
tests/test_lora.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import struct
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
import lora
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _write_safetensors(path: Path, header: dict) -> None:
|
| 11 |
+
"""Minimal safetensors file: 8-byte LE header length + JSON header (no tensor data)."""
|
| 12 |
+
h = json.dumps(header).encode("utf-8")
|
| 13 |
+
path.write_bytes(struct.pack("<Q", len(h)) + h)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def test_sniff_valid_zimage_lora_returns_metadata(tmp_path):
|
| 17 |
+
p = tmp_path / "ok.safetensors"
|
| 18 |
+
_write_safetensors(p, {
|
| 19 |
+
"transformer.layer1.lora_A.weight": {"dtype": "BF16", "shape": [64, 3840]},
|
| 20 |
+
"transformer.layer1.lora_B.weight": {"dtype": "BF16", "shape": [3840, 64]},
|
| 21 |
+
"__metadata__": {"rank": "64"},
|
| 22 |
+
})
|
| 23 |
+
info = lora.sniff(p)
|
| 24 |
+
assert info.rank == 64
|
| 25 |
+
assert info.target == "transformer"
|
| 26 |
+
assert info.size_bytes == p.stat().st_size
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_sniff_rejects_non_safetensors(tmp_path):
|
| 30 |
+
p = tmp_path / "bad.bin"
|
| 31 |
+
p.write_bytes(b"this is not a safetensors file at all")
|
| 32 |
+
with pytest.raises(lora.LoRAValidationError) as exc:
|
| 33 |
+
lora.sniff(p)
|
| 34 |
+
assert "safetensors" in str(exc.value).lower()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_sniff_rejects_non_zimage_keys(tmp_path):
|
| 38 |
+
p = tmp_path / "wrong.safetensors"
|
| 39 |
+
_write_safetensors(p, {
|
| 40 |
+
"down_blocks.0.weight": {"dtype": "F32", "shape": [320, 320]},
|
| 41 |
+
})
|
| 42 |
+
with pytest.raises(lora.LoRAValidationError) as exc:
|
| 43 |
+
lora.sniff(p)
|
| 44 |
+
msg = str(exc.value).lower()
|
| 45 |
+
assert "down_blocks" in msg or "unexpected" in msg
|