techfreakworm commited on
Commit
99375d0
·
unverified ·
1 Parent(s): 65ab3e7

refactor(pipeline): rewrite for real acestep AceStepHandler+LLMHandler api

Browse files

The original ace_pipeline assumed ``ACEStepPipeline.from_pretrained()`` —
a clean library entry point that does NOT exist in the installed
``acestep`` package (apple-silicon fork on Mac, upstream on CUDA). The
real API is a split-handler pattern:

from acestep.handler import AceStepHandler
from acestep.llm_inference import LLMHandler
from acestep.inference import GenerationParams, GenerationConfig, generate_music

dit = AceStepHandler(); dit.initialize_service(project_root, config_path, device)
lm = LLMHandler(); lm.initialize(checkpoint_dir, lm_model_path, backend, device)
result = generate_music(dit, lm, GenerationParams(...), GenerationConfig(...))

To keep ``modes.py`` and ``backend.py`` clean, ``ace_pipeline`` now
exposes a single ``ACEStepStudio`` wrapper that owns both handlers and
exposes ``generate(params: dict) -> str`` returning the audio path.

Defaults:
- DiT: ACE-Step/acestep-v15-xl-sft (~16 GB) → ``./checkpoints/acestep-v15-xl-sft/``
- LM: ACE-Step/acestep-5Hz-lm-0.6B (~1.4 GB) → ``./checkpoints/acestep-5Hz-lm-0.6B/``

The fork auto-routes ``backend='vllm'`` to ``mlx`` on ``device='mps'``
when mlx-lm is installed, so the same code path works on Mac and CUDA.

Tests updated to mock the wrapper interface: ``pipe.generate(params)``
instead of ``pipe(...)``. 17/17 L1+L2 pass; the GPU smoke (deselected
by default) exercises the real pipeline once checkpoints are downloaded.

Closes spec §14.1 open question (canonical ACE-Step Python API).

ace_pipeline.py CHANGED
@@ -1,12 +1,49 @@
1
  """ACE-Step pipeline lifecycle: device autodetect, lazy load, cache mirror.
2
 
3
- Mirrors z-image-studio's `models.py` pattern. M0 only implements device
4
- detection the pipeline class itself is filled in at M1.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
 
7
  from __future__ import annotations
8
 
9
  import os
 
 
 
 
 
 
 
10
 
11
 
12
  def detect_device() -> str:
@@ -27,9 +64,9 @@ def detect_device() -> str:
27
  def vram_limit_for(device: str) -> int | None:
28
  """Returns a VRAM cap in bytes for CUDA, None otherwise.
29
 
30
- `torch.mps` has no `mem_get_info` — calling DiffSynth-style free-VRAM
31
- gates with a numeric limit would crash on MPS. Returning None lets the
32
- pipeline short-circuit those checks.
33
  """
34
  if device != "cuda":
35
  return None
@@ -43,30 +80,139 @@ def vram_limit_for(device: str) -> int | None:
43
  return None
44
 
45
 
46
- _PIPELINE = None # module-level lazy singleton
47
- _DEFAULT_MODEL_ID = "ACE-Step/acestep-v15-xl-sft"
48
-
49
-
50
- def _load_pipeline(device: str, model_path: str):
51
- """Construct the ACE-Step pipeline. Heavy import is local so unit tests can mock."""
52
- from ace_step import ACEStepPipeline # type: ignore[import-not-found]
53
-
54
- # On Mac, the apple-silicon fork sets dtype + backend automatically.
55
- # On CUDA we pass bf16 explicitly.
56
- if device == "cuda":
57
- pipe = ACEStepPipeline.from_pretrained(model_path, torch_dtype="bf16")
58
- else:
59
- pipe = ACEStepPipeline.from_pretrained(model_path)
60
-
61
- pipe.to(device)
62
- return pipe
63
 
 
 
 
64
 
65
- def get_pipeline():
66
- """Lazy-load the ACE-Step pipeline once per process."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  global _PIPELINE
68
  if _PIPELINE is None:
