techfreakworm commited on
Commit
dfa2ff6
·
unverified ·
1 Parent(s): 52f41b8

feat(backend): add ace-step studio backend with dispatch + zerogpu wrap

Browse files
Files changed (2) hide show
  1. backend.py +81 -0
  2. tests/test_backend.py +61 -0
backend.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ACEStepStudioBackend — dispatch + ZeroGPU lifetime + duration estimator.
2
+
3
+ Off Spaces, @spaces.GPU is a no-op identity decorator (`spaces` may not be
4
+ installed locally). On Spaces, the HF runtime injects it at startup and
5
+ the decorator applies for real.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import random
11
+ import time
12
+ from typing import Any
13
+
14
+ try:
15
+ import spaces # type: ignore[import-not-found]
16
+
17
+ _HAS_SPACES = True
18
+ except ImportError: # pragma: no cover - covered by manual local testing
19
+ spaces = None
20
+ _HAS_SPACES = False
21
+
22
+ import ace_pipeline as ap
23
+
24
+
25
+ def _maybe_seed(seed: int | None) -> int:
26
+ if seed and int(seed) > 0:
27
+ return int(seed)
28
+ return random.randint(1, 2_147_483_647)
29
+
30
+
31
+ def _duration_estimate(mode: str, params: dict[str, Any]) -> int:
32
+ """ZeroGPU per-call duration cap, clamped [60, 180] s."""
33
+ base = 60
34
+ duration_s = int(params.get("duration_s", 30) or 30)
35
+ if duration_s > 60:
36
+ base = 90
37
+ if duration_s > 120:
38
+ base = 120
39
+ if mode == "edit":
40
+ base = max(base, 90)
41
+ if mode == "extend":
42
+ base = max(base, 120)
43
+ return min(180, max(60, base))
44
+
45
+
46
+ class ACEStepStudioBackend:
47
+ """Lazy backend singleton. Owns @spaces.GPU and pipeline lifecycle."""
48
+
49
+ def dispatch(self, mode: str, params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
50
+ params = dict(params)
51
+ params["seed"] = _maybe_seed(params.get("seed"))
52
+ t0 = time.time()
53
+ pipe = ap.get_pipeline()
54
+ out_path = self._call_pipe_for_mode(pipe, mode, params)
55
+ meta = {
56
+ "mode": mode,
57
+ "seed": params["seed"],
58
+ "duration_s": params.get("duration_s"),
59
+ "wall_seconds": round(time.time() - t0, 2),
60
+ "estimated_duration_s": _duration_estimate(mode, params),
61
+ "loras": [
62
+ {"name": lora.get("name"), "scale": lora.get("scale"), "sha256": lora.get("sha256")}
63
+ for lora in params.get("loras", [])
64
+ ],
65
+ "lm": params.get("lm", {}),
66
+ "dcw": params.get("dcw", {}),
67
+ }
68
+ return out_path, meta
69
+
70
+ def _call_pipe_for_mode(self, pipe, mode: str, params: dict[str, Any]) -> str:
71
+ """Mode-specific kwargs translation. Filled out per milestone."""
72
+ if mode == "generate":
73
+ return pipe(
74
+ prompt=params["prompt"],
75
+ lyrics=params.get("lyrics", ""),
76
+ duration_s=params["duration_s"],
77
+ instrumental=params.get("instrumental", False),
78
+ seed=params["seed"],
79
+ )
80
+ # cover / extend / edit / lyrics get filled in at M3 / M4
81
+ raise NotImplementedError(f"Mode {mode!r} is not wired yet")
tests/test_backend.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """L2 tests for backend.dispatch — pipeline is mocked at the boundary."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from unittest.mock import MagicMock
6
+
7
+ import backend as be
8
+
9
+
10
+ def test_dispatch_generate_calls_pipeline_with_expected_kwargs(monkeypatch, tmp_path):
11
+ fake_pipe = MagicMock()
12
+ fake_out = tmp_path / "out.wav"
13
+ fake_out.write_bytes(b"RIFF" + b"\0" * 1000)
14
+ fake_pipe.return_value = str(fake_out)
15
+
16
+ monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
17
+
18
+ b = be.ACEStepStudioBackend()
19
+ out_path, meta = b.dispatch(
20
+ mode="generate",
21
+ params={
22
+ "prompt": "psytrance",
23
+ "lyrics": "[verse]",
24
+ "duration_s": 10,
25
+ "instrumental": False,
26
+ "seed": 42,
27
+ "loras": [],
28
+ "advanced": {},
29
+ "lm": {},
30
+ "dcw": {},
31
+ },
32
+ )
33
+
34
+ assert out_path == str(fake_out)
35
+ assert meta["mode"] == "generate"
36
+ assert meta["seed"] == 42
37
+ fake_pipe.assert_called_once()
38
+
39
+
40
+ def test_dispatch_random_seed_if_zero(monkeypatch, tmp_path):
41
+ fake_pipe = MagicMock(return_value=str(tmp_path / "x.wav"))
42
+ monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
43
+ (tmp_path / "x.wav").write_bytes(b"RIFF")
44
+
45
+ b = be.ACEStepStudioBackend()
46
+ _, meta = b.dispatch(
47
+ mode="generate",
48
+ params={
49
+ "prompt": "p",
50
+ "lyrics": "",
51
+ "duration_s": 5,
52
+ "instrumental": False,
53
+ "seed": 0,
54
+ "loras": [],
55
+ "advanced": {},
56
+ "lm": {},
57
+ "dcw": {},
58
+ },
59
+ )
60
+
61
+ assert 1 <= meta["seed"] <= 2_147_483_647