Spaces:
Running on Zero
Running on Zero
feat(modes): add generate mode handler with input validation
Browse files- modes.py +40 -0
- tests/test_modes_generate.py +38 -0
modes.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pure mode handlers — one function per generation mode.
|
| 2 |
+
|
| 3 |
+
Each handler validates inputs, builds the ACE-Step kwargs for its mode, and
|
| 4 |
+
hands off to `backend.dispatch(...)`. Backend ownership of @spaces.GPU and
|
| 5 |
+
pipeline lifecycle keeps these handlers cheap to test.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _require(params: dict[str, Any], field: str) -> Any:
|
| 14 |
+
v = params.get(field)
|
| 15 |
+
if v is None or (isinstance(v, str) and not v.strip()):
|
| 16 |
+
raise ValueError(f"Missing required field: {field}")
|
| 17 |
+
return v
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def generate(backend, params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
| 21 |
+
"""Text → song. Vocals + instruments in one stream."""
|
| 22 |
+
prompt = _require(params, "prompt")
|
| 23 |
+
lyrics = params.get("lyrics", "")
|
| 24 |
+
duration_s = int(params.get("duration_s", 30))
|
| 25 |
+
instrumental = bool(params.get("instrumental", False))
|
| 26 |
+
|
| 27 |
+
return backend.dispatch(
|
| 28 |
+
mode="generate",
|
| 29 |
+
params={
|
| 30 |
+
"prompt": prompt,
|
| 31 |
+
"lyrics": lyrics,
|
| 32 |
+
"duration_s": duration_s,
|
| 33 |
+
"instrumental": instrumental,
|
| 34 |
+
"seed": params.get("seed"),
|
| 35 |
+
"loras": params.get("loras", []),
|
| 36 |
+
"advanced": params.get("advanced", {}),
|
| 37 |
+
"lm": params.get("lm", {}),
|
| 38 |
+
"dcw": params.get("dcw", {}),
|
| 39 |
+
},
|
| 40 |
+
)
|
tests/test_modes_generate.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""L2 tests for the generate mode handler — backend is mocked at the pipeline boundary."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from unittest.mock import MagicMock
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
import modes
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test_generate_validates_prompt_required():
|
| 13 |
+
backend = MagicMock()
|
| 14 |
+
with pytest.raises(ValueError, match="prompt"):
|
| 15 |
+
modes.generate(backend, params={"prompt": "", "lyrics": "[verse] x", "duration_s": 10})
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_generate_passes_params_to_backend(monkeypatch):
|
| 19 |
+
backend = MagicMock()
|
| 20 |
+
backend.dispatch.return_value = ("/tmp/audio.wav", {"seed": 42})
|
| 21 |
+
out_path, meta = modes.generate(
|
| 22 |
+
backend,
|
| 23 |
+
params={
|
| 24 |
+
"prompt": "psytrance",
|
| 25 |
+
"lyrics": "[verse] x",
|
| 26 |
+
"duration_s": 30,
|
| 27 |
+
"instrumental": False,
|
| 28 |
+
"seed": 42,
|
| 29 |
+
},
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
assert out_path == "/tmp/audio.wav"
|
| 33 |
+
assert meta["seed"] == 42
|
| 34 |
+
backend.dispatch.assert_called_once()
|
| 35 |
+
call_kwargs = backend.dispatch.call_args.kwargs
|
| 36 |
+
assert call_kwargs["mode"] == "generate"
|
| 37 |
+
# Cover-style params must be absent for the generate mode
|
| 38 |
+
assert "audio_cover_strength" not in call_kwargs["params"]
|