69
- device = detect_device()
70
- model_path = os.environ.get("ACE_MODEL_PATH", _DEFAULT_MODEL_ID)
71
- _PIPELINE = _load_pipeline(device, model_path)
72
  return _PIPELINE
 
1
  """ACE-Step pipeline lifecycle: device autodetect, lazy load, cache mirror.
2
 
3
+ The installed ``acestep`` package (apple-silicon fork on Mac, upstream on
4
+ CUDA) does NOT expose a single ``ACEStepPipeline.from_pretrained`` entry
5
+ point. The real API is a split-handler pattern:
6
+
7
+ from acestep.handler import AceStepHandler # DiT side
8
+ from acestep.llm_inference import LLMHandler # 5Hz LM planner
9
+ from acestep.inference import (
10
+ GenerationParams, GenerationConfig, generate_music,
11
+ )
12
+
13
+ dit = AceStepHandler()
14
+ dit.initialize_service(project_root=..., config_path="acestep-v15-xl-sft",
15
+ device="mps")
16
+ lm = LLMHandler()
17
+ lm.initialize(checkpoint_dir=..., lm_model_path="acestep-5Hz-lm-0.6B",
18
+ backend="vllm", # auto-routes to mlx on mps
19
+ device="mps")
20
+ params = GenerationParams(caption=..., lyrics=..., duration=..., seed=...)
21
+ cfg = GenerationConfig(batch_size=1, audio_format="wav")
22
+ result = generate_music(dit, lm, params, cfg)
23
+ # result.audios[0]["path"] is the WAV file
24
+
25
+ To keep ``backend.py`` and ``modes.py`` clean, this module exposes a
26
+ single ``ACEStepStudio`` wrapper that owns both handlers and exposes a
27
+ ``generate(params: dict) -> str`` method returning the audio path.
28
+ ``get_pipeline()`` returns the lazy singleton wrapper.
29
+
30
+ Checkpoints live under ``{project_root}/checkpoints/{config_path}/``.
31
+ On Mac with the apple-silicon fork, the fork auto-downloads from
32
+ HuggingFace if a checkpoint is missing, but in practice we pre-download
33
+ via ``hf download`` before the first inference call to avoid pytest
34
+ timeouts.
35
  """
36
 
37
  from __future__ import annotations
38
 
39
  import os
40
+ from pathlib import Path
41
+
42
+ _REPO_ROOT = Path(__file__).resolve().parent
43
+ _CHECKPOINTS_DIR = _REPO_ROOT / "checkpoints"
44
+
45
+ _DEFAULT_DIT_CONFIG = "acestep-v15-xl-sft"
46
+ _DEFAULT_LM_MODEL = "acestep-5Hz-lm-0.6B"
47
 
48
 
49
  def detect_device() -> str:
 
64
  def vram_limit_for(device: str) -> int | None:
65
  """Returns a VRAM cap in bytes for CUDA, None otherwise.
66
 
67
+ ``torch.mps`` has no ``mem_get_info`` — calling DiffSynth-style
68
+ free-VRAM gates with a numeric limit would crash on MPS. Returning
69
+ None lets the pipeline short-circuit those checks.
70
  """
71
  if device != "cuda":
72
  return None
 
80
  return None
81
 
82
 
83
+ class ACEStepStudio:
84
+ """Wrapper around the apple-silicon fork's split-handler API.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ Owns one ``AceStepHandler`` (DiT) and one ``LLMHandler`` (5Hz LM
87
+ planner). Both are lazy-loaded on the first ``generate(...)`` call.
88
+ """
89
 
