techfreakworm commited on
Commit
26dc3a4
·
unverified ·
1 Parent(s): a2eae29

feat(pipeline): wire cover/extend/edit task_types in studio.generate

Browse files
Files changed (3) hide show
  1. ace_pipeline.py +83 -10
  2. backend.py +8 -5
  3. tests/test_backend.py +48 -0
ace_pipeline.py CHANGED
@@ -135,12 +135,30 @@ class ACEStepStudio:
135
  self._llm = llm
136
 
137
  def generate(self, params: dict) -> str:
138
- """Run a single text→song generation.
139
-
140
- ``params`` is the dict produced by ``modes.generate``:
141
- ``{"prompt", "lyrics", "duration_s", "instrumental", "seed",
142
- "loras", "advanced", "lm", "dcw"}``. Returns the path to the
143
- produced audio file.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  """
145
  self._ensure_loaded()
146
 
@@ -152,20 +170,68 @@ class ACEStepStudio:
152
 
153
  advanced = params.get("advanced", {}) or {}
154
  lm_opts = params.get("lm", {}) or {}
 
155
 
156
  # Map our internal dict to ACE-Step's GenerationParams.
157
  # Lyrics "[Instrumental]" is the ACE-Step convention for instrumental.
158
- lyrics = params.get("lyrics", "") or ""
 
 
159
  instrumental = bool(params.get("instrumental", False))
160
  if instrumental and not lyrics:
