Spaces:
Running on Zero
Running on Zero
feat(lora): add safetensors header sniff with ace-step module check
Browse files- lora_stack.py +122 -0
- tests/test_lora_stack.py +65 -0
lora_stack.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LoRA stack: sniff/validate user-uploaded .safetensors files and
|
| 2 |
+
manage which one is active on the ACE-Step DiT handler.
|
| 3 |
+
|
| 4 |
+
Single-LoRA semantics
|
| 5 |
+
---------------------
|
| 6 |
+
The Apple-Silicon ACE-Step fork's AceStepHandler exposes a one-LoRA-
|
| 7 |
+
at-a-time API (load_lora / unload_lora / set_use_lora / set_lora_scale),
|
| 8 |
+
not the multi-adapter PEFT pattern the plan's Task D3 originally
|
| 9 |
+
described. ``apply_stack(pipe, stack)`` therefore supports:
|
| 10 |
+
|
| 11 |
+
- empty stack -> ``unload_lora`` + ``set_use_lora(False)``
|
| 12 |
+
- single-entry stack -> ``load_lora(path)`` + ``set_lora_scale(scale)``
|
| 13 |
+
+ ``set_use_lora(True)``
|
| 14 |
+
- multi-entry stack -> use only the first, log a warning
|
| 15 |
+
|
| 16 |
+
If the upstream pipeline ever exposes multi-adapter support, this
|
| 17 |
+
function can be extended without changing the wrapper's call sites.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import json
|
| 23 |
+
import struct
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
|
| 27 |
+
# Expected DiT module suffixes for ACE-Step 1.5 XL SFT.
|
| 28 |
+
# Match against `*.to_q.lora_A.weight`, etc.
|
| 29 |
+
_EXPECTED_MODULES = {"to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"}
|
| 30 |
+
_MAX_FILE_BYTES = 500 * 1024 * 1024 # 500 MB cap
|
| 31 |
+
_MAX_RANK = 256
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class LoRAValidationError(ValueError):
|
| 35 |
+
"""Raised when a LoRA file fails validation."""
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class LoRAInfo:
|
| 40 |
+
path: Path
|
| 41 |
+
compatible: bool
|
| 42 |
+
rank: int
|
| 43 |
+
alpha: int | None
|
| 44 |
+
target_modules: set[str]
|
| 45 |
+
diagnostic: str
|
| 46 |
+
file_size: int
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def sniff(path: Path | str) -> LoRAInfo:
|
| 50 |
+
"""Read the safetensors header; do not materialise tensors."""
|
| 51 |
+
path = Path(path)
|
| 52 |
+
if not path.exists():
|
| 53 |
+
raise LoRAValidationError(f"File not found: {path}")
|
| 54 |
+
|
| 55 |
+
file_size = path.stat().st_size
|
| 56 |
+
if file_size > _MAX_FILE_BYTES:
|
| 57 |
+
raise LoRAValidationError(
|
| 58 |
+
f"File too large ({file_size / 1e6:.0f} MB > {_MAX_FILE_BYTES / 1e6:.0f} MB cap)."
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
with open(path, "rb") as f:
|
| 62 |
+
header_len_bytes = f.read(8)
|
| 63 |
+
if len(header_len_bytes) < 8:
|
| 64 |
+
raise LoRAValidationError("Not a valid .safetensors file (truncated)")
|
| 65 |
+
header_len = struct.unpack("<Q", header_len_bytes)[0]
|
| 66 |
+
if header_len <= 0 or header_len > 10 * 1024 * 1024:
|
| 67 |
+
raise LoRAValidationError(f"Unreasonable header length: {header_len}")
|
| 68 |
+
header_bytes = f.read(header_len)
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
header = json.loads(header_bytes)
|
| 72 |
+
except json.JSONDecodeError as e:
|
| 73 |
+
raise LoRAValidationError(f"Invalid header JSON: {e}") from e
|
| 74 |
+
|
| 75 |
+
target_modules: set[str] = set()
|
| 76 |
+
rank = 0
|
| 77 |
+
alpha = None
|
| 78 |
+
has_ace_prefix = False
|
| 79 |
+
|
| 80 |
+
for k, v in header.items():
|
| 81 |
+
if k == "__metadata__":
|
| 82 |
+
if isinstance(v, dict):
|
| 83 |
+
if "lora_alpha" in v:
|
| 84 |
+
try:
|
| 85 |
+
alpha = int(v["lora_alpha"])
|
| 86 |
+
except (TypeError, ValueError):
|
| 87 |
+
pass
|
| 88 |
+
continue
|
| 89 |
+
if not isinstance(v, dict) or "shape" not in v:
|
| 90 |
+
continue
|
| 91 |
+
# ACE-Step DiT keys start with "transformer." (the diffusers DiT prefix).
|
| 92 |
+
# SDXL UNet LoRAs start with "unet." — reject those even though the
|
| 93 |
+
# inner attention layer names overlap (`.to_q.lora_A.weight`).
|
| 94 |
+
if k.startswith("transformer.") or k.startswith("transformer_blocks."):
|
| 95 |
+
has_ace_prefix = True
|
| 96 |
+
# Extract module suffix from things like "transformer.blocks.0.attn.to_q.lora_A.weight"
|
| 97 |
+
for suffix in _EXPECTED_MODULES:
|
| 98 |
+
if f".{suffix}.lora_A.weight" in k or f".{suffix}.lora_B.weight" in k:
|
| 99 |
+
target_modules.add(suffix)
|
| 100 |
+
if "lora_A.weight" in k:
|
| 101 |
+
rank = max(rank, int(v["shape"][0]))
|
| 102 |
+
break
|
| 103 |
+
|
| 104 |
+
compatible = has_ace_prefix and bool(target_modules) and (rank > 0) and (rank <= _MAX_RANK)
|
| 105 |
+
diagnostic = (
|
| 106 |
+
"OK"
|
| 107 |
+
if compatible
|
| 108 |
+
else (
|
| 109 |
+
f"Expected ACE-Step DiT modules ({sorted(_EXPECTED_MODULES)}), got modules in: "
|
| 110 |
+
f"{sorted(set(header.keys()) - {'__metadata__'})[:3]}…"
|
| 111 |
+
)
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return LoRAInfo(
|
| 115 |
+
path=path,
|
| 116 |
+
compatible=compatible,
|
| 117 |
+
rank=rank,
|
| 118 |
+
alpha=alpha,
|
| 119 |
+
target_modules=target_modules,
|
| 120 |
+
diagnostic=diagnostic,
|
| 121 |
+
file_size=file_size,
|
| 122 |
+
)
|
tests/test_lora_stack.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""L1 tests for LoRA header sniffing — no torch, no pipeline."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import struct
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import pytest
|
| 10 |
+
|
| 11 |
+
import lora_stack as ls
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _write_safetensors(path: Path, key_dict: dict[str, dict]) -> None:
|
| 15 |
+
"""Minimal safetensors writer: header JSON + dummy tensor bytes."""
|
| 16 |
+
header_json = json.dumps(key_dict).encode("utf-8")
|
| 17 |
+
header_len = struct.pack("<Q", len(header_json))
|
| 18 |
+
path.write_bytes(header_len + header_json + b"\0" * 8)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def test_sniff_accepts_ace_step_lora(tmp_path):
|
| 22 |
+
p = tmp_path / "psytrance.safetensors"
|
| 23 |
+
_write_safetensors(
|
| 24 |
+
p,
|
| 25 |
+
{
|
| 26 |
+
"transformer.blocks.0.attn.to_q.lora_A.weight": {
|
| 27 |
+
"dtype": "BF16",
|
| 28 |
+
"shape": [64, 768],
|
| 29 |
+
"data_offsets": [0, 8],
|
| 30 |
+
},
|
| 31 |
+
"transformer.blocks.0.attn.to_q.lora_B.weight": {
|
| 32 |
+
"dtype": "BF16",
|
| 33 |
+
"shape": [768, 64],
|
| 34 |
+
"data_offsets": [0, 8],
|
| 35 |
+
},
|
| 36 |
+
},
|
| 37 |
+
)
|
| 38 |
+
info = ls.sniff(p)
|
| 39 |
+
assert info.compatible is True
|
| 40 |
+
assert info.rank == 64
|
| 41 |
+
assert "to_q" in info.target_modules
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_sniff_rejects_sdxl_lora(tmp_path):
|
| 45 |
+
p = tmp_path / "sdxl.safetensors"
|
| 46 |
+
_write_safetensors(
|
| 47 |
+
p,
|
| 48 |
+
{
|
| 49 |
+
"unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.lora_A.weight": {
|
| 50 |
+
"dtype": "F16",
|
| 51 |
+
"shape": [16, 320],
|
| 52 |
+
"data_offsets": [0, 8],
|
| 53 |
+
},
|
| 54 |
+
},
|
| 55 |
+
)
|
| 56 |
+
info = ls.sniff(p)
|
| 57 |
+
assert info.compatible is False
|
| 58 |
+
assert "expected" in info.diagnostic.lower()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def test_sniff_rejects_oversize(tmp_path):
|
| 62 |
+
p = tmp_path / "huge.safetensors"
|
| 63 |
+
p.write_bytes(b"\0" * (600 * 1024 * 1024))
|
| 64 |
+
with pytest.raises(ls.LoRAValidationError, match="too large"):
|
| 65 |
+
ls.sniff(p)
|