techfreakworm commited on
Commit
52f41b8
·
unverified ·
1 Parent(s): eb3bcb4

feat(modes): add generate mode handler with input validation

Browse files
Files changed (2) hide show
  1. modes.py +40 -0
  2. tests/test_modes_generate.py +38 -0
modes.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pure mode handlers — one function per generation mode.
2
+
3
+ Each handler validates inputs, builds the ACE-Step kwargs for its mode, and
4
+ hands off to `backend.dispatch(...)`. Backend ownership of @spaces.GPU and
5
+ pipeline lifecycle keeps these handlers cheap to test.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Any
11
+
12
+
13
+ def _require(params: dict[str, Any], field: str) -> Any:
14
+ v = params.get(field)
15
+ if v is None or (isinstance(v, str) and not v.strip()):
16
+ raise ValueError(f"Missing required field: {field}")
17
+ return v
18
+
19
+
20
+ def generate(backend, params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
21
+ """Text → song. Vocals + instruments in one stream."""
22
+ prompt = _require(params, "prompt")
23
+ lyrics = params.get("lyrics", "")
24
+ duration_s = int(params.get("duration_s", 30))
25
+ instrumental = bool(params.get("instrumental", False))
26
+
27
+ return backend.dispatch(
28
+ mode="generate",
29
+ params={
30
+ "prompt": prompt,
31
+ "lyrics": lyrics,
32
+ "duration_s": duration_s,
33
+ "instrumental": instrumental,
34
+ "seed": params.get("seed"),
35
+ "loras": params.get("loras", []),
36
+ "advanced": params.get("advanced", {}),
37
+ "lm": params.get("lm", {}),
38
+ "dcw": params.get("dcw", {}),
39
+ },
40
+ )
tests/test_modes_generate.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """L2 tests for the generate mode handler — backend is mocked at the pipeline boundary."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from unittest.mock import MagicMock
6
+
7
+ import pytest
8
+
9
+ import modes
10
+
11
+
12
+ def test_generate_validates_prompt_required():
13
+ backend = MagicMock()
14
+ with pytest.raises(ValueError, match="prompt"):
15
+ modes.generate(backend, params={"prompt": "", "lyrics": "[verse] x", "duration_s": 10})
16
+
17
+
18
+ def test_generate_passes_params_to_backend(monkeypatch):
19
+ backend = MagicMock()
20
+ backend.dispatch.return_value = ("/tmp/audio.wav", {"seed": 42})
21
+ out_path, meta = modes.generate(
22
+ backend,
23
+ params={
24
+ "prompt": "psytrance",
25
+ "lyrics": "[verse] x",
26
+ "duration_s": 30,
27
+ "instrumental": False,
28
+ "seed": 42,
29
+ },
30
+ )
31
+
32
+ assert out_path == "/tmp/audio.wav"
33
+ assert meta["seed"] == 42
34
+ backend.dispatch.assert_called_once()
35
+ call_kwargs = backend.dispatch.call_args.kwargs
36
+ assert call_kwargs["mode"] == "generate"
37
+ # Cover-style params must be absent for the generate mode
38
+ assert "audio_cover_strength" not in call_kwargs["params"]