161
  lyrics = "[Instrumental]"
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  gen_params = GenerationParams(
164
- task_type="text2music",
165
- caption=params.get("prompt", ""),
166
  lyrics=lyrics,
167
  instrumental=instrumental,
168
- duration=int(params.get("duration_s", 30)),
169
  seed=int(params.get("seed", -1)),
170
  inference_steps=int(advanced.get("steps", 32)),
171
  guidance_scale=float(advanced.get("cfg", 4.0)),
@@ -176,6 +242,13 @@ class ACEStepStudio:
176
  vocal_language=advanced.get("vocal_language", "unknown"),
177
  cfg_interval_start=float(advanced.get("cfg_interval_start", 0.0)),
178
  cfg_interval_end=float(advanced.get("cfg_interval_end", 1.0)),
 
 
 
 
 
 
 
179
  thinking=bool(lm_opts.get("thinking", False)),
180
  lm_temperature=float(lm_opts.get("temperature", 0.85)),
181
  lm_cfg_scale=float(lm_opts.get("cfg", 2.0)),
 
135
  self._llm = llm
136
 
137
  def generate(self, params: dict) -> str:
138
+ """Run a single song generation across all four modes.
139
+
140
+ ``params`` is the dict produced by the mode handlers in ``modes.py``.
141
+ The ``params["mode"]`` key (``generate`` | ``cover`` | ``extend`` |
142
+ ``edit``) selects the ACE-Step ``task_type`` and which audio inputs
143
+ get wired through to ``GenerationParams``:
144
+
145
+ - ``generate``: ``task_type="text2music"``
146
+ - ``cover``: ``task_type="cover"`` + ``reference_audio`` +
147
+ ``audio_cover_strength``
148
+ - ``extend``: ``task_type="repaint"`` + ``src_audio`` set to the
149
+ seed, with ``repainting_start=-1`` / ``repainting_end=-1`` as a
150
+ sentinel meaning "paint after the end of the seed". The actual
151
+ mask shaping ultimately lives inside ACE-Step's repaint path.
152
+ - ``edit``: ``task_type="repaint"`` + ``src_audio`` + explicit
153
+ ``[segment_start_s, segment_end_s]`` segment bounds.
154
+
155
+ Flow-edit (``sub_mode="flow_edit"``) is implemented as a repaint
156
+ pass: the installed ACE-Step ``GenerationParams`` dataclass has no
157
+ native ``flow_edit_*`` fields, so the extra flow-edit knobs carried
158
+ in the internal params dict are ignored at the ``GenerationParams``
159
+ instantiation level and will need wiring once upstream grows them.
160
+
161
+ Returns the path to the produced audio file.
162
  """
163
  self._ensure_loaded()
164
 
 
170
 
171
  advanced = params.get("advanced", {}) or {}
172
  lm_opts = params.get("lm", {}) or {}
173
+ mode = params.get("mode", "generate")
174
 
175
  # Map our internal dict to ACE-Step's GenerationParams.
176
  # Lyrics "[Instrumental]" is the ACE-Step convention for instrumental.
177
+ lyrics = params.get("lyrics", "") or params.get("extension_lyrics", "") or ""
178
+ if mode == "edit":
179
+ lyrics = params.get("target_lyrics", "") or lyrics
180
  instrumental = bool(params.get("instrumental", False))
181
  if instrumental and not lyrics:
182
  lyrics = "[Instrumental]"
183
 
184
+ # Mode-specific task_type + audio inputs.
185
+ # All five fields below MUST resolve before we instantiate
186
+ # GenerationParams so that the dataclass ctor sees consistent values.
187
+ ref_audio: str | None = None
188
+ src_audio: str | None = None
189
+ audio_cover_strength = 0.0
190
+ repainting_start = 0.0
191
+ repainting_end = -1.0
192
+
193
+ if mode == "generate":
194
+ task_type = "text2music"
195
+ elif mode == "cover":
196
+ task_type = "cover"
197
+ ref_audio = params.get("ref_audio")
198
+ audio_cover_strength = float(params.get("audio_cover_strength", 0.93))
199
+ elif mode == "extend":
200
+ task_type = "repaint"
201
+ src_audio = params.get("seed_audio")
202
+ # Sentinel: -1 / -1 means "append after the seed audio's end".
203
+ # ACE-Step's repaint path interprets these bounds against the
204
+ # src_audio duration; the actual semantics need verifying once
205
+ # we run a full pass on real hardware (M3 GPU smoke).
206
+ repainting_start = -1.0
207
+ repainting_end = -1.0
208
+ elif mode == "edit":
209
+ task_type = "repaint"
210
+ src_audio = params.get("source_audio")
211
+ repainting_start = float(params.get("segment_start_s", 0.0))
212
+ repainting_end = float(params.get("segment_end_s", 30.0))
213
+ # flow_edit sub-mode: lower audio_cover_strength to allow style
214
+ # drift while still using the repaint task type. The extra
215
+ # flow_* fields in our internal params dict are kept around for
216
+ # future use but not forwarded to GenerationParams (no native
217
+ # support in the installed dataclass).
218
+ if params.get("sub_mode") == "flow_edit":
219
+ audio_cover_strength = 0.3
220
+ else:
221
+ raise ValueError(f"Unknown mode: {mode!r}")
222
+
223
+ # Caption can come from the per-mode handlers under different keys.
224
+ caption = (
225
+ params.get("prompt") or params.get("extra_prompt") or params.get("flow_source_caption") or ""
226
+ )
227
+ duration_s = int(params.get("duration_s") or params.get("extra_duration_s") or 30)
228
+
229
  gen_params = GenerationParams(
230
+ task_type=task_type,
231
+ caption=caption,
232
  lyrics=lyrics,
233
  instrumental=instrumental,
234
+ duration=duration_s,
235
  seed=int(params.get("seed", -1)),
236
  inference_steps=int(advanced.get("steps", 32)),
237
  guidance_scale=float(advanced.get("cfg", 4.0)),
 
242
  vocal_language=advanced.get("vocal_language", "unknown"),
243
  cfg_interval_start=float(advanced.get("cfg_interval_start", 0.0)),
244
  cfg_interval_end=float(advanced.get("cfg_interval_end", 1.0)),
245
+ # Mode-specific audio inputs + repaint bounds
246
+ reference_audio=ref_audio,
247
+ src_audio=src_audio,
248
+ audio_cover_strength=audio_cover_strength,
249
+ repainting_start=repainting_start,
250
+ repainting_end=repainting_end,
251
+ # 5Hz language model knobs
252
  thinking=bool(lm_opts.get("thinking", False)),
253
  lm_temperature=float(lm_opts.get("temperature", 0.85)),
254
  lm_cfg_scale=float(lm_opts.get("cfg", 2.0)),
backend.py CHANGED
@@ -77,10 +77,13 @@ class ACEStepStudioBackend:
77
  ``generate(params)`` method that handles the underlying
78
  AceStepHandler + LLMHandler + generate_music plumbing.
79
 
80
- Cover / Extend / Edit / Lyrics task_types are mapped here at
81
- M3 / M4 by switching ``params["task_type"]`` before calling.
 
 
 
82
  """
83
- if mode == "generate":
84
- return pipe.generate(params)
85
- # cover / extend / edit / lyrics get filled in at M3 / M4
86
  raise NotImplementedError(f"Mode {mode!r} is not wired yet")
 
77
  ``generate(params)`` method that handles the underlying
78
  AceStepHandler + LLMHandler + generate_music plumbing.
79
 
80
+ All four song modes (``generate``, ``cover``, ``extend``, ``edit``)
81
+ flow through ``pipe.generate(params)``. The pipeline wrapper
82
+ switches its ``GenerationParams.task_type`` based on ``params["mode"]``
83
+ — see ``ace_pipeline.ACEStepStudio.generate`` for the mapping. The
84
+ ``lyrics`` mode is wired separately at M4.
85
  """
86
+ if mode in ("generate", "cover", "extend", "edit"):
87
+ params_with_mode = {**params, "mode": mode}
88
+ return pipe.generate(params_with_mode)
89
  raise NotImplementedError(f"Mode {mode!r} is not wired yet")
tests/test_backend.py CHANGED
@@ -4,6 +4,8 @@ from __future__ import annotations
4
 
5
  from unittest.mock import MagicMock
6
 
 
 
7
  import backend as be
8
 
9
 
@@ -98,3 +100,49 @@ def test_dispatch_applies_lora_stack(monkeypatch, tmp_path):
98
  )
99
 
100
  apply_mock.assert_called_once_with(fake_pipe, stack)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  from unittest.mock import MagicMock
6
 
7
+ import pytest
8
+
9
  import backend as be
10
 
11
 
 
100
  )