90
+ def __init__(
91
+ self,
92
+ dit_config: str | None = None,
93
+ lm_model: str | None = None,
94
+ device: str | None = None,
95
+ ) -> None:
96
+ self._dit = None
97
+ self._llm = None
98
+ self._dit_config = dit_config or os.environ.get("ACE_DIT_CONFIG", _DEFAULT_DIT_CONFIG)
99
+ self._lm_model = lm_model or os.environ.get("ACE_LM_MODEL", _DEFAULT_LM_MODEL)
100
+ self._device = device or detect_device()
101
+
102
+ @property
103
+ def device(self) -> str:
104
+ return self._device
105
+
106
+ @property
107
+ def is_loaded(self) -> bool:
108
+ return self._dit is not None and self._llm is not None
109
+
110
+ def _ensure_loaded(self) -> None:
111
+ """First-call lazy load of both handlers. Heavy imports stay local."""
112
+ if self.is_loaded:
113
+ return
114
+
115
+ from acestep.handler import AceStepHandler
116
+ from acestep.llm_inference import LLMHandler
117
+
118
+ dit = AceStepHandler()
119
+ dit.initialize_service(
120
+ project_root=str(_REPO_ROOT),
121
+ config_path=self._dit_config,
122
+ device=self._device,
123
+ )
124
+
125
+ llm = LLMHandler()
126
+ llm.initialize(
127
+ checkpoint_dir=str(_CHECKPOINTS_DIR),
128
+ lm_model_path=self._lm_model,
129
+ backend="vllm", # fork auto-routes to mlx on mps + mlx-lm installed
130
+ device=self._device,
131
+ )
132
+
133
+ self._dit = dit
134
+ self._llm = llm
135
+
136
+ def generate(self, params: dict) -> str:
137
+ """Run a single text→song generation.
138
+
139
+ ``params`` is the dict produced by ``modes.generate``:
140
+ ``{"prompt", "lyrics", "duration_s", "instrumental", "seed",
141
+ "loras", "advanced", "lm", "dcw"}``. Returns the path to the
142
+ produced audio file.
143
+ """
144
+ self._ensure_loaded()
145
+
146
+ from acestep.inference import (
147
+ GenerationConfig,
148
+ GenerationParams,
149
+ generate_music,
150
+ )
151
+
152
+ advanced = params.get("advanced", {}) or {}
153
+ lm_opts = params.get("lm", {}) or {}
154
+
155
+ # Map our internal dict to ACE-Step's GenerationParams.
156
+ # Lyrics "[Instrumental]" is the ACE-Step convention for instrumental.
157
+ lyrics = params.get("lyrics", "") or ""
158
+ instrumental = bool(params.get("instrumental", False))
159
+ if instrumental and not lyrics:
160
+ lyrics = "[Instrumental]"
161
+
162
+ gen_params = GenerationParams(
163
+ task_type="text2music",
164
+ caption=params.get("prompt", ""),
165
+ lyrics=lyrics,
166
+ instrumental=instrumental,
167
+ duration=int(params.get("duration_s", 30)),
168
+ seed=int(params.get("seed", -1)),
169
+ inference_steps=int(advanced.get("steps", 32)),
170
+ guidance_scale=float(advanced.get("cfg", 4.0)),
171
+ shift=float(advanced.get("shift", 1.0)),
172
+ bpm=advanced.get("bpm"),
173
+ keyscale=advanced.get("keyscale", ""),
174
+ timesignature=advanced.get("timesignature", ""),
175
+ vocal_language=advanced.get("vocal_language", "unknown"),
176
+ cfg_interval_start=float(advanced.get("cfg_interval_start", 0.0)),
177
+ cfg_interval_end=float(advanced.get("cfg_interval_end", 1.0)),
178
+ thinking=bool(lm_opts.get("thinking", False)),
179
+ lm_temperature=float(lm_opts.get("temperature", 0.85)),
180
+ lm_cfg_scale=float(lm_opts.get("cfg", 2.0)),
181
+ lm_top_k=int(lm_opts.get("top_k", 0)),
182
+ lm_top_p=float(lm_opts.get("top_p", 0.9)),
183
+ lm_negative_prompt=lm_opts.get("negative_prompt", ""),
184
+ use_cot_metas=bool(lm_opts.get("cot_metas", False)),
185
+ use_cot_caption=bool(lm_opts.get("cot_caption", False)),
186
+ use_cot_language=bool(lm_opts.get("cot_language", False)),
187
+ )
188
+
189
+ gen_config = GenerationConfig(
190
+ batch_size=1,
191
+ audio_format=advanced.get("audio_format", "wav"),
192
+ use_random_seed=False,
193
+ seeds=[int(params.get("seed", 1))],
194
+ )
195
+
196
+ result = generate_music(self._dit, self._llm, gen_params, gen_config)
197
+
198
+ if not result.success:
199
+ raise RuntimeError(f"ACE-Step generation failed: {result.error}")
200
+ if not result.audios:
201
+ raise RuntimeError("ACE-Step returned no audio outputs")
202
+
203
+ return result.audios[0]["path"]
204
+
205
+
206
+ _PIPELINE: ACEStepStudio | None = None # module-level lazy singleton
207
+
208
+
209
+ def get_pipeline() -> ACEStepStudio:
210
+ """Lazy-construct the ACE Music Studio wrapper.
211
+
212
+ The wrapper itself is cheap to construct; both handlers (DiT, LM)
213
+ are only loaded on the first ``generate(...)`` call.
214
+ """
215
  global _PIPELINE
