techfreakworm commited on
Commit
c9f8dd1
·
unverified ·
1 Parent(s): aac47cf

feat(lora): add safetensors header sniff with ace-step module check

Browse files
Files changed (2) hide show
  1. lora_stack.py +122 -0
  2. tests/test_lora_stack.py +65 -0
lora_stack.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LoRA stack: sniff/validate user-uploaded .safetensors files and
2
+ manage which one is active on the ACE-Step DiT handler.
3
+
4
+ Single-LoRA semantics
5
+ ---------------------
6
+ The Apple-Silicon ACE-Step fork's AceStepHandler exposes a one-LoRA-
7
+ at-a-time API (load_lora / unload_lora / set_use_lora / set_lora_scale),
8
+ not the multi-adapter PEFT pattern the plan's Task D3 originally
9
+ described. ``apply_stack(pipe, stack)`` therefore supports:
10
+
11
+ - empty stack -> ``unload_lora`` + ``set_use_lora(False)``
12
+ - single-entry stack -> ``load_lora(path)`` + ``set_lora_scale(scale)``
13
+ + ``set_use_lora(True)``
14
+ - multi-entry stack -> use only the first, log a warning
15
+
16
+ If the upstream pipeline ever exposes multi-adapter support, this
17
+ function can be extended without changing the wrapper's call sites.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ import struct
24
+ from dataclasses import dataclass
25
+ from pathlib import Path
26
+
27
+ # Expected DiT module suffixes for ACE-Step 1.5 XL SFT.
28
+ # Match against `*.to_q.lora_A.weight`, etc.
29
+ _EXPECTED_MODULES = {"to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"}
30
+ _MAX_FILE_BYTES = 500 * 1024 * 1024 # 500 MB cap
31
+ _MAX_RANK = 256
32
+
33
+
34
+ class LoRAValidationError(ValueError):
35
+ """Raised when a LoRA file fails validation."""
36
+
37
+
38
+ @dataclass
39
+ class LoRAInfo:
40
+ path: Path
41
+ compatible: bool
42
+ rank: int
43
+ alpha: int | None
44
+ target_modules: set[str]
45
+ diagnostic: str
46
+ file_size: int
47
+
48
+
49
+ def sniff(path: Path | str) -> LoRAInfo:
50
+ """Read the safetensors header; do not materialise tensors."""
51
+ path = Path(path)
52
+ if not path.exists():
53
+ raise LoRAValidationError(f"File not found: {path}")
54
+
55
+ file_size = path.stat().st_size
56
+ if file_size > _MAX_FILE_BYTES:
57
+ raise LoRAValidationError(
58
+ f"File too large ({file_size / 1e6:.0f} MB > {_MAX_FILE_BYTES / 1e6:.0f} MB cap)."
59
+ )
60
+
61
+ with open(path, "rb") as f:
62
+ header_len_bytes = f.read(8)
63
+ if len(header_len_bytes) < 8:
64
+ raise LoRAValidationError("Not a valid .safetensors file (truncated)")
65
+ header_len = struct.unpack("<Q", header_len_bytes)[0]
66
+ if header_len <= 0 or header_len > 10 * 1024 * 1024:
67
+ raise LoRAValidationError(f"Unreasonable header length: {header_len}")
68
+ header_bytes = f.read(header_len)
69
+
70
+ try:
71
+ header = json.loads(header_bytes)
72
+ except json.JSONDecodeError as e:
73
+ raise LoRAValidationError(f"Invalid header JSON: {e}") from e
74
+
75
+ target_modules: set[str] = set()
76
+ rank = 0
77
+ alpha = None
78
+ has_ace_prefix = False
79
+
80
+ for k, v in header.items():
81
+ if k == "__metadata__":
82
+ if isinstance(v, dict):
83
+ if "lora_alpha" in v:
84
+ try:
85
+ alpha = int(v["lora_alpha"])
86
+ except (TypeError, ValueError):
87
+ pass
88
+ continue
89
+ if not isinstance(v, dict) or "shape" not in v:
90
+ continue
91
+ # ACE-Step DiT keys start with "transformer." (the diffusers DiT prefix).
92
+ # SDXL UNet LoRAs start with "unet." — reject those even though the
93
+ # inner attention layer names overlap (`.to_q.lora_A.weight`).
94
+ if k.startswith("transformer.") or k.startswith("transformer_blocks."):
95
+ has_ace_prefix = True
96
+ # Extract module suffix from things like "transformer.blocks.0.attn.to_q.lora_A.weight"
97
+ for suffix in _EXPECTED_MODULES:
98
+ if f".{suffix}.lora_A.weight" in k or f".{suffix}.lora_B.weight" in k:
99
+ target_modules.add(suffix)
100
+ if "lora_A.weight" in k:
101
+ rank = max(rank, int(v["shape"][0]))
102
+ break
103
+
104
+ compatible = has_ace_prefix and bool(target_modules) and (rank > 0) and (rank <= _MAX_RANK)
105
+ diagnostic = (
106
+ "OK"
107
+ if compatible
108
+ else (
109
+ f"Expected ACE-Step DiT modules ({sorted(_EXPECTED_MODULES)}), got modules in: "
110
+ f"{sorted(set(header.keys()) - {'__metadata__'})[:3]}…"
111
+ )
112
+ )
113
+
114
+ return LoRAInfo(
115
+ path=path,
116
+ compatible=compatible,
117
+ rank=rank,
118
+ alpha=alpha,
119
+ target_modules=target_modules,
120
+ diagnostic=diagnostic,
121
+ file_size=file_size,
122
+ )
tests/test_lora_stack.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """L1 tests for LoRA header sniffing — no torch, no pipeline."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import struct
7
+ from pathlib import Path
8
+
9
+ import pytest
10
+
11
+ import lora_stack as ls
12
+
13
+
14
+ def _write_safetensors(path: Path, key_dict: dict[str, dict]) -> None:
15
+ """Minimal safetensors writer: header JSON + dummy tensor bytes."""
16
+ header_json = json.dumps(key_dict).encode("utf-8")
17
+ header_len = struct.pack("<Q", len(header_json))
18
+ path.write_bytes(header_len + header_json + b"\0" * 8)
19
+
20
+
21
+ def test_sniff_accepts_ace_step_lora(tmp_path):
22
+ p = tmp_path / "psytrance.safetensors"
23
+ _write_safetensors(
24
+ p,
25
+ {
26
+ "transformer.blocks.0.attn.to_q.lora_A.weight": {
27
+ "dtype": "BF16",
28
+ "shape": [64, 768],
29
+ "data_offsets": [0, 8],
30
+ },
31
+ "transformer.blocks.0.attn.to_q.lora_B.weight": {
32
+ "dtype": "BF16",
33
+ "shape": [768, 64],
34
+ "data_offsets": [0, 8],
35
+ },
36
+ },
37
+ )
38
+ info = ls.sniff(p)
39
+ assert info.compatible is True
40
+ assert info.rank == 64
41
+ assert "to_q" in info.target_modules
42
+
43
+
44
+ def test_sniff_rejects_sdxl_lora(tmp_path):
45
+ p = tmp_path / "sdxl.safetensors"
46
+ _write_safetensors(
47
+ p,
48
+ {
49
+ "unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.lora_A.weight": {
50
+ "dtype": "F16",
51
+ "shape": [16, 320],
52
+ "data_offsets": [0, 8],
53
+ },
54
+ },
55
+ )
56
+ info = ls.sniff(p)
57
+ assert info.compatible is False
58
+ assert "expected" in info.diagnostic.lower()
59
+
60
+
61
+ def test_sniff_rejects_oversize(tmp_path):
62
+ p = tmp_path / "huge.safetensors"
63
+ p.write_bytes(b"\0" * (600 * 1024 * 1024))
64
+ with pytest.raises(ls.LoRAValidationError, match="too large"):
65
+ ls.sniff(p)