techfreakworm commited on
Commit
bc2513c
·
unverified ·
1 Parent(s): 261639d

feat(lora): safetensors header sniff + zimage key validation

Browse files
Files changed (2) hide show
  1. lora.py +70 -0
  2. tests/test_lora.py +45 -0
lora.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LoRA file validation and apply/revert context manager."""
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import struct
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+
9
+ ZIMAGE_LORA_PREFIXES = ("transformer.", "dit.", "model.transformer.")
10
+
11
+
12
+ class LoRAValidationError(ValueError):
13
+ """Raised when a LoRA safetensors file doesn't match Z-Image's key layout."""
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class LoRAInfo:
18
+ path: Path
19
+ rank: int
20
+ target: str # which submodule it applies to ("transformer" for Z-Image)
21
+ size_bytes: int
22
+
23
+
24
+ def sniff(path: Path | str) -> LoRAInfo:
25
+ """Read just the safetensors header to verify and infer rank + target.
26
+
27
+ Doesn't load tensors. Doesn't allocate GPU memory. Cheap enough to call before
28
+ @spaces.GPU fires.
29
+ """
30
+ path = Path(path)
31
+ raw = path.read_bytes()
32
+ if len(raw) < 8:
33
+ raise LoRAValidationError(f"{path.name}: file too short to be safetensors")
34
+ (header_len,) = struct.unpack("<Q", raw[:8])
35
+ if header_len <= 0 or header_len + 8 > len(raw):
36
+ raise LoRAValidationError(f"{path.name}: not a valid safetensors header")
37
+ try:
38
+ header = json.loads(raw[8 : 8 + header_len])
39
+ except json.JSONDecodeError as e:
40
+ raise LoRAValidationError(f"{path.name}: safetensors header is not JSON ({e})") from e
41
+
42
+ tensor_keys = [k for k in header.keys() if not k.startswith("__")]
43
+ if not tensor_keys:
44
+ raise LoRAValidationError(f"{path.name}: no tensors in file")
45
+
46
+ bad = [k for k in tensor_keys if not k.startswith(ZIMAGE_LORA_PREFIXES)]
47
+ if bad:
48
+ sample = bad[0]
49
+ raise LoRAValidationError(
50
+ f"{path.name}: unexpected key '{sample}' — Z-Image LoRAs must target "
51
+ f"{ZIMAGE_LORA_PREFIXES} (got {len(bad)}/{len(tensor_keys)} mismatched keys)"
52
+ )
53
+
54
+ meta = header.get("__metadata__") or {}
55
+ rank = int(meta.get("rank", 0))
56
+ if not rank:
57
+ # Infer from any A/B tensor pair shape
58
+ for k, v in header.items():
59
+ if "lora_A" in k or "lora_down" in k:
60
+ shape = v.get("shape") or []
61
+ if shape:
62
+ rank = int(min(shape))
63
+ break
64
+
65
+ return LoRAInfo(
66
+ path=path,
67
+ rank=rank,
68
+ target="transformer",
69
+ size_bytes=path.stat().st_size,
70
+ )
tests/test_lora.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import struct
3
+ from pathlib import Path
4
+
5
+ import pytest
6
+
7
+ import lora
8
+
9
+
10
+ def _write_safetensors(path: Path, header: dict) -> None:
11
+ """Minimal safetensors file: 8-byte LE header length + JSON header (no tensor data)."""
12
+ h = json.dumps(header).encode("utf-8")
13
+ path.write_bytes(struct.pack("<Q", len(h)) + h)
14
+
15
+
16
+ def test_sniff_valid_zimage_lora_returns_metadata(tmp_path):
17
+ p = tmp_path / "ok.safetensors"
18
+ _write_safetensors(p, {
19
+ "transformer.layer1.lora_A.weight": {"dtype": "BF16", "shape": [64, 3840]},
20
+ "transformer.layer1.lora_B.weight": {"dtype": "BF16", "shape": [3840, 64]},
21
+ "__metadata__": {"rank": "64"},
22
+ })
23
+ info = lora.sniff(p)
24
+ assert info.rank == 64
25
+ assert info.target == "transformer"
26
+ assert info.size_bytes == p.stat().st_size
27
+
28
+
29
+ def test_sniff_rejects_non_safetensors(tmp_path):
30
+ p = tmp_path / "bad.bin"
31
+ p.write_bytes(b"this is not a safetensors file at all")
32
+ with pytest.raises(lora.LoRAValidationError) as exc:
33
+ lora.sniff(p)
34
+ assert "safetensors" in str(exc.value).lower()
35
+
36
+
37
+ def test_sniff_rejects_non_zimage_keys(tmp_path):
38
+ p = tmp_path / "wrong.safetensors"
39
+ _write_safetensors(p, {
40
+ "down_blocks.0.weight": {"dtype": "F32", "shape": [320, 320]},
41
+ })
42
+ with pytest.raises(lora.LoRAValidationError) as exc:
43
+ lora.sniff(p)
44
+ msg = str(exc.value).lower()
45
+ assert "down_blocks" in msg or "unexpected" in msg