Spaces:
Running on Zero
Running on Zero
feat(pipeline): wire cover/extend/edit task_types in studio.generate
Browse files- ace_pipeline.py +83 -10
- backend.py +8 -5
- tests/test_backend.py +48 -0
ace_pipeline.py
CHANGED
|
@@ -135,12 +135,30 @@ class ACEStepStudio:
|
|
| 135 |
self._llm = llm
|
| 136 |
|
| 137 |
def generate(self, params: dict) -> str:
|
| 138 |
-
"""Run a single
|
| 139 |
-
|
| 140 |
-
``params`` is the dict produced by ``modes.
|
| 141 |
-
``
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
"""
|
| 145 |
self._ensure_loaded()
|
| 146 |
|
|
@@ -152,20 +170,68 @@ class ACEStepStudio:
|
|
| 152 |
|
| 153 |
advanced = params.get("advanced", {}) or {}
|
| 154 |
lm_opts = params.get("lm", {}) or {}
|
|
|
|
| 155 |
|
| 156 |
# Map our internal dict to ACE-Step's GenerationParams.
|
| 157 |
# Lyrics "[Instrumental]" is the ACE-Step convention for instrumental.
|
| 158 |
-
lyrics = params.get("lyrics", "") or ""
|
|
|
|
|
|
|
| 159 |
instrumental = bool(params.get("instrumental", False))
|
| 160 |
if instrumental and not lyrics:
|
| 161 |
lyrics = "[Instrumental]"
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
gen_params = GenerationParams(
|
| 164 |
-
task_type=
|
| 165 |
-
caption=
|
| 166 |
lyrics=lyrics,
|
| 167 |
instrumental=instrumental,
|
| 168 |
-
duration=
|
| 169 |
seed=int(params.get("seed", -1)),
|
| 170 |
inference_steps=int(advanced.get("steps", 32)),
|
| 171 |
guidance_scale=float(advanced.get("cfg", 4.0)),
|
|
@@ -176,6 +242,13 @@ class ACEStepStudio:
|
|
| 176 |
vocal_language=advanced.get("vocal_language", "unknown"),
|
| 177 |
cfg_interval_start=float(advanced.get("cfg_interval_start", 0.0)),
|
| 178 |
cfg_interval_end=float(advanced.get("cfg_interval_end", 1.0)),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
thinking=bool(lm_opts.get("thinking", False)),
|
| 180 |
lm_temperature=float(lm_opts.get("temperature", 0.85)),
|
| 181 |
lm_cfg_scale=float(lm_opts.get("cfg", 2.0)),
|
|
|
|
| 135 |
self._llm = llm
|
| 136 |
|
| 137 |
def generate(self, params: dict) -> str:
|
| 138 |
+
"""Run a single song generation across all four modes.
|
| 139 |
+
|
| 140 |
+
``params`` is the dict produced by the mode handlers in ``modes.py``.
|
| 141 |
+
The ``params["mode"]`` key (``generate`` | ``cover`` | ``extend`` |
|
| 142 |
+
``edit``) selects the ACE-Step ``task_type`` and which audio inputs
|
| 143 |
+
get wired through to ``GenerationParams``:
|
| 144 |
+
|
| 145 |
+
- ``generate``: ``task_type="text2music"``
|
| 146 |
+
- ``cover``: ``task_type="cover"`` + ``reference_audio`` +
|
| 147 |
+
``audio_cover_strength``
|
| 148 |
+
- ``extend``: ``task_type="repaint"`` + ``src_audio`` set to the
|
| 149 |
+
seed, with ``repainting_start=-1`` / ``repainting_end=-1`` as a
|
| 150 |
+
sentinel meaning "paint after the end of the seed". The actual
|
| 151 |
+
mask shaping ultimately lives inside ACE-Step's repaint path.
|
| 152 |
+
- ``edit``: ``task_type="repaint"`` + ``src_audio`` + explicit
|
| 153 |
+
``[segment_start_s, segment_end_s]`` segment bounds.
|
| 154 |
+
|
| 155 |
+
Flow-edit (``sub_mode="flow_edit"``) is implemented as a repaint
|
| 156 |
+
pass: the installed ACE-Step ``GenerationParams`` dataclass has no
|
| 157 |
+
native ``flow_edit_*`` fields, so the extra flow-edit knobs carried
|
| 158 |
+
in the internal params dict are ignored at the ``GenerationParams``
|
| 159 |
+
instantiation level and will need wiring once upstream grows them.
|
| 160 |
+
|
| 161 |
+
Returns the path to the produced audio file.
|
| 162 |
"""
|
| 163 |
self._ensure_loaded()
|
| 164 |
|
|
|
|
| 170 |
|
| 171 |
advanced = params.get("advanced", {}) or {}
|
| 172 |
lm_opts = params.get("lm", {}) or {}
|
| 173 |
+
mode = params.get("mode", "generate")
|
| 174 |
|
| 175 |
# Map our internal dict to ACE-Step's GenerationParams.
|
| 176 |
# Lyrics "[Instrumental]" is the ACE-Step convention for instrumental.
|
| 177 |
+
lyrics = params.get("lyrics", "") or params.get("extension_lyrics", "") or ""
|
| 178 |
+
if mode == "edit":
|
| 179 |
+
lyrics = params.get("target_lyrics", "") or lyrics
|
| 180 |
instrumental = bool(params.get("instrumental", False))
|
| 181 |
if instrumental and not lyrics:
|
| 182 |
lyrics = "[Instrumental]"
|
| 183 |
|
| 184 |
+
# Mode-specific task_type + audio inputs.
|
| 185 |
+
# All five fields below MUST resolve before we instantiate
|
| 186 |
+
# GenerationParams so that the dataclass ctor sees consistent values.
|
| 187 |
+
ref_audio: str | None = None
|
| 188 |
+
src_audio: str | None = None
|
| 189 |
+
audio_cover_strength = 0.0
|
| 190 |
+
repainting_start = 0.0
|
| 191 |
+
repainting_end = -1.0
|
| 192 |
+
|
| 193 |
+
if mode == "generate":
|
| 194 |
+
task_type = "text2music"
|
| 195 |
+
elif mode == "cover":
|
| 196 |
+
task_type = "cover"
|
| 197 |
+
ref_audio = params.get("ref_audio")
|
| 198 |
+
audio_cover_strength = float(params.get("audio_cover_strength", 0.93))
|
| 199 |
+
elif mode == "extend":
|
| 200 |
+
task_type = "repaint"
|
| 201 |
+
src_audio = params.get("seed_audio")
|
| 202 |
+
# Sentinel: -1 / -1 means "append after the seed audio's end".
|
| 203 |
+
# ACE-Step's repaint path interprets these bounds against the
|
| 204 |
+
# src_audio duration; the actual semantics need verifying once
|
| 205 |
+
# we run a full pass on real hardware (M3 GPU smoke).
|
| 206 |
+
repainting_start = -1.0
|
| 207 |
+
repainting_end = -1.0
|
| 208 |
+
elif mode == "edit":
|
| 209 |
+
task_type = "repaint"
|
| 210 |
+
src_audio = params.get("source_audio")
|
| 211 |
+
repainting_start = float(params.get("segment_start_s", 0.0))
|
| 212 |
+
repainting_end = float(params.get("segment_end_s", 30.0))
|
| 213 |
+
# flow_edit sub-mode: lower audio_cover_strength to allow style
|
| 214 |
+
# drift while still using the repaint task type. The extra
|
| 215 |
+
# flow_* fields in our internal params dict are kept around for
|
| 216 |
+
# future use but not forwarded to GenerationParams (no native
|
| 217 |
+
# support in the installed dataclass).
|
| 218 |
+
if params.get("sub_mode") == "flow_edit":
|
| 219 |
+
audio_cover_strength = 0.3
|
| 220 |
+
else:
|
| 221 |
+
raise ValueError(f"Unknown mode: {mode!r}")
|
| 222 |
+
|
| 223 |
+
# Caption can come from the per-mode handlers under different keys.
|
| 224 |
+
caption = (
|
| 225 |
+
params.get("prompt") or params.get("extra_prompt") or params.get("flow_source_caption") or ""
|
| 226 |
+
)
|
| 227 |
+
duration_s = int(params.get("duration_s") or params.get("extra_duration_s") or 30)
|
| 228 |
+
|
| 229 |
gen_params = GenerationParams(
|
| 230 |
+
task_type=task_type,
|
| 231 |
+
caption=caption,
|
| 232 |
lyrics=lyrics,
|
| 233 |
instrumental=instrumental,
|
| 234 |
+
duration=duration_s,
|
| 235 |
seed=int(params.get("seed", -1)),
|
| 236 |
inference_steps=int(advanced.get("steps", 32)),
|
| 237 |
guidance_scale=float(advanced.get("cfg", 4.0)),
|
|
|
|
| 242 |
vocal_language=advanced.get("vocal_language", "unknown"),
|
| 243 |
cfg_interval_start=float(advanced.get("cfg_interval_start", 0.0)),
|
| 244 |
cfg_interval_end=float(advanced.get("cfg_interval_end", 1.0)),
|
| 245 |
+
# Mode-specific audio inputs + repaint bounds
|
| 246 |
+
reference_audio=ref_audio,
|
| 247 |
+
src_audio=src_audio,
|
| 248 |
+
audio_cover_strength=audio_cover_strength,
|
| 249 |
+
repainting_start=repainting_start,
|
| 250 |
+
repainting_end=repainting_end,
|
| 251 |
+
# 5Hz language model knobs
|
| 252 |
thinking=bool(lm_opts.get("thinking", False)),
|
| 253 |
lm_temperature=float(lm_opts.get("temperature", 0.85)),
|
| 254 |
lm_cfg_scale=float(lm_opts.get("cfg", 2.0)),
|
backend.py
CHANGED
|
@@ -77,10 +77,13 @@ class ACEStepStudioBackend:
|
|
| 77 |
``generate(params)`` method that handles the underlying
|
| 78 |
AceStepHandler + LLMHandler + generate_music plumbing.
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
| 82 |
"""
|
| 83 |
-
if mode
|
| 84 |
-
|
| 85 |
-
|
| 86 |
raise NotImplementedError(f"Mode {mode!r} is not wired yet")
|
|
|
|
| 77 |
``generate(params)`` method that handles the underlying
|
| 78 |
AceStepHandler + LLMHandler + generate_music plumbing.
|
| 79 |
|
| 80 |
+
All four song modes (``generate``, ``cover``, ``extend``, ``edit``)
|
| 81 |
+
flow through ``pipe.generate(params)``. The pipeline wrapper
|
| 82 |
+
switches its ``GenerationParams.task_type`` based on ``params["mode"]``
|
| 83 |
+
— see ``ace_pipeline.ACEStepStudio.generate`` for the mapping. The
|
| 84 |
+
``lyrics`` mode is wired separately at M4.
|
| 85 |
"""
|
| 86 |
+
if mode in ("generate", "cover", "extend", "edit"):
|
| 87 |
+
params_with_mode = {**params, "mode": mode}
|
| 88 |
+
return pipe.generate(params_with_mode)
|
| 89 |
raise NotImplementedError(f"Mode {mode!r} is not wired yet")
|
tests/test_backend.py
CHANGED
|
@@ -4,6 +4,8 @@ from __future__ import annotations
|
|
| 4 |
|
| 5 |
from unittest.mock import MagicMock
|
| 6 |
|
|
|
|
|
|
|
| 7 |
import backend as be
|
| 8 |
|
| 9 |
|
|
@@ -98,3 +100,49 @@ def test_dispatch_applies_lora_stack(monkeypatch, tmp_path):
|
|
| 98 |
)
|
| 99 |
|
| 100 |
apply_mock.assert_called_once_with(fake_pipe, stack)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
from unittest.mock import MagicMock
|
| 6 |
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
import backend as be
|
| 10 |
|
| 11 |
|
|
|
|
| 100 |
)
|
| 101 |
|
| 102 |
apply_mock.assert_called_once_with(fake_pipe, stack)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@pytest.mark.parametrize(
|
| 106 |
+
"mode,extra",
|
| 107 |
+
[
|
| 108 |
+
("cover", {"ref_audio": "/tmp/ref.wav", "audio_cover_strength": 0.9}),
|
| 109 |
+
("extend", {"seed_audio": "/tmp/seed.wav", "extra_duration_s": 60}),
|
| 110 |
+
(
|
| 111 |
+
"edit",
|
| 112 |
+
{
|
| 113 |
+
"source_audio": "/tmp/src.wav",
|
| 114 |
+
"segment_start_s": 50.0,
|
| 115 |
+
"segment_end_s": 90.0,
|
| 116 |
+
"sub_mode": "repaint",
|
| 117 |
+
},
|
| 118 |
+
),
|
| 119 |
+
],
|
| 120 |
+
)
|
| 121 |
+
def test_dispatch_forwards_mode_to_pipe_generate(monkeypatch, tmp_path, mode, extra):
|
| 122 |
+
fake_pipe = MagicMock()
|
| 123 |
+
fake_pipe.generate.return_value = str(tmp_path / "x.wav")
|
| 124 |
+
(tmp_path / "x.wav").write_bytes(b"RIFF")
|
| 125 |
+
monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
|
| 126 |
+
monkeypatch.setattr("lora_stack.apply_stack", MagicMock())
|
| 127 |
+
|
| 128 |
+
b = be.ACEStepStudioBackend()
|
| 129 |
+
params = {
|
| 130 |
+
"prompt": "p",
|
| 131 |
+
"lyrics": "",
|
| 132 |
+
"duration_s": 10,
|
| 133 |
+
"instrumental": True,
|
| 134 |
+
"seed": 42,
|
| 135 |
+
"loras": [],
|
| 136 |
+
"advanced": {},
|
| 137 |
+
"lm": {},
|
| 138 |
+
"dcw": {},
|
| 139 |
+
**extra,
|
| 140 |
+
}
|
| 141 |
+
b.dispatch(mode=mode, params=params)
|
| 142 |
+
|
| 143 |
+
fake_pipe.generate.assert_called_once()
|
| 144 |
+
sent_params = fake_pipe.generate.call_args.args[0]
|
| 145 |
+
assert sent_params["mode"] == mode
|
| 146 |
+
# Mode-specific keys propagate to pipe.generate
|
| 147 |
+
for k, v in extra.items():
|
| 148 |
+
assert sent_params[k] == v
|