101
 
102
  apply_mock.assert_called_once_with(fake_pipe, stack)
103
+
104
+
105
+ @pytest.mark.parametrize(
106
+ "mode,extra",
107
+ [
108
+ ("cover", {"ref_audio": "/tmp/ref.wav", "audio_cover_strength": 0.9}),
109
+ ("extend", {"seed_audio": "/tmp/seed.wav", "extra_duration_s": 60}),
110
+ (
111
+ "edit",
112
+ {
113
+ "source_audio": "/tmp/src.wav",
114
+ "segment_start_s": 50.0,
115
+ "segment_end_s": 90.0,
116
+ "sub_mode": "repaint",
117
+ },
118
+ ),
119
+ ],
120
+ )
121
+ def test_dispatch_forwards_mode_to_pipe_generate(monkeypatch, tmp_path, mode, extra):
122
+ fake_pipe = MagicMock()
123
+ fake_pipe.generate.return_value = str(tmp_path / "x.wav")
124
+ (tmp_path / "x.wav").write_bytes(b"RIFF")
125
+ monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
126
+ monkeypatch.setattr("lora_stack.apply_stack", MagicMock())
127
+
128
+ b = be.ACEStepStudioBackend()
129
+ params = {
130
+ "prompt": "p",
131
+ "lyrics": "",
132
+ "duration_s": 10,
133
+ "instrumental": True,
134
+ "seed": 42,
135
+ "loras": [],
136
+ "advanced": {},
137
+ "lm": {},
138
+ "dcw": {},
139
+ **extra,
140
+ }
141
+ b.dispatch(mode=mode, params=params)
142
+
143
+ fake_pipe.generate.assert_called_once()
144
+ sent_params = fake_pipe.generate.call_args.args[0]
145
+ assert sent_params["mode"] == mode
146
+ # Mode-specific keys propagate to pipe.generate
147
+ for k, v in extra.items():
148
+ assert sent_params[k] == v