techfreakworm commited on
Commit
96012ce
·
unverified ·
1 Parent(s): e701df3

feat(backend): apply lora stack before mode dispatch

Browse files
Files changed (2) hide show
  1. backend.py +2 -0
  2. tests/test_backend.py +29 -0
backend.py CHANGED
@@ -20,6 +20,7 @@ except ImportError: # pragma: no cover - covered by manual local testing
20
  _HAS_SPACES = False
21
 
22
  import ace_pipeline as ap
 
23
 
24
 
25
  def _maybe_seed(seed: int | None) -> int:
@@ -51,6 +52,7 @@ class ACEStepStudioBackend:
51
  params["seed"] = _maybe_seed(params.get("seed"))
52
  t0 = time.time()
53
  pipe = ap.get_pipeline()
 
54
  out_path = self._call_pipe_for_mode(pipe, mode, params)
55
  meta = {
56
  "mode": mode,
 
20
  _HAS_SPACES = False
21
 
22
  import ace_pipeline as ap
23
+ import lora_stack
24
 
25
 
26
  def _maybe_seed(seed: int | None) -> int:
 
52
  params["seed"] = _maybe_seed(params.get("seed"))
53
  t0 = time.time()
54
  pipe = ap.get_pipeline()
55
+ lora_stack.apply_stack(pipe, params.get("loras", []))
56
  out_path = self._call_pipe_for_mode(pipe, mode, params)
57
  meta = {
58
  "mode": mode,
tests/test_backend.py CHANGED
@@ -69,3 +69,32 @@ def test_dispatch_random_seed_if_zero(monkeypatch, tmp_path):
69
  # The seed-resolved value is the one forwarded to the wrapper
70
  sent_params = fake_pipe.generate.call_args.args[0]
71
  assert sent_params["seed"] == meta["seed"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  # The seed-resolved value is the one forwarded to the wrapper
70
  sent_params = fake_pipe.generate.call_args.args[0]
71
  assert sent_params["seed"] == meta["seed"]
72
+
73
+
74
+ def test_dispatch_applies_lora_stack(monkeypatch, tmp_path):
75
+ fake_pipe = MagicMock()
76
+ fake_pipe.generate.return_value = str(tmp_path / "x.wav")
77
+ (tmp_path / "x.wav").write_bytes(b"RIFF")
78
+ monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
79
+
80
+ apply_mock = MagicMock()
81
+ monkeypatch.setattr("lora_stack.apply_stack", apply_mock)
82
+
83
+ b = be.ACEStepStudioBackend()
84
+ stack = [{"name": "RapMachine", "scale": 0.85, "path": "/x.safetensors", "sha256": "a" * 64}]
85
+ b.dispatch(
86
+ mode="generate",
87
+ params={
88
+ "prompt": "p",
89
+ "lyrics": "",
90
+ "duration_s": 5,
91
+ "instrumental": False,
92
+ "seed": 1,
93
+ "loras": stack,
94
+ "advanced": {},
95
+ "lm": {},
96
+ "dcw": {},
97
+ },
98
+ )
99
+
100
+ apply_mock.assert_called_once_with(fake_pipe, stack)