Spaces:
Running on Zero
Running on Zero
feat(backend): apply lora stack before mode dispatch
Browse files- backend.py +2 -0
- tests/test_backend.py +29 -0
backend.py
CHANGED
|
@@ -20,6 +20,7 @@ except ImportError: # pragma: no cover - covered by manual local testing
|
|
| 20 |
_HAS_SPACES = False
|
| 21 |
|
| 22 |
import ace_pipeline as ap
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
def _maybe_seed(seed: int | None) -> int:
|
|
@@ -51,6 +52,7 @@ class ACEStepStudioBackend:
|
|
| 51 |
params["seed"] = _maybe_seed(params.get("seed"))
|
| 52 |
t0 = time.time()
|
| 53 |
pipe = ap.get_pipeline()
|
|
|
|
| 54 |
out_path = self._call_pipe_for_mode(pipe, mode, params)
|
| 55 |
meta = {
|
| 56 |
"mode": mode,
|
|
|
|
| 20 |
_HAS_SPACES = False
|
| 21 |
|
| 22 |
import ace_pipeline as ap
|
| 23 |
+
import lora_stack
|
| 24 |
|
| 25 |
|
| 26 |
def _maybe_seed(seed: int | None) -> int:
|
|
|
|
| 52 |
params["seed"] = _maybe_seed(params.get("seed"))
|
| 53 |
t0 = time.time()
|
| 54 |
pipe = ap.get_pipeline()
|
| 55 |
+
lora_stack.apply_stack(pipe, params.get("loras", []))
|
| 56 |
out_path = self._call_pipe_for_mode(pipe, mode, params)
|
| 57 |
meta = {
|
| 58 |
"mode": mode,
|
tests/test_backend.py
CHANGED
|
@@ -69,3 +69,32 @@ def test_dispatch_random_seed_if_zero(monkeypatch, tmp_path):
|
|
| 69 |
# The seed-resolved value is the one forwarded to the wrapper
|
| 70 |
sent_params = fake_pipe.generate.call_args.args[0]
|
| 71 |
assert sent_params["seed"] == meta["seed"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
# The seed-resolved value is the one forwarded to the wrapper
|
| 70 |
sent_params = fake_pipe.generate.call_args.args[0]
|
| 71 |
assert sent_params["seed"] == meta["seed"]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_dispatch_applies_lora_stack(monkeypatch, tmp_path):
|
| 75 |
+
fake_pipe = MagicMock()
|
| 76 |
+
fake_pipe.generate.return_value = str(tmp_path / "x.wav")
|
| 77 |
+
(tmp_path / "x.wav").write_bytes(b"RIFF")
|
| 78 |
+
monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
|
| 79 |
+
|
| 80 |
+
apply_mock = MagicMock()
|
| 81 |
+
monkeypatch.setattr("lora_stack.apply_stack", apply_mock)
|
| 82 |
+
|
| 83 |
+
b = be.ACEStepStudioBackend()
|
| 84 |
+
stack = [{"name": "RapMachine", "scale": 0.85, "path": "/x.safetensors", "sha256": "a" * 64}]
|
| 85 |
+
b.dispatch(
|
| 86 |
+
mode="generate",
|
| 87 |
+
params={
|
| 88 |
+
"prompt": "p",
|
| 89 |
+
"lyrics": "",
|
| 90 |
+
"duration_s": 5,
|
| 91 |
+
"instrumental": False,
|
| 92 |
+
"seed": 1,
|
| 93 |
+
"loras": stack,
|
| 94 |
+
"advanced": {},
|
| 95 |
+
"lm": {},
|
| 96 |
+
"dcw": {},
|
| 97 |
+
},
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
apply_mock.assert_called_once_with(fake_pipe, stack)
|