techfreakworm commited on
Commit
a2eae29
·
unverified ·
1 Parent(s): 24b2854

feat(modes): add cover, extend, edit handlers with sub-mode dispatch

Browse files
Files changed (2) hide show
  1. modes.py +117 -0
  2. tests/test_modes_other.py +125 -0
modes.py CHANGED
@@ -38,3 +38,120 @@ def generate(backend, params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
38
  "dcw": params.get("dcw", {}),
39
  },
40
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  "dcw": params.get("dcw", {}),
39
  },
40
  )
41
+
42
+
43
+ def cover(backend, params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
44
+ """Audio-reference cover — reference audio + new prompt -> song in that style.
45
+
46
+ Maps to ACE-Step's ``GenerationParams(task_type="cover")`` with
47
+ ``reference_audio`` set to the uploaded clip and ``audio_cover_strength``
48
+ controlling how tightly the new song hugs the reference timbre/structure.
49
+ """
50
+ ref_audio = _require(params, "ref_audio")
51
+ prompt = params.get("prompt", "")
52
+ lyrics = params.get("lyrics", "")
53
+ duration_s = int(params.get("duration_s", 30))
54
+
55
+ return backend.dispatch(
56
+ mode="cover",
57
+ params={
58
+ "prompt": prompt,
59
+ "ref_audio": ref_audio,
60
+ "lyrics": lyrics,
61
+ "duration_s": duration_s,
62
+ "audio_cover_strength": float(params.get("audio_cover_strength", 0.93)),
63
+ "cover_noise_strength": float(params.get("cover_noise_strength", 0.0)),
64
+ "seed": params.get("seed"),
65
+ "loras": params.get("loras", []),
66
+ "advanced": params.get("advanced", {}),
67
+ "lm": params.get("lm", {}),
68
+ "dcw": params.get("dcw", {}),
69
+ },
70
+ )
71
+
72
+
73
+ def extend(backend, params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
74
+ """Song continuation — seed audio + extension prompt -> extended song.
75
+
76
+ Maps to ACE-Step's ``GenerationParams(task_type="repaint")`` with
77
+ ``src_audio`` set to the seed and ``repainting_start``/``repainting_end``
78
+ pointing past the end of the seed so the model paints new audio after it.
79
+ """
80
+ seed_audio = _require(params, "seed_audio")
81
+ extra_prompt = params.get("extra_prompt", "")
82
+ extra_duration_s = int(params.get("extra_duration_s", 60))
83
+
84
+ return backend.dispatch(
85
+ mode="extend",
86
+ params={
87
+ "seed_audio": seed_audio,
88
+ "extra_prompt": extra_prompt,
89
+ "extension_lyrics": params.get("extension_lyrics", ""),
90
+ "extra_duration_s": extra_duration_s,
91
+ "repaint_mode": params.get("repaint_mode", "balanced"),
92
+ "repaint_strength": float(params.get("repaint_strength", 0.5)),
93
+ "wav_crossfade_s": float(params.get("wav_crossfade_s", 2.0)),
94
+ "latent_crossfade_frames": int(params.get("latent_crossfade_frames", 10)),
95
+ "chunk_mask_mode": params.get("chunk_mask_mode", "auto"),
96
+ "seed": params.get("seed"),
97
+ "loras": params.get("loras", []),
98
+ "advanced": params.get("advanced", {}),
99
+ "lm": params.get("lm", {}),
100
+ "dcw": params.get("dcw", {}),
101
+ },
102
+ )
103
+
104
+
105
+ def edit(backend, params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
106
+ """Segment-level edit — repaint a region OR morph caption-to-caption.
107
+
108
+ Two sub-modes:
109
+
110
+ - ``"repaint"`` (default): paint over ``[segment_start_s, segment_end_s]``
111
+ using ACE-Step's ``task_type="repaint"`` with the segment bounds wired
112
+ into ``repainting_start`` / ``repainting_end``.
113
+ - ``"flow_edit"``: caption-to-caption morph. The installed ACE-Step
114
+ ``GenerationParams`` dataclass has no native ``flow_edit_*`` fields, so
115
+ flow-edit is implemented downstream as a ``task_type="repaint"`` pass
116
+ with a lower ``audio_cover_strength`` to allow more style drift. The
117
+ ``flow_source_caption`` / ``flow_n_*`` knobs are carried through the
118
+ internal params dict so the pipeline wrapper can use them if/when the
119
+ upstream dataclass grows native support.
120
+ """
121
+ source_audio = _require(params, "source_audio")
122
+ sub_mode = params.get("sub_mode", "repaint")
123
+
124
+ out_params: dict[str, Any] = {
125
+ "source_audio": source_audio,
126
+ "source_lyrics": params.get("source_lyrics", ""),
127
+ "target_lyrics": params.get("target_lyrics", ""),
128
+ "segment_start_s": float(params.get("segment_start_s", 0.0)),
129
+ "segment_end_s": float(params.get("segment_end_s", 30.0)),
130
+ "sub_mode": sub_mode,
131
+ "seed": params.get("seed"),
132
+ "loras": params.get("loras", []),
133
+ "advanced": params.get("advanced", {}),
134
+ "lm": params.get("lm", {}),
135
+ "dcw": params.get("dcw", {}),
136
+ }
137
+ if sub_mode == "repaint":
138
+ out_params.update(
139
+ {
140
+ "repaint_mode": params.get("repaint_mode", "balanced"),
141
+ "repaint_strength": float(params.get("repaint_strength", 0.5)),
142
+ "chunk_mask_mode": params.get("chunk_mask_mode", "auto"),
143
+ "latent_crossfade_frames": int(params.get("latent_crossfade_frames", 10)),
144
+ "wav_crossfade_s": float(params.get("wav_crossfade_s", 0.0)),
145
+ }
146
+ )
147
+ elif sub_mode == "flow_edit":
148
+ out_params.update(
149
+ {
150
+ "flow_source_caption": params.get("flow_source_caption", ""),
151
+ "flow_n_min": float(params.get("flow_n_min", 0.0)),
152
+ "flow_n_max": float(params.get("flow_n_max", 1.0)),
153
+ "flow_n_avg": int(params.get("flow_n_avg", 1)),
154
+ }
155
+ )
156
+
157
+ return backend.dispatch(mode="edit", params=out_params)
tests/test_modes_other.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """L2 tests for cover / extend / edit mode handlers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from unittest.mock import MagicMock
6
+
7
+ import pytest
8
+
9
+ import modes
10
+
11
+
12
+ def test_cover_requires_ref_audio():
13
+ with pytest.raises(ValueError, match="ref_audio"):
14
+ modes.cover(
15
+ MagicMock(),
16
+ params={"prompt": "p", "lyrics": "[v]", "duration_s": 30, "ref_audio": None},
17
+ )
18
+
19
+
20
+ def test_cover_passes_audio_cover_strength():
21
+ backend = MagicMock()
22
+ backend.dispatch.return_value = ("/tmp/x.wav", {})
23
+ modes.cover(
24
+ backend,
25
+ params={
26
+ "prompt": "p",
27
+ "lyrics": "[v]",
28
+ "duration_s": 30,
29
+ "ref_audio": "/tmp/ref.wav",
30
+ "audio_cover_strength": 0.9,
31
+ "loras": [],
32
+ "advanced": {},
33
+ "lm": {},
34
+ "dcw": {},
35
+ },
36
+ )
37
+ args = backend.dispatch.call_args.kwargs
38
+ assert args["mode"] == "cover"
39
+ assert args["params"]["audio_cover_strength"] == 0.9
40
+ assert args["params"]["ref_audio"] == "/tmp/ref.wav"
41
+
42
+
43
+ def test_extend_requires_seed_audio():
44
+ with pytest.raises(ValueError, match="seed_audio"):
45
+ modes.extend(
46
+ MagicMock(),
47
+ params={"extra_prompt": "p", "extra_duration_s": 60, "seed_audio": None},
48
+ )
49
+
50
+
51
+ def test_extend_passes_repaint_params():
52
+ backend = MagicMock()
53
+ backend.dispatch.return_value = ("/tmp/x.wav", {})
54
+ modes.extend(
55
+ backend,
56
+ params={
57
+ "seed_audio": "/tmp/seed.wav",
58
+ "extra_prompt": "more",
59
+ "extra_duration_s": 60,
60
+ "extension_lyrics": "[v]",
61
+ "repaint_strength": 0.5,
62
+ "wav_crossfade_s": 2.0,
63
+ "loras": [],
64
+ "advanced": {},
65
+ "lm": {},
66
+ "dcw": {},
67
+ },
68
+ )
69
+ args = backend.dispatch.call_args.kwargs
70
+ assert args["mode"] == "extend"
71
+ assert args["params"]["repaint_strength"] == 0.5
72
+ assert args["params"]["wav_crossfade_s"] == 2.0
73
+
74
+
75
+ def test_edit_repaint_passes_segment_bounds():
76
+ backend = MagicMock()
77
+ backend.dispatch.return_value = ("/tmp/x.wav", {})
78
+ modes.edit(
79
+ backend,
80
+ params={
81
+ "source_audio": "/tmp/src.wav",
82
+ "source_lyrics": "[v]",
83
+ "target_lyrics": "[c] new",
84
+ "segment_start_s": 50.0,
85
+ "segment_end_s": 90.0,
86
+ "sub_mode": "repaint",
87
+ "repaint_strength": 0.5,
88
+ "loras": [],
89
+ "advanced": {},
90
+ "lm": {},
91
+ "dcw": {},
92
+ },
93
+ )
94
+ args = backend.dispatch.call_args.kwargs
95
+ assert args["mode"] == "edit"
96
+ assert args["params"]["segment_start_s"] == 50.0
97
+ assert args["params"]["segment_end_s"] == 90.0
98
+ assert args["params"]["sub_mode"] == "repaint"
99
+
100
+
101
+ def test_edit_flow_morph_sub_mode_passes_through():
102
+ backend = MagicMock()
103
+ backend.dispatch.return_value = ("/tmp/x.wav", {})
104
+ modes.edit(
105
+ backend,
106
+ params={
107
+ "source_audio": "/tmp/src.wav",
108
+ "source_lyrics": "[v]",
109
+ "target_lyrics": "[c]",
110
+ "segment_start_s": 0.0,
111
+ "segment_end_s": 30.0,
112
+ "sub_mode": "flow_edit",
113
+ "flow_source_caption": "acoustic ballad",
114
+ "flow_n_min": 0.0,
115
+ "flow_n_max": 1.0,
116
+ "flow_n_avg": 1,
117
+ "loras": [],
118
+ "advanced": {},
119
+ "lm": {},
120
+ "dcw": {},
121
+ },
122
+ )
123
+ args = backend.dispatch.call_args.kwargs
124
+ assert args["params"]["sub_mode"] == "flow_edit"
125
+ assert args["params"]["flow_source_caption"] == "acoustic ballad"