techfreakworm commited on
Commit
05f6b53
·
unverified ·
1 Parent(s): a5459fd

fix(deploy): per-mode duration estimator handles cover/extend/edit/lyrics signatures

Browse files
Files changed (2) hide show
  1. app.py +58 -10
  2. tests/test_gpu_estimator.py +82 -0
app.py CHANGED
@@ -159,21 +159,69 @@ def _estimate_gpu_duration(mode: str, params: dict, multiplier: float = 1.0) ->
159
  return max(_GPU_CLAMP_MIN, min(_GPU_CLAMP_MAX, int(estimated)))
160
 
161
 
162
- def _gpu_call_to_estimator(mode: str, *, duration_arg_index: int = 2):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  """Bridge spaces.GPU's per-call (*args, **kwargs) → our (mode, params, multiplier) estimator.
164
 
165
- spaces.GPU(duration=callable) invokes the callable with the handler's actual
166
- runtime args. The handlers here have signature roughly:
167
- on_<mode>_click(prompt_or_seed, lyrics_or_other, duration_s, ...)
168
- so duration_s is at position 2 by default. The kwargs path also works.
169
  """
170
 
171
  def from_call(*args, **kwargs):
172
- duration_s = kwargs.get("duration_s")
173
- if duration_s is None and len(args) > duration_arg_index:
174
- candidate = args[duration_arg_index]
175
- if isinstance(candidate, (int, float)):
176
- duration_s = candidate
 
 
 
 
 
177
  return _estimate_gpu_duration(mode, {"duration_s": duration_s})
178
 
179
  return from_call
 
159
  return max(_GPU_CLAMP_MIN, min(_GPU_CLAMP_MAX, int(estimated)))
160
 
161
 
162
+ # Per-mode hints for where the duration is in the handler's call args.
163
+ # Each entry: (positional_index, kwarg_name).
164
+ # For "edit" mode, the duration is computed as (segment_end_s - segment_start_s).
165
+ # For "lyrics", there's no audio duration; we just default.
166
+ _GPU_DURATION_HINTS: dict[str, tuple[int, str] | str | None] = {
167
+ "generate": (2, "duration_s"),
168
+ "cover": (3, "duration_s"),
169
+ "extend": (3, "extra_duration_s"),
170
+ "edit": "segment_window", # special: end - start
171
+ "lyrics": None, # no audio length
172
+ }
173
+
174
+
175
+ def _extract_duration_s(mode: str, args: tuple, kwargs: dict) -> float | None:
176
+ """Pull the requested audio duration out of a handler's call args, mode-aware.
177
+
178
+ Returns None when the mode has no audio duration concept (lyrics) or when
179
+ the value can't be found. Caller falls back to a per-mode default.
180
+ """
181
+ hint = _GPU_DURATION_HINTS.get(mode)
182
+ if hint is None:
183
+ return None
184
+
185
+ if hint == "segment_window":
186
+ # edit: (source_audio, sub_mode, source_lyrics, target_lyrics, segment_start_s, segment_end_s, ...)
187
+ start = kwargs.get("segment_start_s")
188
+ end = kwargs.get("segment_end_s")
189
+ if start is None and len(args) > 4:
190
+ start = args[4] if isinstance(args[4], (int, float)) else None
191
+ if end is None and len(args) > 5:
192
+ end = args[5] if isinstance(args[5], (int, float)) else None
193
+ if start is not None and end is not None:
194
+ window = float(end) - float(start)
195
+ return window if window > 0 else None
196
+ return None
197
+
198
+ pos_idx, kw_name = hint
199
+ if kw_name in kwargs and isinstance(kwargs[kw_name], (int, float)):
200
+ return float(kwargs[kw_name])
201
+ if len(args) > pos_idx and isinstance(args[pos_idx], (int, float)):
202
+ return float(args[pos_idx])
203
+ return None
204
+
205
+
206
+ def _gpu_call_to_estimator(mode: str):
207
  """Bridge spaces.GPU's per-call (*args, **kwargs) → our (mode, params, multiplier) estimator.
208
 
209
+ Per-mode duration extraction handles the different signatures of the five
210
+ handlers. Falls back to a per-mode default when extraction fails so the
211
+ estimator still produces a reasonable timeout.
 
212
  """
