Spaces:
Running on Zero
Running on Zero
feat(backend): add ace-step studio backend with dispatch + zerogpu wrap
Browse files- backend.py +81 -0
- tests/test_backend.py +61 -0
backend.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ACEStepStudioBackend — dispatch + ZeroGPU lifetime + duration estimator.
|
| 2 |
+
|
| 3 |
+
Off Spaces, @spaces.GPU is a no-op identity decorator (`spaces` may not be
|
| 4 |
+
installed locally). On Spaces, the HF runtime injects it at startup and
|
| 5 |
+
the decorator applies for real.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import random
|
| 11 |
+
import time
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import spaces # type: ignore[import-not-found]
|
| 16 |
+
|
| 17 |
+
_HAS_SPACES = True
|
| 18 |
+
except ImportError: # pragma: no cover - covered by manual local testing
|
| 19 |
+
spaces = None
|
| 20 |
+
_HAS_SPACES = False
|
| 21 |
+
|
| 22 |
+
import ace_pipeline as ap
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _maybe_seed(seed: int | None) -> int:
|
| 26 |
+
if seed and int(seed) > 0:
|
| 27 |
+
return int(seed)
|
| 28 |
+
return random.randint(1, 2_147_483_647)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _duration_estimate(mode: str, params: dict[str, Any]) -> int:
|
| 32 |
+
"""ZeroGPU per-call duration cap, clamped [60, 180] s."""
|
| 33 |
+
base = 60
|
| 34 |
+
duration_s = int(params.get("duration_s", 30) or 30)
|
| 35 |
+
if duration_s > 60:
|
| 36 |
+
base = 90
|
| 37 |
+
if duration_s > 120:
|
| 38 |
+
base = 120
|
| 39 |
+
if mode == "edit":
|
| 40 |
+
base = max(base, 90)
|
| 41 |
+
if mode == "extend":
|
| 42 |
+
base = max(base, 120)
|
| 43 |
+
return min(180, max(60, base))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ACEStepStudioBackend:
|
| 47 |
+
"""Lazy backend singleton. Owns @spaces.GPU and pipeline lifecycle."""
|
| 48 |
+
|
| 49 |
+
def dispatch(self, mode: str, params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
| 50 |
+
params = dict(params)
|
| 51 |
+
params["seed"] = _maybe_seed(params.get("seed"))
|
| 52 |
+
t0 = time.time()
|
| 53 |
+
pipe = ap.get_pipeline()
|
| 54 |
+
out_path = self._call_pipe_for_mode(pipe, mode, params)
|
| 55 |
+
meta = {
|
| 56 |
+
"mode": mode,
|
| 57 |
+
"seed": params["seed"],
|
| 58 |
+
"duration_s": params.get("duration_s"),
|
| 59 |
+
"wall_seconds": round(time.time() - t0, 2),
|
| 60 |
+
"estimated_duration_s": _duration_estimate(mode, params),
|
| 61 |
+
"loras": [
|
| 62 |
+
{"name": lora.get("name"), "scale": lora.get("scale"), "sha256": lora.get("sha256")}
|
| 63 |
+
for lora in params.get("loras", [])
|
| 64 |
+
],
|
| 65 |
+
"lm": params.get("lm", {}),
|
| 66 |
+
"dcw": params.get("dcw", {}),
|
| 67 |
+
}
|
| 68 |
+
return out_path, meta
|
| 69 |
+
|
| 70 |
+
def _call_pipe_for_mode(self, pipe, mode: str, params: dict[str, Any]) -> str:
|
| 71 |
+
"""Mode-specific kwargs translation. Filled out per milestone."""
|
| 72 |
+
if mode == "generate":
|
| 73 |
+
return pipe(
|
| 74 |
+
prompt=params["prompt"],
|
| 75 |
+
lyrics=params.get("lyrics", ""),
|
| 76 |
+
duration_s=params["duration_s"],
|
| 77 |
+
instrumental=params.get("instrumental", False),
|
| 78 |
+
seed=params["seed"],
|
| 79 |
+
)
|
| 80 |
+
# cover / extend / edit / lyrics get filled in at M3 / M4
|
| 81 |
+
raise NotImplementedError(f"Mode {mode!r} is not wired yet")
|
tests/test_backend.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""L2 tests for backend.dispatch — pipeline is mocked at the boundary."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from unittest.mock import MagicMock
|
| 6 |
+
|
| 7 |
+
import backend as be
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_dispatch_generate_calls_pipeline_with_expected_kwargs(monkeypatch, tmp_path):
|
| 11 |
+
fake_pipe = MagicMock()
|
| 12 |
+
fake_out = tmp_path / "out.wav"
|
| 13 |
+
fake_out.write_bytes(b"RIFF" + b"\0" * 1000)
|
| 14 |
+
fake_pipe.return_value = str(fake_out)
|
| 15 |
+
|
| 16 |
+
monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
|
| 17 |
+
|
| 18 |
+
b = be.ACEStepStudioBackend()
|
| 19 |
+
out_path, meta = b.dispatch(
|
| 20 |
+
mode="generate",
|
| 21 |
+
params={
|
| 22 |
+
"prompt": "psytrance",
|
| 23 |
+
"lyrics": "[verse]",
|
| 24 |
+
"duration_s": 10,
|
| 25 |
+
"instrumental": False,
|
| 26 |
+
"seed": 42,
|
| 27 |
+
"loras": [],
|
| 28 |
+
"advanced": {},
|
| 29 |
+
"lm": {},
|
| 30 |
+
"dcw": {},
|
| 31 |
+
},
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
assert out_path == str(fake_out)
|
| 35 |
+
assert meta["mode"] == "generate"
|
| 36 |
+
assert meta["seed"] == 42
|
| 37 |
+
fake_pipe.assert_called_once()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_dispatch_random_seed_if_zero(monkeypatch, tmp_path):
|
| 41 |
+
fake_pipe = MagicMock(return_value=str(tmp_path / "x.wav"))
|
| 42 |
+
monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
|
| 43 |
+
(tmp_path / "x.wav").write_bytes(b"RIFF")
|
| 44 |
+
|
| 45 |
+
b = be.ACEStepStudioBackend()
|
| 46 |
+
_, meta = b.dispatch(
|
| 47 |
+
mode="generate",
|
| 48 |
+
params={
|
| 49 |
+
"prompt": "p",
|
| 50 |
+
"lyrics": "",
|
| 51 |
+
"duration_s": 5,
|
| 52 |
+
"instrumental": False,
|
| 53 |
+
"seed": 0,
|
| 54 |
+
"loras": [],
|
| 55 |
+
"advanced": {},
|
| 56 |
+
"lm": {},
|
| 57 |
+
"dcw": {},
|
| 58 |
+
},
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
assert 1 <= meta["seed"] <= 2_147_483_647
|