techfreakworm commited on
Commit
e701df3
·
unverified ·
1 Parent(s): 321117e

feat(lora): add single-lora stack apply + disable

Browse files
Files changed (2) hide show
  1. lora_stack.py +33 -0
  2. tests/test_lora_stack.py +41 -1
lora_stack.py CHANGED
@@ -20,10 +20,13 @@ function can be extended without changing the wrapper's call sites.
20
  from __future__ import annotations
21
 
22
  import json
 
23
  import struct
24
  from dataclasses import dataclass
25
  from pathlib import Path
26
 
 
 
27
  # Expected DiT module suffixes for ACE-Step 1.5 XL SFT.
28
  # Match against `*.to_q.lora_A.weight`, etc.
29
  _EXPECTED_MODULES = {"to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"}
@@ -149,3 +152,33 @@ def download_preset(name: str) -> Path:
149
  f"Could not download preset {name!r} from {p['hf_id']!r}: {e}"
150
  ) from e
151
  raise LoRAValidationError(f"Unknown preset: {name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  from __future__ import annotations
21
 
22
  import json
23
+ import logging
24
  import struct
25
  from dataclasses import dataclass
26
  from pathlib import Path
27
 
28
+ _log = logging.getLogger("ams.lora")
29
+
30
  # Expected DiT module suffixes for ACE-Step 1.5 XL SFT.
31
  # Match against `*.to_q.lora_A.weight`, etc.
32
  _EXPECTED_MODULES = {"to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"}
 
152
  f"Could not download preset {name!r} from {p['hf_id']!r}: {e}"
153
  ) from e
154
  raise LoRAValidationError(f"Unknown preset: {name}")
155
+
156
+
157
+ def apply_stack(pipe, stack: list[dict]) -> None:
158
+ """Activate the given LoRA stack on the pipeline's DiT handler.
159
+
160
+ Apple-Silicon fork supports only one active LoRA at a time
161
+ (see module docstring). Behaviour:
162
+
163
+ - ``stack == []``: disable + unload the current LoRA.
164
+ - ``len(stack) == 1``: load + set scale + enable.
165
+ - ``len(stack) >= 2``: load the first, warn that the rest is ignored.
166
+ """
167
+ dit = pipe._dit # internal AceStepHandler reference
168
+ if not stack:
169
+ dit.unload_lora()
170
+ dit.set_use_lora(False)
171
+ return
172
+
173
+ if len(stack) > 1:
174
+ _log.warning(
175
+ "apply_stack received %d LoRAs but only one is supported by "
176
+ "the apple-silicon ACE-Step fork; activating %r and ignoring the rest.",
177
+ len(stack),
178
+ stack[0]["name"],
179
+ )
180
+
181
+ first = stack[0]
182
+ dit.load_lora(first["path"])
183
+ dit.set_lora_scale(float(first["scale"]))
184
+ dit.set_use_lora(True)
tests/test_lora_stack.py CHANGED
@@ -1,10 +1,12 @@
1
- """L1 tests for LoRA header sniffing — no torch, no pipeline."""
2
 
3
  from __future__ import annotations
4
 
5
  import json
 
6
  import struct
7
  from pathlib import Path
 
8
 
9
  import pytest
10
 
@@ -63,3 +65,41 @@ def test_sniff_rejects_oversize(tmp_path):
63
  p.write_bytes(b"\0" * (600 * 1024 * 1024))
64
  with pytest.raises(ls.LoRAValidationError, match="too large"):
65
  ls.sniff(p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """L1 tests for LoRA header sniffing + apply_stack — no torch, no pipeline."""
2
 
3
  from __future__ import annotations
4
 
5
  import json
6
+ import logging
7
  import struct
8
  from pathlib import Path
9
+ from unittest.mock import MagicMock
10
 
11
  import pytest
12
 
 
65
  p.write_bytes(b"\0" * (600 * 1024 * 1024))
66
  with pytest.raises(ls.LoRAValidationError, match="too large"):
67
  ls.sniff(p)
68
+
69
+
70
+ def test_apply_stack_empty_disables_lora():
71
+ pipe = MagicMock()
72
+ pipe._dit = MagicMock()
73
+ ls.apply_stack(pipe, [])
74
+ pipe._dit.unload_lora.assert_called_once()
75
+ pipe._dit.set_use_lora.assert_called_with(False)
76
+
77
+
78
+ def test_apply_stack_single_lora_loads_and_enables(tmp_path):
79
+ pipe = MagicMock()
80
+ pipe._dit = MagicMock()
81
+ fake_path = tmp_path / "psy.safetensors"
82
+ fake_path.write_bytes(b"\0")
83
+ stack = [{"name": "psytrance_v2", "scale": 0.95, "path": str(fake_path), "sha256": "a" * 64}]
84
+ ls.apply_stack(pipe, stack)
85
+ pipe._dit.load_lora.assert_called_once_with(str(fake_path))
86
+ pipe._dit.set_lora_scale.assert_called_once_with(0.95)
87
+ pipe._dit.set_use_lora.assert_called_with(True)
88
+
89
+
90
+ def test_apply_stack_multi_lora_uses_first_and_warns(tmp_path, caplog):
91
+ pipe = MagicMock()
92
+ pipe._dit = MagicMock()
93
+ a = tmp_path / "a.safetensors"
94
+ a.write_bytes(b"\0")
95
+ b = tmp_path / "b.safetensors"
96
+ b.write_bytes(b"\0")
97
+ stack = [
98
+ {"name": "a", "scale": 0.85, "path": str(a), "sha256": "1" * 64},
99
+ {"name": "b", "scale": 0.95, "path": str(b), "sha256": "2" * 64},
100
+ ]
101
+ with caplog.at_level(logging.WARNING):
102
+ ls.apply_stack(pipe, stack)
103
+ pipe._dit.load_lora.assert_called_once_with(str(a))
104
+ pipe._dit.set_lora_scale.assert_called_once_with(0.85)
105
+ assert any("only one" in r.message.lower() or "single" in r.message.lower() for r in caplog.records)