216
  if _PIPELINE is None:
217
+ _PIPELINE = ACEStepStudio()
 
 
218
  return _PIPELINE
backend.py CHANGED
@@ -68,14 +68,17 @@ class ACEStepStudioBackend:
68
  return out_path, meta
69
 
70
  def _call_pipe_for_mode(self, pipe, mode: str, params: dict[str, Any]) -> str:
71
- """Mode-specific kwargs translation. Filled out per milestone."""
 
 
 
 
 
 
 
 
 
72
  if mode == "generate":
73
- return pipe(
74
- prompt=params["prompt"],
75
- lyrics=params.get("lyrics", ""),
76
- duration_s=params["duration_s"],
77
- instrumental=params.get("instrumental", False),
78
- seed=params["seed"],
79
- )
80
  # cover / extend / edit / lyrics get filled in at M3 / M4
81
  raise NotImplementedError(f"Mode {mode!r} is not wired yet")
 
68
  return out_path, meta
69
 
70
  def _call_pipe_for_mode(self, pipe, mode: str, params: dict[str, Any]) -> str:
71
+ """Dispatch to the pipeline wrapper.
72
+
73
+ ``pipe`` is the ``ACEStepStudio`` wrapper returned by
74
+ ``ace_pipeline.get_pipeline()``. It exposes a single
75
+ ``generate(params)`` method that handles the underlying
76
+ AceStepHandler + LLMHandler + generate_music plumbing.
77
+
78
+ Cover / Extend / Edit / Lyrics task_types are mapped here at
79
+ M3 / M4 by switching ``params["task_type"]`` before calling.
80
+ """
81
  if mode == "generate":
82
+ return pipe.generate(params)
 
 
 
 
 
 
83
  # cover / extend / edit / lyrics get filled in at M3 / M4
84
  raise NotImplementedError(f"Mode {mode!r} is not wired yet")
tests/test_ace_pipeline_lazy.py CHANGED
@@ -1,39 +1,166 @@
1
- """L2 tests for pipeline lazy loadmock the heavy ACE-Step import."""
2
 
3
  from __future__ import annotations
4
 
 
5
  from unittest.mock import MagicMock
6
 
 
 
7
  import ace_pipeline as ap
8
 
9
 
10
- def test_get_pipeline_loads_lazily_first_call_only(monkeypatch):
11
- fake_pipe = MagicMock(name="fake_ace_pipeline")
12
- loader = MagicMock(return_value=fake_pipe)
13
- monkeypatch.setattr(ap, "_load_pipeline", loader)
14
  monkeypatch.setattr(ap, "_PIPELINE", None, raising=False)
15
-
16
  p1 = ap.get_pipeline()
17
  p2 = ap.get_pipeline()
 
 
18
 
19
- assert p1 is fake_pipe
20
- assert p2 is fake_pipe
21
- assert loader.call_count == 1, "pipeline should load exactly once"
22
 
 
 
 
 
 
23
 
