Spaces:
Running on Zero
Running on Zero
feat(lora): add single-lora stack apply + disable
Browse files- lora_stack.py +33 -0
- tests/test_lora_stack.py +41 -1
lora_stack.py
CHANGED
|
@@ -20,10 +20,13 @@ function can be extended without changing the wrapper's call sites.
|
|
| 20 |
from __future__ import annotations
|
| 21 |
|
| 22 |
import json
|
|
|
|
| 23 |
import struct
|
| 24 |
from dataclasses import dataclass
|
| 25 |
from pathlib import Path
|
| 26 |
|
|
|
|
|
|
|
| 27 |
# Expected DiT module suffixes for ACE-Step 1.5 XL SFT.
|
| 28 |
# Match against `*.to_q.lora_A.weight`, etc.
|
| 29 |
_EXPECTED_MODULES = {"to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"}
|
|
@@ -149,3 +152,33 @@ def download_preset(name: str) -> Path:
|
|
| 149 |
f"Could not download preset {name!r} from {p['hf_id']!r}: {e}"
|
| 150 |
) from e
|
| 151 |
raise LoRAValidationError(f"Unknown preset: {name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
from __future__ import annotations
|
| 21 |
|
| 22 |
import json
|
| 23 |
+
import logging
|
| 24 |
import struct
|
| 25 |
from dataclasses import dataclass
|
| 26 |
from pathlib import Path
|
| 27 |
|
| 28 |
+
_log = logging.getLogger("ams.lora")
|
| 29 |
+
|
| 30 |
# Expected DiT module suffixes for ACE-Step 1.5 XL SFT.
|
| 31 |
# Match against `*.to_q.lora_A.weight`, etc.
|
| 32 |
_EXPECTED_MODULES = {"to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"}
|
|
|
|
| 152 |
f"Could not download preset {name!r} from {p['hf_id']!r}: {e}"
|
| 153 |
) from e
|
| 154 |
raise LoRAValidationError(f"Unknown preset: {name}")
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def apply_stack(pipe, stack: list[dict]) -> None:
|
| 158 |
+
"""Activate the given LoRA stack on the pipeline's DiT handler.
|
| 159 |
+
|
| 160 |
+
Apple-Silicon fork supports only one active LoRA at a time
|
| 161 |
+
(see module docstring). Behaviour:
|
| 162 |
+
|
| 163 |
+
- ``stack == []``: disable + unload the current LoRA.
|
| 164 |
+
- ``len(stack) == 1``: load + set scale + enable.
|
| 165 |
+
- ``len(stack) >= 2``: load the first, warn that the rest is ignored.
|
| 166 |
+
"""
|
| 167 |
+
dit = pipe._dit # internal AceStepHandler reference
|
| 168 |
+
if not stack:
|
| 169 |
+
dit.unload_lora()
|
| 170 |
+
dit.set_use_lora(False)
|
| 171 |
+
return
|
| 172 |
+
|
| 173 |
+
if len(stack) > 1:
|
| 174 |
+
_log.warning(
|
| 175 |
+
"apply_stack received %d LoRAs but only one is supported by "
|
| 176 |
+
"the apple-silicon ACE-Step fork; activating %r and ignoring the rest.",
|
| 177 |
+
len(stack),
|
| 178 |
+
stack[0]["name"],
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
first = stack[0]
|
| 182 |
+
dit.load_lora(first["path"])
|
| 183 |
+
dit.set_lora_scale(float(first["scale"]))
|
| 184 |
+
dit.set_use_lora(True)
|
tests/test_lora_stack.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
-
"""L1 tests for LoRA header sniffing — no torch, no pipeline."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import json
|
|
|
|
| 6 |
import struct
|
| 7 |
from pathlib import Path
|
|
|
|
| 8 |
|
| 9 |
import pytest
|
| 10 |
|
|
@@ -63,3 +65,41 @@ def test_sniff_rejects_oversize(tmp_path):
|
|
| 63 |
p.write_bytes(b"\0" * (600 * 1024 * 1024))
|
| 64 |
with pytest.raises(ls.LoRAValidationError, match="too large"):
|
| 65 |
ls.sniff(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""L1 tests for LoRA header sniffing + apply_stack — no torch, no pipeline."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import json
|
| 6 |
+
import logging
|
| 7 |
import struct
|
| 8 |
from pathlib import Path
|
| 9 |
+
from unittest.mock import MagicMock
|
| 10 |
|
| 11 |
import pytest
|
| 12 |
|
|
|
|
| 65 |
p.write_bytes(b"\0" * (600 * 1024 * 1024))
|
| 66 |
with pytest.raises(ls.LoRAValidationError, match="too large"):
|
| 67 |
ls.sniff(p)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def test_apply_stack_empty_disables_lora():
|
| 71 |
+
pipe = MagicMock()
|
| 72 |
+
pipe._dit = MagicMock()
|
| 73 |
+
ls.apply_stack(pipe, [])
|
| 74 |
+
pipe._dit.unload_lora.assert_called_once()
|
| 75 |
+
pipe._dit.set_use_lora.assert_called_with(False)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def test_apply_stack_single_lora_loads_and_enables(tmp_path):
|
| 79 |
+
pipe = MagicMock()
|
| 80 |
+
pipe._dit = MagicMock()
|
| 81 |
+
fake_path = tmp_path / "psy.safetensors"
|
| 82 |
+
fake_path.write_bytes(b"\0")
|
| 83 |
+
stack = [{"name": "psytrance_v2", "scale": 0.95, "path": str(fake_path), "sha256": "a" * 64}]
|
| 84 |
+
ls.apply_stack(pipe, stack)
|
| 85 |
+
pipe._dit.load_lora.assert_called_once_with(str(fake_path))
|
| 86 |
+
pipe._dit.set_lora_scale.assert_called_once_with(0.95)
|
| 87 |
+
pipe._dit.set_use_lora.assert_called_with(True)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def test_apply_stack_multi_lora_uses_first_and_warns(tmp_path, caplog):
|
| 91 |
+
pipe = MagicMock()
|
| 92 |
+
pipe._dit = MagicMock()
|
| 93 |
+
a = tmp_path / "a.safetensors"
|
| 94 |
+
a.write_bytes(b"\0")
|
| 95 |
+
b = tmp_path / "b.safetensors"
|
| 96 |
+
b.write_bytes(b"\0")
|
| 97 |
+
stack = [
|
| 98 |
+
{"name": "a", "scale": 0.85, "path": str(a), "sha256": "1" * 64},
|
| 99 |
+
{"name": "b", "scale": 0.95, "path": str(b), "sha256": "2" * 64},
|
| 100 |
+
]
|
| 101 |
+
with caplog.at_level(logging.WARNING):
|
| 102 |
+
ls.apply_stack(pipe, stack)
|
| 103 |
+
pipe._dit.load_lora.assert_called_once_with(str(a))
|
| 104 |
+
pipe._dit.set_lora_scale.assert_called_once_with(0.85)
|
| 105 |
+
assert any("only one" in r.message.lower() or "single" in r.message.lower() for r in caplog.records)
|