213
 
214
  def from_call(*args, **kwargs):
215
+ duration_s = _extract_duration_s(mode, args, kwargs)
216
+ if duration_s is None:
217
+ # Per-mode default when no duration found in call args.
218
+ duration_s = {
219
+ "generate": 30.0,
220
+ "cover": 30.0,
221
+ "extend": 20.0,
222
+ "edit": 8.0, # typical edit segment window
223
+ "lyrics": 0.0, # no audio; base alone
224
+ }.get(mode, 30.0)
225
  return _estimate_gpu_duration(mode, {"duration_s": duration_s})
226
 
227
  return from_call
tests/test_gpu_estimator.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for the per-mode GPU duration extraction."""
2
+
3
+ from __future__ import annotations
4
+
5
+
6
+ def test_extract_generate_positional():
7
+ from app import _extract_duration_s
8
+
9
+ args = ("a prompt", "lyrics body", 45.0, "vocal_male", None)
10
+ kwargs = {}
11
+ assert _extract_duration_s("generate", args, kwargs) == 45.0
12
+
13
+
14
+ def test_extract_cover_at_index_3():
15
+ from app import _extract_duration_s
16
+
17
+ args = ("ref.wav", "new style", "new lyrics", 60.0)
18
+ kwargs = {}
19
+ assert _extract_duration_s("cover", args, kwargs) == 60.0
20
+
21
+
22
+ def test_extract_extend_uses_extra_duration_s_kwarg():
23
+ from app import _extract_duration_s
24
+
25
+ args = ("seed.wav", "more of the same", "extension lyrics", 25.0)
26
+ kwargs = {}
27
+ assert _extract_duration_s("extend", args, kwargs) == 25.0
28
+
29
+
30
+ def test_extract_extend_kwarg_form():
31
+ from app import _extract_duration_s
32
+
33
+ assert _extract_duration_s("extend", (), {"extra_duration_s": 18.5}) == 18.5
34
+
35
+
36
+ def test_extract_edit_segment_window():
37
+ from app import _extract_duration_s
38
+
39
+ args = ("src.wav", "repaint", "src lyrics", "new lyrics", 10.0, 22.5)
40
+ kwargs = {}
41
+ assert _extract_duration_s("edit", args, kwargs) == 12.5
42
+
43
+
44
+ def test_extract_edit_kwarg_window():
45
+ from app import _extract_duration_s
46
+
47
+ kwargs = {"segment_start_s": 5.0, "segment_end_s": 20.0}
48
+ assert _extract_duration_s("edit", (), kwargs) == 15.0
49
+
50
+
51
+ def test_extract_lyrics_returns_none():
52
+ from app import _extract_duration_s
53
+
54
+ assert _extract_duration_s("lyrics", ("brief", "ABAB"), {}) is None
55
+
56
+
57
+ def test_extract_generate_falls_back_when_missing():
58
+ from app import _extract_duration_s
59
+
60
+ # No positional duration, no kwarg → None
61
+ assert _extract_duration_s("generate", ("p", "l"), {}) is None
62
+
63
+
64
+ def test_estimator_clamp_floor():
65
+ from app import _estimate_gpu_duration
66
+
67
+ # lyrics base=15 + 1.0*2 = 17 → clamped up to 60s floor.
68
+ assert _estimate_gpu_duration("lyrics", {"duration_s": 1.0}) == 60
69
+
70
+
71
+ def test_estimator_clamp_ceiling():
72
+ from app import _estimate_gpu_duration
73
+
74
+ # 240s requested * 2 = 480 + base 30 = 510 → clamped to 300
75
+ assert _estimate_gpu_duration("generate", {"duration_s": 240}) == 300
76
+
77
+
78
+ def test_estimator_mode_specific_base():
79
+ from app import _estimate_gpu_duration
80
+
81
+ # 30s requested * 2 = 60 + base 40 (cover) = 100s
82
+ assert _estimate_gpu_duration("cover", {"duration_s": 30}) == 100