Spaces:
Running on Zero
refactor(pipeline): rewrite for real acestep AceStepHandler+LLMHandler api
Browse filesThe original ace_pipeline assumed ``ACEStepPipeline.from_pretrained()`` —
a clean library entry point that does NOT exist in the installed
``acestep`` package (apple-silicon fork on Mac, upstream on CUDA). The
real API is a split-handler pattern:
from acestep.handler import AceStepHandler
from acestep.llm_inference import LLMHandler
from acestep.inference import GenerationParams, GenerationConfig, generate_music
dit = AceStepHandler(); dit.initialize_service(project_root, config_path, device)
lm = LLMHandler(); lm.initialize(checkpoint_dir, lm_model_path, backend, device)
result = generate_music(dit, lm, GenerationParams(...), GenerationConfig(...))
To keep ``modes.py`` and ``backend.py`` clean, ``ace_pipeline`` now
exposes a single ``ACEStepStudio`` wrapper that owns both handlers and
exposes ``generate(params: dict) -> str`` returning the audio path.
Defaults:
- DiT: ACE-Step/acestep-v15-xl-sft (~16 GB) → ``./checkpoints/acestep-v15-xl-sft/``
- LM: ACE-Step/acestep-5Hz-lm-0.6B (~1.4 GB) → ``./checkpoints/acestep-5Hz-lm-0.6B/``
The fork auto-routes ``backend='vllm'`` to ``mlx`` on ``device='mps'``
when mlx-lm is installed, so the same code path works on Mac and CUDA.
Tests updated to mock the wrapper interface: ``pipe.generate(params)``
instead of ``pipe(...)``. 17/17 L1+L2 pass; the GPU smoke (deselected
by default) exercises the real pipeline once checkpoints are downloaded.
Closes spec §14.1 open question (canonical ACE-Step Python API).
- ace_pipeline.py +173 -27
- backend.py +11 -8
- tests/test_ace_pipeline_lazy.py +147 -20
- tests/test_backend.py +17 -7
- tests/test_smoke_gpu.py +22 -15
|
@@ -1,12 +1,49 @@
|
|
| 1 |
"""ACE-Step pipeline lifecycle: device autodetect, lazy load, cache mirror.
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def detect_device() -> str:
|
|
@@ -27,9 +64,9 @@ def detect_device() -> str:
|
|
| 27 |
def vram_limit_for(device: str) -> int | None:
|
| 28 |
"""Returns a VRAM cap in bytes for CUDA, None otherwise.
|
| 29 |
|
| 30 |
-
`torch.mps` has no `mem_get_info` — calling DiffSynth-style
|
| 31 |
-
gates with a numeric limit would crash on MPS. Returning
|
| 32 |
-
pipeline short-circuit those checks.
|
| 33 |
"""
|
| 34 |
if device != "cuda":
|
| 35 |
return None
|
|
@@ -43,30 +80,139 @@ def vram_limit_for(device: str) -> int | None:
|
|
| 43 |
return None
|
| 44 |
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def _load_pipeline(device: str, model_path: str):
|
| 51 |
-
"""Construct the ACE-Step pipeline. Heavy import is local so unit tests can mock."""
|
| 52 |
-
from ace_step import ACEStepPipeline # type: ignore[import-not-found]
|
| 53 |
-
|
| 54 |
-
# On Mac, the apple-silicon fork sets dtype + backend automatically.
|
| 55 |
-
# On CUDA we pass bf16 explicitly.
|
| 56 |
-
if device == "cuda":
|
| 57 |
-
pipe = ACEStepPipeline.from_pretrained(model_path, torch_dtype="bf16")
|
| 58 |
-
else:
|
| 59 |
-
pipe = ACEStepPipeline.from_pretrained(model_path)
|
| 60 |
-
|
| 61 |
-
pipe.to(device)
|
| 62 |
-
return pipe
|
| 63 |
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
def
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
global _PIPELINE
|
| 68 |
if _PIPELINE is None:
|
| 69 |
-
|
| 70 |
-
model_path = os.environ.get("ACE_MODEL_PATH", _DEFAULT_MODEL_ID)
|
| 71 |
-
_PIPELINE = _load_pipeline(device, model_path)
|
| 72 |
return _PIPELINE
|
|
|
|
| 1 |
"""ACE-Step pipeline lifecycle: device autodetect, lazy load, cache mirror.
|
| 2 |
|
| 3 |
+
The installed ``acestep`` package (apple-silicon fork on Mac, upstream on
|
| 4 |
+
CUDA) does NOT expose a single ``ACEStepPipeline.from_pretrained`` entry
|
| 5 |
+
point. The real API is a split-handler pattern:
|
| 6 |
+
|
| 7 |
+
from acestep.handler import AceStepHandler # DiT side
|
| 8 |
+
from acestep.llm_inference import LLMHandler # 5Hz LM planner
|
| 9 |
+
from acestep.inference import (
|
| 10 |
+
GenerationParams, GenerationConfig, generate_music,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
dit = AceStepHandler()
|
| 14 |
+
dit.initialize_service(project_root=..., config_path="acestep-v15-xl-sft",
|
| 15 |
+
device="mps")
|
| 16 |
+
lm = LLMHandler()
|
| 17 |
+
lm.initialize(checkpoint_dir=..., lm_model_path="acestep-5Hz-lm-0.6B",
|
| 18 |
+
backend="vllm", # auto-routes to mlx on mps
|
| 19 |
+
device="mps")
|
| 20 |
+
params = GenerationParams(caption=..., lyrics=..., duration=..., seed=...)
|
| 21 |
+
cfg = GenerationConfig(batch_size=1, audio_format="wav")
|
| 22 |
+
result = generate_music(dit, lm, params, cfg)
|
| 23 |
+
# result.audios[0]["path"] is the WAV file
|
| 24 |
+
|
| 25 |
+
To keep ``backend.py`` and ``modes.py`` clean, this module exposes a
|
| 26 |
+
single ``ACEStepStudio`` wrapper that owns both handlers and exposes a
|
| 27 |
+
``generate(params: dict) -> str`` method returning the audio path.
|
| 28 |
+
``get_pipeline()`` returns the lazy singleton wrapper.
|
| 29 |
+
|
| 30 |
+
Checkpoints live under ``{project_root}/checkpoints/{config_path}/``.
|
| 31 |
+
On Mac with the apple-silicon fork, the fork auto-downloads from
|
| 32 |
+
HuggingFace if a checkpoint is missing, but in practice we pre-download
|
| 33 |
+
via ``hf download`` before the first inference call to avoid pytest
|
| 34 |
+
timeouts.
|
| 35 |
"""
|
| 36 |
|
| 37 |
from __future__ import annotations
|
| 38 |
|
| 39 |
import os
|
| 40 |
+
from pathlib import Path
|
| 41 |
+
|
| 42 |
+
_REPO_ROOT = Path(__file__).resolve().parent
|
| 43 |
+
_CHECKPOINTS_DIR = _REPO_ROOT / "checkpoints"
|
| 44 |
+
|
| 45 |
+
_DEFAULT_DIT_CONFIG = "acestep-v15-xl-sft"
|
| 46 |
+
_DEFAULT_LM_MODEL = "acestep-5Hz-lm-0.6B"
|
| 47 |
|
| 48 |
|
| 49 |
def detect_device() -> str:
|
|
|
|
| 64 |
def vram_limit_for(device: str) -> int | None:
|
| 65 |
"""Returns a VRAM cap in bytes for CUDA, None otherwise.
|
| 66 |
|
| 67 |
+
``torch.mps`` has no ``mem_get_info`` — calling DiffSynth-style
|
| 68 |
+
free-VRAM gates with a numeric limit would crash on MPS. Returning
|
| 69 |
+
None lets the pipeline short-circuit those checks.
|
| 70 |
"""
|
| 71 |
if device != "cuda":
|
| 72 |
return None
|
|
|
|
| 80 |
return None
|
| 81 |
|
| 82 |
|
| 83 |
+
class ACEStepStudio:
|
| 84 |
+
"""Wrapper around the apple-silicon fork's split-handler API.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
+
Owns one ``AceStepHandler`` (DiT) and one ``LLMHandler`` (5Hz LM
|
| 87 |
+
planner). Both are lazy-loaded on the first ``generate(...)`` call.
|
| 88 |
+
"""
|
| 89 |
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
dit_config: str | None = None,
|
| 93 |
+
lm_model: str | None = None,
|
| 94 |
+
device: str | None = None,
|
| 95 |
+
) -> None:
|
| 96 |
+
self._dit = None
|
| 97 |
+
self._llm = None
|
| 98 |
+
self._dit_config = dit_config or os.environ.get("ACE_DIT_CONFIG", _DEFAULT_DIT_CONFIG)
|
| 99 |
+
self._lm_model = lm_model or os.environ.get("ACE_LM_MODEL", _DEFAULT_LM_MODEL)
|
| 100 |
+
self._device = device or detect_device()
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def device(self) -> str:
|
| 104 |
+
return self._device
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def is_loaded(self) -> bool:
|
| 108 |
+
return self._dit is not None and self._llm is not None
|
| 109 |
+
|
| 110 |
+
def _ensure_loaded(self) -> None:
|
| 111 |
+
"""First-call lazy load of both handlers. Heavy imports stay local."""
|
| 112 |
+
if self.is_loaded:
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
from acestep.handler import AceStepHandler
|
| 116 |
+
from acestep.llm_inference import LLMHandler
|
| 117 |
+
|
| 118 |
+
dit = AceStepHandler()
|
| 119 |
+
dit.initialize_service(
|
| 120 |
+
project_root=str(_REPO_ROOT),
|
| 121 |
+
config_path=self._dit_config,
|
| 122 |
+
device=self._device,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
llm = LLMHandler()
|
| 126 |
+
llm.initialize(
|
| 127 |
+
checkpoint_dir=str(_CHECKPOINTS_DIR),
|
| 128 |
+
lm_model_path=self._lm_model,
|
| 129 |
+
backend="vllm", # fork auto-routes to mlx on mps + mlx-lm installed
|
| 130 |
+
device=self._device,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
self._dit = dit
|
| 134 |
+
self._llm = llm
|
| 135 |
+
|
| 136 |
+
def generate(self, params: dict) -> str:
|
| 137 |
+
"""Run a single text→song generation.
|
| 138 |
+
|
| 139 |
+
``params`` is the dict produced by ``modes.generate``:
|
| 140 |
+
``{"prompt", "lyrics", "duration_s", "instrumental", "seed",
|
| 141 |
+
"loras", "advanced", "lm", "dcw"}``. Returns the path to the
|
| 142 |
+
produced audio file.
|
| 143 |
+
"""
|
| 144 |
+
self._ensure_loaded()
|
| 145 |
+
|
| 146 |
+
from acestep.inference import (
|
| 147 |
+
GenerationConfig,
|
| 148 |
+
GenerationParams,
|
| 149 |
+
generate_music,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
advanced = params.get("advanced", {}) or {}
|
| 153 |
+
lm_opts = params.get("lm", {}) or {}
|
| 154 |
+
|
| 155 |
+
# Map our internal dict to ACE-Step's GenerationParams.
|
| 156 |
+
# Lyrics "[Instrumental]" is the ACE-Step convention for instrumental.
|
| 157 |
+
lyrics = params.get("lyrics", "") or ""
|
| 158 |
+
instrumental = bool(params.get("instrumental", False))
|
| 159 |
+
if instrumental and not lyrics:
|
| 160 |
+
lyrics = "[Instrumental]"
|
| 161 |
+
|
| 162 |
+
gen_params = GenerationParams(
|
| 163 |
+
task_type="text2music",
|
| 164 |
+
caption=params.get("prompt", ""),
|
| 165 |
+
lyrics=lyrics,
|
| 166 |
+
instrumental=instrumental,
|
| 167 |
+
duration=int(params.get("duration_s", 30)),
|
| 168 |
+
seed=int(params.get("seed", -1)),
|
| 169 |
+
inference_steps=int(advanced.get("steps", 32)),
|
| 170 |
+
guidance_scale=float(advanced.get("cfg", 4.0)),
|
| 171 |
+
shift=float(advanced.get("shift", 1.0)),
|
| 172 |
+
bpm=advanced.get("bpm"),
|
| 173 |
+
keyscale=advanced.get("keyscale", ""),
|
| 174 |
+
timesignature=advanced.get("timesignature", ""),
|
| 175 |
+
vocal_language=advanced.get("vocal_language", "unknown"),
|
| 176 |
+
cfg_interval_start=float(advanced.get("cfg_interval_start", 0.0)),
|
| 177 |
+
cfg_interval_end=float(advanced.get("cfg_interval_end", 1.0)),
|
| 178 |
+
thinking=bool(lm_opts.get("thinking", False)),
|
| 179 |
+
lm_temperature=float(lm_opts.get("temperature", 0.85)),
|
| 180 |
+
lm_cfg_scale=float(lm_opts.get("cfg", 2.0)),
|
| 181 |
+
lm_top_k=int(lm_opts.get("top_k", 0)),
|
| 182 |
+
lm_top_p=float(lm_opts.get("top_p", 0.9)),
|
| 183 |
+
lm_negative_prompt=lm_opts.get("negative_prompt", ""),
|
| 184 |
+
use_cot_metas=bool(lm_opts.get("cot_metas", False)),
|
| 185 |
+
use_cot_caption=bool(lm_opts.get("cot_caption", False)),
|
| 186 |
+
use_cot_language=bool(lm_opts.get("cot_language", False)),
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
gen_config = GenerationConfig(
|
| 190 |
+
batch_size=1,
|
| 191 |
+
audio_format=advanced.get("audio_format", "wav"),
|
| 192 |
+
use_random_seed=False,
|
| 193 |
+
seeds=[int(params.get("seed", 1))],
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
result = generate_music(self._dit, self._llm, gen_params, gen_config)
|
| 197 |
+
|
| 198 |
+
if not result.success:
|
| 199 |
+
raise RuntimeError(f"ACE-Step generation failed: {result.error}")
|
| 200 |
+
if not result.audios:
|
| 201 |
+
raise RuntimeError("ACE-Step returned no audio outputs")
|
| 202 |
+
|
| 203 |
+
return result.audios[0]["path"]
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
_PIPELINE: ACEStepStudio | None = None # module-level lazy singleton
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def get_pipeline() -> ACEStepStudio:
|
| 210 |
+
"""Lazy-construct the ACE Music Studio wrapper.
|
| 211 |
+
|
| 212 |
+
The wrapper itself is cheap to construct; both handlers (DiT, LM)
|
| 213 |
+
are only loaded on the first ``generate(...)`` call.
|
| 214 |
+
"""
|
| 215 |
global _PIPELINE
|
| 216 |
if _PIPELINE is None:
|
| 217 |
+
_PIPELINE = ACEStepStudio()
|
|
|
|
|
|
|
| 218 |
return _PIPELINE
|
|
@@ -68,14 +68,17 @@ class ACEStepStudioBackend:
|
|
| 68 |
return out_path, meta
|
| 69 |
|
| 70 |
def _call_pipe_for_mode(self, pipe, mode: str, params: dict[str, Any]) -> str:
|
| 71 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
if mode == "generate":
|
| 73 |
-
return pipe(
|
| 74 |
-
prompt=params["prompt"],
|
| 75 |
-
lyrics=params.get("lyrics", ""),
|
| 76 |
-
duration_s=params["duration_s"],
|
| 77 |
-
instrumental=params.get("instrumental", False),
|
| 78 |
-
seed=params["seed"],
|
| 79 |
-
)
|
| 80 |
# cover / extend / edit / lyrics get filled in at M3 / M4
|
| 81 |
raise NotImplementedError(f"Mode {mode!r} is not wired yet")
|
|
|
|
| 68 |
return out_path, meta
|
| 69 |
|
| 70 |
def _call_pipe_for_mode(self, pipe, mode: str, params: dict[str, Any]) -> str:
|
| 71 |
+
"""Dispatch to the pipeline wrapper.
|
| 72 |
+
|
| 73 |
+
``pipe`` is the ``ACEStepStudio`` wrapper returned by
|
| 74 |
+
``ace_pipeline.get_pipeline()``. It exposes a single
|
| 75 |
+
``generate(params)`` method that handles the underlying
|
| 76 |
+
AceStepHandler + LLMHandler + generate_music plumbing.
|
| 77 |
+
|
| 78 |
+
Cover / Extend / Edit / Lyrics task_types are mapped here at
|
| 79 |
+
M3 / M4 by switching ``params["task_type"]`` before calling.
|
| 80 |
+
"""
|
| 81 |
if mode == "generate":
|
| 82 |
+
return pipe.generate(params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
# cover / extend / edit / lyrics get filled in at M3 / M4
|
| 84 |
raise NotImplementedError(f"Mode {mode!r} is not wired yet")
|
|
@@ -1,39 +1,166 @@
|
|
| 1 |
-
"""L2 tests for
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 5 |
from unittest.mock import MagicMock
|
| 6 |
|
|
|
|
|
|
|
| 7 |
import ace_pipeline as ap
|
| 8 |
|
| 9 |
|
| 10 |
-
def
|
| 11 |
-
fake_pipe = MagicMock(name="fake_ace_pipeline")
|
| 12 |
-
loader = MagicMock(return_value=fake_pipe)
|
| 13 |
-
monkeypatch.setattr(ap, "_load_pipeline", loader)
|
| 14 |
monkeypatch.setattr(ap, "_PIPELINE", None, raising=False)
|
| 15 |
-
|
| 16 |
p1 = ap.get_pipeline()
|
| 17 |
p2 = ap.get_pipeline()
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
assert p1 is fake_pipe
|
| 20 |
-
assert p2 is fake_pipe
|
| 21 |
-
assert loader.call_count == 1, "pipeline should load exactly once"
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
monkeypatch.setattr(ap, "detect_device", lambda: "mps")
|
| 27 |
-
captured = {}
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
assert captured["
|
|
|
|
|
|
| 1 |
+
"""L2 tests for the ACEStepStudio wrapper — mocks the heavy acestep imports."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
+
import sys
|
| 6 |
from unittest.mock import MagicMock
|
| 7 |
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
import ace_pipeline as ap
|
| 11 |
|
| 12 |
|
| 13 |
+
def test_get_pipeline_returns_singleton(monkeypatch):
|
|
|
|
|
|
|
|
|
|
| 14 |
monkeypatch.setattr(ap, "_PIPELINE", None, raising=False)
|
|
|
|
| 15 |
p1 = ap.get_pipeline()
|
| 16 |
p2 = ap.get_pipeline()
|
| 17 |
+
assert p1 is p2
|
| 18 |
+
assert isinstance(p1, ap.ACEStepStudio)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
def test_studio_constructor_uses_detected_device(monkeypatch):
|
| 22 |
+
monkeypatch.setattr(ap, "detect_device", lambda: "mps")
|
| 23 |
+
studio = ap.ACEStepStudio()
|
| 24 |
+
assert studio.device == "mps"
|
| 25 |
+
assert studio.is_loaded is False # handlers are lazy
|
| 26 |
|
| 27 |
+
|
| 28 |
+
def test_studio_constructor_respects_env_overrides(monkeypatch):
|
| 29 |
+
monkeypatch.setenv("ACE_DIT_CONFIG", "custom-dit")
|
| 30 |
+
monkeypatch.setenv("ACE_LM_MODEL", "custom-lm")
|
| 31 |
+
monkeypatch.setattr(ap, "detect_device", lambda: "cpu")
|
| 32 |
+
studio = ap.ACEStepStudio()
|
| 33 |
+
assert studio._dit_config == "custom-dit"
|
| 34 |
+
assert studio._lm_model == "custom-lm"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_studio_ensure_loaded_constructs_both_handlers(monkeypatch):
|
| 38 |
+
fake_dit_cls = MagicMock(name="AceStepHandler")
|
| 39 |
+
fake_lm_cls = MagicMock(name="LLMHandler")
|
| 40 |
+
fake_dit = MagicMock()
|
| 41 |
+
fake_lm = MagicMock()
|
| 42 |
+
fake_dit_cls.return_value = fake_dit
|
| 43 |
+
fake_lm_cls.return_value = fake_lm
|
| 44 |
+
|
| 45 |
+
handler_mod = MagicMock()
|
| 46 |
+
handler_mod.AceStepHandler = fake_dit_cls
|
| 47 |
+
llm_mod = MagicMock()
|
| 48 |
+
llm_mod.LLMHandler = fake_lm_cls
|
| 49 |
+
|
| 50 |
+
monkeypatch.setitem(sys.modules, "acestep.handler", handler_mod)
|
| 51 |
+
monkeypatch.setitem(sys.modules, "acestep.llm_inference", llm_mod)
|
| 52 |
monkeypatch.setattr(ap, "detect_device", lambda: "mps")
|
|
|
|
| 53 |
|
| 54 |
+
studio = ap.ACEStepStudio()
|
| 55 |
+
studio._ensure_loaded()
|
| 56 |
+
|
| 57 |
+
fake_dit_cls.assert_called_once()
|
| 58 |
+
fake_lm_cls.assert_called_once()
|
| 59 |
+
fake_dit.initialize_service.assert_called_once()
|
| 60 |
+
fake_lm.initialize.assert_called_once()
|
| 61 |
+
assert fake_dit.initialize_service.call_args.kwargs["device"] == "mps"
|
| 62 |
+
assert fake_lm.initialize.call_args.kwargs["device"] == "mps"
|
| 63 |
+
assert fake_dit.initialize_service.call_args.kwargs["config_path"] == "acestep-v15-xl-sft"
|
| 64 |
+
assert fake_lm.initialize.call_args.kwargs["lm_model_path"] == "acestep-5Hz-lm-0.6B"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _install_fake_inference(monkeypatch, success=True, audios=None, error=None):
|
| 68 |
+
"""Plant a fake ``acestep.inference`` module and return the spies."""
|
| 69 |
+
if audios is None:
|
| 70 |
+
audios = [{"path": "/tmp/x.wav"}]
|
| 71 |
+
fake_result = MagicMock(success=success, audios=audios, error=error)
|
| 72 |
+
fake_generate = MagicMock(return_value=fake_result)
|
| 73 |
+
captured = {"gp": {}, "gc": {}}
|
| 74 |
+
|
| 75 |
+
def fake_gp(**kw):
|
| 76 |
+
captured["gp"] = kw
|
| 77 |
+
return kw
|
| 78 |
+
|
| 79 |
+
def fake_gc(**kw):
|
| 80 |
+
captured["gc"] = kw
|
| 81 |
+
return kw
|
| 82 |
+
|
| 83 |
+
fake_inference = MagicMock()
|
| 84 |
+
fake_inference.generate_music = fake_generate
|
| 85 |
+
fake_inference.GenerationParams = MagicMock(side_effect=fake_gp)
|
| 86 |
+
fake_inference.GenerationConfig = MagicMock(side_effect=fake_gc)
|
| 87 |
+
monkeypatch.setitem(sys.modules, "acestep.inference", fake_inference)
|
| 88 |
+
return fake_generate, captured
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def test_studio_generate_builds_params_and_calls_generate_music(monkeypatch, tmp_path):
|
| 92 |
+
out_wav = tmp_path / "out.wav"
|
| 93 |
+
out_wav.write_bytes(b"RIFF" + b"\0" * 100)
|
| 94 |
+
|
| 95 |
+
fake_generate, captured = _install_fake_inference(monkeypatch, audios=[{"path": str(out_wav)}])
|
| 96 |
+
|
| 97 |
+
studio = ap.ACEStepStudio()
|
| 98 |
+
studio._dit = MagicMock(name="dit")
|
| 99 |
+
studio._llm = MagicMock(name="llm")
|
| 100 |
+
|
| 101 |
+
result_path = studio.generate(
|
| 102 |
+
{
|
| 103 |
+
"prompt": "psytrance",
|
| 104 |
+
"lyrics": "[verse]",
|
| 105 |
+
"duration_s": 30,
|
| 106 |
+
"instrumental": False,
|
| 107 |
+
"seed": 42,
|
| 108 |
+
"loras": [],
|
| 109 |
+
"advanced": {"steps": 32, "cfg": 4.0, "bpm": 135},
|
| 110 |
+
"lm": {"thinking": False},
|
| 111 |
+
"dcw": {},
|
| 112 |
+
}
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
assert result_path == str(out_wav)
|
| 116 |
+
fake_generate.assert_called_once()
|
| 117 |
+
assert captured["gp"]["caption"] == "psytrance"
|
| 118 |
+
assert captured["gp"]["duration"] == 30
|
| 119 |
+
assert captured["gp"]["seed"] == 42
|
| 120 |
+
assert captured["gp"]["inference_steps"] == 32
|
| 121 |
+
assert captured["gp"]["bpm"] == 135
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def test_studio_generate_raises_on_failure(monkeypatch):
|
| 125 |
+
_install_fake_inference(monkeypatch, success=False, audios=[], error="OOM")
|
| 126 |
+
studio = ap.ACEStepStudio()
|
| 127 |
+
studio._dit = MagicMock()
|
| 128 |
+
studio._llm = MagicMock()
|
| 129 |
+
|
| 130 |
+
with pytest.raises(RuntimeError, match="OOM"):
|
| 131 |
+
studio.generate(
|
| 132 |
+
{
|
| 133 |
+
"prompt": "p",
|
| 134 |
+
"lyrics": "",
|
| 135 |
+
"duration_s": 5,
|
| 136 |
+
"instrumental": True,
|
| 137 |
+
"seed": 1,
|
| 138 |
+
"advanced": {},
|
| 139 |
+
"lm": {},
|
| 140 |
+
"dcw": {},
|
| 141 |
+
}
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
|
| 145 |
+
def test_studio_generate_uses_instrumental_marker_when_lyrics_empty(monkeypatch):
|
| 146 |
+
_fake_generate, captured = _install_fake_inference(monkeypatch)
|
| 147 |
+
studio = ap.ACEStepStudio()
|
| 148 |
+
studio._dit = MagicMock()
|
| 149 |
+
studio._llm = MagicMock()
|
| 150 |
|
| 151 |
+
studio.generate(
|
| 152 |
+
{
|
| 153 |
+
"prompt": "drone",
|
| 154 |
+
"lyrics": "",
|
| 155 |
+
"duration_s": 5,
|
| 156 |
+
"instrumental": True,
|
| 157 |
+
"seed": 1,
|
| 158 |
+
"advanced": {},
|
| 159 |
+
"lm": {},
|
| 160 |
+
"dcw": {},
|
| 161 |
+
}
|
| 162 |
+
)
|
| 163 |
|
| 164 |
+
# Instrumental + empty lyrics → ACE-Step convention is "[Instrumental]"
|
| 165 |
+
assert captured["gp"]["lyrics"] == "[Instrumental]"
|
| 166 |
+
assert captured["gp"]["instrumental"] is True
|
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""L2 tests for backend.dispatch — pipeline is mocked at the boundary."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
@@ -7,12 +7,13 @@ from unittest.mock import MagicMock
|
|
| 7 |
import backend as be
|
| 8 |
|
| 9 |
|
| 10 |
-
def
|
| 11 |
-
|
| 12 |
fake_out = tmp_path / "out.wav"
|
| 13 |
fake_out.write_bytes(b"RIFF" + b"\0" * 1000)
|
| 14 |
-
fake_pipe.return_value = str(fake_out)
|
| 15 |
|
|
|
|
|
|
|
| 16 |
monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
|
| 17 |
|
| 18 |
b = be.ACEStepStudioBackend()
|
|
@@ -34,13 +35,19 @@ def test_dispatch_generate_calls_pipeline_with_expected_kwargs(monkeypatch, tmp_
|
|
| 34 |
assert out_path == str(fake_out)
|
| 35 |
assert meta["mode"] == "generate"
|
| 36 |
assert meta["seed"] == 42
|
| 37 |
-
fake_pipe.assert_called_once()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
def test_dispatch_random_seed_if_zero(monkeypatch, tmp_path):
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
|
| 43 |
-
(tmp_path / "x.wav").write_bytes(b"RIFF")
|
| 44 |
|
| 45 |
b = be.ACEStepStudioBackend()
|
| 46 |
_, meta = b.dispatch(
|
|
@@ -59,3 +66,6 @@ def test_dispatch_random_seed_if_zero(monkeypatch, tmp_path):
|
|
| 59 |
)
|
| 60 |
|
| 61 |
assert 1 <= meta["seed"] <= 2_147_483_647
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""L2 tests for backend.dispatch — pipeline is mocked at the wrapper boundary."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 7 |
import backend as be
|
| 8 |
|
| 9 |
|
| 10 |
+
def test_dispatch_generate_calls_pipeline_generate(monkeypatch, tmp_path):
|
| 11 |
+
"""Backend should call ``pipe.generate(params)`` and return its path."""
|
| 12 |
fake_out = tmp_path / "out.wav"
|
| 13 |
fake_out.write_bytes(b"RIFF" + b"\0" * 1000)
|
|
|
|
| 14 |
|
| 15 |
+
fake_pipe = MagicMock()
|
| 16 |
+
fake_pipe.generate.return_value = str(fake_out)
|
| 17 |
monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
|
| 18 |
|
| 19 |
b = be.ACEStepStudioBackend()
|
|
|
|
| 35 |
assert out_path == str(fake_out)
|
| 36 |
assert meta["mode"] == "generate"
|
| 37 |
assert meta["seed"] == 42
|
| 38 |
+
fake_pipe.generate.assert_called_once()
|
| 39 |
+
# The full params dict is forwarded to pipe.generate
|
| 40 |
+
sent_params = fake_pipe.generate.call_args.args[0]
|
| 41 |
+
assert sent_params["prompt"] == "psytrance"
|
| 42 |
+
assert sent_params["seed"] == 42
|
| 43 |
|
| 44 |
|
| 45 |
def test_dispatch_random_seed_if_zero(monkeypatch, tmp_path):
|
| 46 |
+
out = tmp_path / "x.wav"
|
| 47 |
+
out.write_bytes(b"RIFF")
|
| 48 |
+
fake_pipe = MagicMock()
|
| 49 |
+
fake_pipe.generate.return_value = str(out)
|
| 50 |
monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
|
|
|
|
| 51 |
|
| 52 |
b = be.ACEStepStudioBackend()
|
| 53 |
_, meta = b.dispatch(
|
|
|
|
| 66 |
)
|
| 67 |
|
| 68 |
assert 1 <= meta["seed"] <= 2_147_483_647
|
| 69 |
+
# The seed-resolved value is the one forwarded to the wrapper
|
| 70 |
+
sent_params = fake_pipe.generate.call_args.args[0]
|
| 71 |
+
assert sent_params["seed"] == meta["seed"]
|
|
@@ -6,14 +6,16 @@ pipeline. Run before each release tag.
|
|
| 6 |
Skipped automatically in CI by the pyproject ``addopts = -m 'not gpu'``
|
| 7 |
default. Requires:
|
| 8 |
|
| 9 |
-
- ``
|
| 10 |
-
-
|
|
|
|
|
|
|
|
|
|
| 11 |
- A real MPS / CUDA device — CPU inference is functionally untested
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
| 16 |
-
import os
|
| 17 |
from pathlib import Path
|
| 18 |
|
| 19 |
import pytest
|
|
@@ -21,31 +23,36 @@ import pytest
|
|
| 21 |
pytestmark = pytest.mark.gpu
|
| 22 |
|
| 23 |
|
| 24 |
-
def test_generate_minimum_song(
|
| 25 |
-
"""Smallest end-to-end:
|
| 26 |
-
os.environ.setdefault("ACE_MODEL_PATH", "ACE-Step/acestep-v15-xl-sft")
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
from backend import ACEStepStudioBackend
|
| 29 |
|
| 30 |
b = ACEStepStudioBackend()
|
| 31 |
out_path, meta = b.dispatch(
|
| 32 |
mode="generate",
|
| 33 |
params={
|
| 34 |
-
"prompt": "
|
| 35 |
-
"lyrics": "
|
| 36 |
-
"duration_s":
|
| 37 |
"instrumental": True,
|
| 38 |
"seed": 1,
|
| 39 |
"loras": [],
|
| 40 |
-
|
| 41 |
-
"
|
|
|
|
| 42 |
"dcw": {},
|
| 43 |
},
|
| 44 |
)
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
| 47 |
assert meta["mode"] == "generate"
|
| 48 |
assert meta["seed"] == 1
|
| 49 |
-
# Wall time should be < 5 min even on first cold run + 16 GB weight download.
|
| 50 |
-
# Subsequent runs should be < 30 s on M5 Max.
|
| 51 |
assert meta["wall_seconds"] > 0
|
|
|
|
| 6 |
Skipped automatically in CI by the pyproject ``addopts = -m 'not gpu'``
|
| 7 |
default. Requires:
|
| 8 |
|
| 9 |
+
- ``acestep`` package installed (Apple Silicon fork on Mac, upstream on CUDA)
|
| 10 |
+
- DiT checkpoint at ``./checkpoints/acestep-v15-xl-sft/`` (~16 GB) — download via
|
| 11 |
+
``hf download ACE-Step/acestep-v15-xl-sft --local-dir checkpoints/acestep-v15-xl-sft``
|
| 12 |
+
- LM checkpoint at ``./checkpoints/acestep-5Hz-lm-0.6B/`` (~1.4 GB) — download via
|
| 13 |
+
``hf download ACE-Step/acestep-5Hz-lm-0.6B --local-dir checkpoints/acestep-5Hz-lm-0.6B``
|
| 14 |
- A real MPS / CUDA device — CPU inference is functionally untested
|
| 15 |
"""
|
| 16 |
|
| 17 |
from __future__ import annotations
|
| 18 |
|
|
|
|
| 19 |
from pathlib import Path
|
| 20 |
|
| 21 |
import pytest
|
|
|
|
| 23 |
pytestmark = pytest.mark.gpu
|
| 24 |
|
| 25 |
|
| 26 |
+
def test_generate_minimum_song():
|
| 27 |
+
"""Smallest end-to-end: 10 s instrumental drone, seed=1, 16 diffusion steps.
|
|
|
|
| 28 |
|
| 29 |
+
Asserts the pipeline produces a non-empty audio file. Wall time on
|
| 30 |
+
cold start (handlers + weight loading) should be < 5 min on M5 Max
|
| 31 |
+
with checkpoints pre-downloaded; subsequent calls in the same process
|
| 32 |
+
are bounded by the diffusion compute itself (~10-30 s for these settings).
|
| 33 |
+
"""
|
| 34 |
from backend import ACEStepStudioBackend
|
| 35 |
|
| 36 |
b = ACEStepStudioBackend()
|
| 37 |
out_path, meta = b.dispatch(
|
| 38 |
mode="generate",
|
| 39 |
params={
|
| 40 |
+
"prompt": "ambient drone, sine pad, slow swell",
|
| 41 |
+
"lyrics": "",
|
| 42 |
+
"duration_s": 10,
|
| 43 |
"instrumental": True,
|
| 44 |
"seed": 1,
|
| 45 |
"loras": [],
|
| 46 |
+
# Tune for smoke speed: fewer steps, lower CFG, skip LM CoT
|
| 47 |
+
"advanced": {"steps": 16, "cfg": 3.0, "audio_format": "wav"},
|
| 48 |
+
"lm": {"thinking": False},
|
| 49 |
"dcw": {},
|
| 50 |
},
|
| 51 |
)
|
| 52 |
+
|
| 53 |
+
p = Path(out_path)
|
| 54 |
+
assert p.exists(), f"generated file missing: {out_path}"
|
| 55 |
+
assert p.stat().st_size > 0, "generated file is empty"
|
| 56 |
assert meta["mode"] == "generate"
|
| 57 |
assert meta["seed"] == 1
|
|
|
|
|
|
|
| 58 |
assert meta["wall_seconds"] > 0
|