24
- def test_get_pipeline_uses_detected_device(monkeypatch):
25
- monkeypatch.setattr(ap, "_PIPELINE", None, raising=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  monkeypatch.setattr(ap, "detect_device", lambda: "mps")
27
- captured = {}
28
 
29
- def fake_load(device, model_path):
30
- captured["device"] = device
31
- captured["model_path"] = model_path
32
- return MagicMock()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- monkeypatch.setattr(ap, "_load_pipeline", fake_load)
 
 
 
 
35
 
36
- ap.get_pipeline()
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- assert captured["device"] == "mps"
39
- assert captured["model_path"] is not None
 
 
1
+ """L2 tests for the ACEStepStudio wrappermocks the heavy acestep imports."""
2
 
3
  from __future__ import annotations
4
 
5
+ import sys
6
  from unittest.mock import MagicMock
7
 
8
+ import pytest
9
+
10
  import ace_pipeline as ap
11
 
12
 
13
+ def test_get_pipeline_returns_singleton(monkeypatch):
 
 
 
14
  monkeypatch.setattr(ap, "_PIPELINE", None, raising=False)
 
15
  p1 = ap.get_pipeline()
16
  p2 = ap.get_pipeline()
17
+ assert p1 is p2
18
+ assert isinstance(p1, ap.ACEStepStudio)
19
 
 
 
 
20
 
21
+ def test_studio_constructor_uses_detected_device(monkeypatch):
22
+ monkeypatch.setattr(ap, "detect_device", lambda: "mps")
23
+ studio = ap.ACEStepStudio()
24
+ assert studio.device == "mps"
25
+ assert studio.is_loaded is False # handlers are lazy
26
 
27
+
28
+ def test_studio_constructor_respects_env_overrides(monkeypatch):
29
+ monkeypatch.setenv("ACE_DIT_CONFIG", "custom-dit")
30
+ monkeypatch.setenv("ACE_LM_MODEL", "custom-lm")
31
+ monkeypatch.setattr(ap, "detect_device", lambda: "cpu")
32
+ studio = ap.ACEStepStudio()
33
+ assert studio._dit_config == "custom-dit"
34
+ assert studio._lm_model == "custom-lm"
35
+
36
+
37
+ def test_studio_ensure_loaded_constructs_both_handlers(monkeypatch):
38
+ fake_dit_cls = MagicMock(name="AceStepHandler")
39
+ fake_lm_cls = MagicMock(name="LLMHandler")
40
+ fake_dit = MagicMock()
41
+ fake_lm = MagicMock()
42
+ fake_dit_cls.return_value = fake_dit
43
+ fake_lm_cls.return_value = fake_lm
44
+
45
+ handler_mod = MagicMock()
46
+ handler_mod.AceStepHandler = fake_dit_cls
47
+ llm_mod = MagicMock()
48
+ llm_mod.LLMHandler = fake_lm_cls
49
+
50
+ monkeypatch.setitem(sys.modules, "acestep.handler", handler_mod)
51
+ monkeypatch.setitem(sys.modules, "acestep.llm_inference", llm_mod)
52
  monkeypatch.setattr(ap, "detect_device", lambda: "mps")
 
53
 
54
+ studio = ap.ACEStepStudio()
55
+ studio._ensure_loaded()
56
+
57
+ fake_dit_cls.assert_called_once()
58
+ fake_lm_cls.assert_called_once()
59
+ fake_dit.initialize_service.assert_called_once()
60
+ fake_lm.initialize.assert_called_once()
61
+ assert fake_dit.initialize_service.call_args.kwargs["device"] == "mps"
62
+ assert fake_lm.initialize.call_args.kwargs["device"] == "mps"
63
+ assert fake_dit.initialize_service.call_args.kwargs["config_path"] == "acestep-v15-xl-sft"
64
+ assert fake_lm.initialize.call_args.kwargs["lm_model_path"] == "acestep-5Hz-lm-0.6B"
65
+
66
+
67
+ def _install_fake_inference(monkeypatch, success=True, audios=None, error=None):
68
+ """Plant a fake ``acestep.inference`` module and return the spies."""
69
+ if audios is None:
70
+ audios = [{"path": "/tmp/x.wav"}]
71
+ fake_result = MagicMock(success=success, audios=audios, error=error)
72
+ fake_generate = MagicMock(return_value=fake_result)
73
+ captured = {"gp": {}, "gc": {}}
74
+
75
+ def fake_gp(**kw):
76
+ captured["gp"] = kw
77
+ return kw
78
+
79
+ def fake_gc(**kw):
80
+ captured["gc"] = kw
81
+ return kw
82
+
83
+ fake_inference = MagicMock()
84
+ fake_inference.generate_music = fake_generate
85
+ fake_inference.GenerationParams = MagicMock(side_effect=fake_gp)
86
+ fake_inference.GenerationConfig = MagicMock(side_effect=fake_gc)
87
+ monkeypatch.setitem(sys.modules, "acestep.inference", fake_inference)
88
+ return fake_generate, captured
89
+
90
+
91
+ def test_studio_generate_builds_params_and_calls_generate_music(monkeypatch, tmp_path):
92
+ out_wav = tmp_path / "out.wav"
93
+ out_wav.write_bytes(b"RIFF" + b"\0" * 100)
94
+
95
+ fake_generate, captured = _install_fake_inference(monkeypatch, audios=[{"path": str(out_wav)}])
96
+
97
+ studio = ap.ACEStepStudio()
98
+ studio._dit = MagicMock(name="dit")
99
+ studio._llm = MagicMock(name="llm")
100
+
101
+ result_path = studio.generate(
102
+ {
103
+ "prompt": "psytrance",
104
+ "lyrics": "[verse]",
105
+ "duration_s": 30,
106
+ "instrumental": False,
107
+ "seed": 42,
108
+ "loras": [],
109
+ "advanced": {"steps": 32, "cfg": 4.0, "bpm": 135},
110
+ "lm": {"thinking": False},
111
+ "dcw": {},
112
+ }
113
+ )
114
+
115
+ assert result_path == str(out_wav)
116
+ fake_generate.assert_called_once()
117
+ assert captured["gp"]["caption"] == "psytrance"
118
+ assert captured["gp"]["duration"] == 30
119
+ assert captured["gp"]["seed"] == 42
120
+ assert captured["gp"]["inference_steps"] == 32
121
+ assert captured["gp"]["bpm"] == 135
122
+
123
+
124
+ def test_studio_generate_raises_on_failure(monkeypatch):
125
+ _install_fake_inference(monkeypatch, success=False, audios=[], error="OOM")
126
+ studio = ap.ACEStepStudio()
127
+ studio._dit = MagicMock()
128
+ studio._llm = MagicMock()
129
+
130
+ with pytest.raises(RuntimeError, match="OOM"):
131
+ studio.generate(
132
+ {
133
+ "prompt": "p",
134
+ "lyrics": "",
135
+ "duration_s": 5,
136
+ "instrumental": True,
137
+ "seed": 1,
138
+ "advanced": {},
139
+ "lm": {},
140
+ "dcw": {},
141
+ }
142
+ )
143
+
144
 
145
+ def test_studio_generate_uses_instrumental_marker_when_lyrics_empty(monkeypatch):
146
+ _fake_generate, captured = _install_fake_inference(monkeypatch)
147
+ studio = ap.ACEStepStudio()
148
+ studio._dit = MagicMock()
149
+ studio._llm = MagicMock()
150
 
151
+ studio.generate(
152
+ {
153
+ "prompt": "drone",
154
+ "lyrics": "",
155
+ "duration_s": 5,
156
+ "instrumental": True,
157
+ "seed": 1,
158
+ "advanced": {},
159
+ "lm": {},
160
+ "dcw": {},
161
+ }
162
+ )
163
 
164
+ # Instrumental + empty lyrics → ACE-Step convention is "[Instrumental]"
165
+ assert captured["gp"]["lyrics"] == "[Instrumental]"
166
+ assert captured["gp"]["instrumental"] is True
tests/test_backend.py CHANGED
@@ -1,4 +1,4 @@
1
- """L2 tests for backend.dispatch — pipeline is mocked at the boundary."""
2
 
3
  from __future__ import annotations
4
 
@@ -7,12 +7,13 @@ from unittest.mock import MagicMock
7
  import backend as be
8
 
9
 
10
- def test_dispatch_generate_calls_pipeline_with_expected_kwargs(monkeypatch, tmp_path):
11
- fake_pipe = MagicMock()
12
  fake_out = tmp_path / "out.wav"
13
  fake_out.write_bytes(b"RIFF" + b"\0" * 1000)
14
- fake_pipe.return_value = str(fake_out)
15
 
 
 
16
  monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
17
 
18
  b = be.ACEStepStudioBackend()
@@ -34,13 +35,19 @@ def test_dispatch_generate_calls_pipeline_with_expected_kwargs(monkeypatch, tmp_
34
  assert out_path == str(fake_out)
35
  assert meta["mode"] == "generate"
36
  assert meta["seed"] == 42
37
- fake_pipe.assert_called_once()
 
 
 
 
38
 
39
 
40
  def test_dispatch_random_seed_if_zero(monkeypatch, tmp_path):
41
- fake_pipe = MagicMock(return_value=str(tmp_path / "x.wav"))
 
 
 
42
  monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
43
- (tmp_path / "x.wav").write_bytes(b"RIFF")
44
 
45
  b = be.ACEStepStudioBackend()
46
  _, meta = b.dispatch(
@@ -59,3 +66,6 @@ def test_dispatch_random_seed_if_zero(monkeypatch, tmp_path):
59
  )
60
 
61
  assert 1 <= meta["seed"] <= 2_147_483_647
 
 
 
 
1
+ """L2 tests for backend.dispatch — pipeline is mocked at the wrapper boundary."""
2
 
3
  from __future__ import annotations
4
 
 
7
  import backend as be
8
 
9
 
10
+ def test_dispatch_generate_calls_pipeline_generate(monkeypatch, tmp_path):
11
+ """Backend should call ``pipe.generate(params)`` and return its path."""
12
  fake_out = tmp_path / "out.wav"
13
  fake_out.write_bytes(b"RIFF" + b"\0" * 1000)
 
14
 
15
+ fake_pipe = MagicMock()
16
+ fake_pipe.generate.return_value = str(fake_out)
17
  monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
18
 
19
  b = be.ACEStepStudioBackend()
 
35
  assert out_path == str(fake_out)
36
  assert meta["mode"] == "generate"
37
  assert meta["seed"] == 42
38
+ fake_pipe.generate.assert_called_once()
39
+ # The full params dict is forwarded to pipe.generate
40
+ sent_params = fake_pipe.generate.call_args.args[0]
41
+ assert sent_params["prompt"] == "psytrance"
42
+ assert sent_params["seed"] == 42
43
 
44
 
45
  def test_dispatch_random_seed_if_zero(monkeypatch, tmp_path):
46
+ out = tmp_path / "x.wav"
47
+ out.write_bytes(b"RIFF")
48
+ fake_pipe = MagicMock()
49
+ fake_pipe.generate.return_value = str(out)
50
  monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
 
51
 
52
  b = be.ACEStepStudioBackend()
53
  _, meta = b.dispatch(
 
66
  )
67
 
68
  assert 1 <= meta["seed"] <= 2_147_483_647
69
+ # The seed-resolved value is the one forwarded to the wrapper
70
+ sent_params = fake_pipe.generate.call_args.args[0]
71
+ assert sent_params["seed"] == meta["seed"]
tests/test_smoke_gpu.py CHANGED
@@ -6,14 +6,16 @@ pipeline. Run before each release tag.
6
  Skipped automatically in CI by the pyproject ``addopts = -m 'not gpu'``
7
  default. Requires:
8
 
9
- - ``ace-step`` installed (Apple Silicon fork on Mac, upstream on CUDA)
10
- - First run downloads ACE-Step 1.5 XL SFT weights (~16 GB) into the HF cache
 
 
 
11
  - A real MPS / CUDA device — CPU inference is functionally untested
12
  """
13
 
14
  from __future__ import annotations
15
 
16
- import os
17
  from pathlib import Path
18
 
19
  import pytest
@@ -21,31 +23,36 @@ import pytest
21
  pytestmark = pytest.mark.gpu
22
 
23
 
24
- def test_generate_minimum_song(tmp_path):
25
- """Smallest end-to-end: 5 s instrumental drone, seed=1."""
26
- os.environ.setdefault("ACE_MODEL_PATH", "ACE-Step/acestep-v15-xl-sft")
27
 
 
 
 
 
 
28
  from backend import ACEStepStudioBackend
29
 
30
  b = ACEStepStudioBackend()
31
  out_path, meta = b.dispatch(
32
  mode="generate",
33
  params={
34
- "prompt": "test tone, simple drone",
35
- "lyrics": "[intro] tone",
36
- "duration_s": 5,
37
  "instrumental": True,
38
  "seed": 1,
39
  "loras": [],
40
- "advanced": {},
41
- "lm": {},
 
42
  "dcw": {},
43
  },
44
  )
45
- assert Path(out_path).exists()
46
- assert Path(out_path).stat().st_size > 0
 
 
47
  assert meta["mode"] == "generate"
48
  assert meta["seed"] == 1
49
- # Wall time should be < 5 min even on first cold run + 16 GB weight download.
50
- # Subsequent runs should be < 30 s on M5 Max.
51
  assert meta["wall_seconds"] > 0
 
6
  Skipped automatically in CI by the pyproject ``addopts = -m 'not gpu'``
7
  default. Requires:
8
 
9
+ - ``acestep`` package installed (Apple Silicon fork on Mac, upstream on CUDA)
10
+ - DiT checkpoint at ``./checkpoints/acestep-v15-xl-sft/`` (~16 GB) download via
11
+ ``hf download ACE-Step/acestep-v15-xl-sft --local-dir checkpoints/acestep-v15-xl-sft``
12
+ - LM checkpoint at ``./checkpoints/acestep-5Hz-lm-0.6B/`` (~1.4 GB) — download via
13
+ ``hf download ACE-Step/acestep-5Hz-lm-0.6B --local-dir checkpoints/acestep-5Hz-lm-0.6B``
14
  - A real MPS / CUDA device — CPU inference is functionally untested
15
  """
16
 
17
  from __future__ import annotations
18
 
 
19
  from pathlib import Path
20
 
21
  import pytest
 
23
  pytestmark = pytest.mark.gpu
24
 
25
 
26
+ def test_generate_minimum_song():
27
+ """Smallest end-to-end: 10 s instrumental drone, seed=1, 16 diffusion steps.
 
28
 
29
+ Asserts the pipeline produces a non-empty audio file. Wall time on
30
+ cold start (handlers + weight loading) should be < 5 min on M5 Max
31
+ with checkpoints pre-downloaded; subsequent calls in the same process
32
+ are bounded by the diffusion compute itself (~10-30 s for these settings).
33
+ """
34
  from backend import ACEStepStudioBackend
35
 
36
  b = ACEStepStudioBackend()
37
  out_path, meta = b.dispatch(
38
  mode="generate",
39
  params={
40
+ "prompt": "ambient drone, sine pad, slow swell",
41
+ "lyrics": "",
42
+ "duration_s": 10,
43
  "instrumental": True,
44
  "seed": 1,
45
  "loras": [],
46
+ # Tune for smoke speed: fewer steps, lower CFG, skip LM CoT
47
+ "advanced": {"steps": 16, "cfg": 3.0, "audio_format": "wav"},
48
+ "lm": {"thinking": False},
49
  "dcw": {},
50
  },
51
  )
52
+
53
+ p = Path(out_path)
54
+ assert p.exists(), f"generated file missing: {out_path}"
55
+ assert p.stat().st_size > 0, "generated file is empty"
56
  assert meta["mode"] == "generate"
57
  assert meta["seed"] == 1
 
 
58
  assert meta["wall_seconds"] > 0