techfreakworm commited on
Commit
ecc1595
·
verified ·
1 Parent(s): de6853a

feat(spaces): user-controlled GPU budget slider + pre-flight gate

Browse files

The fixed 240 s ceiling in `_duration_for` was the only thing standing
between heavy modes (style, lipsync at high frame counts) and HF Pro+ /
credit-topped accounts whose per-call cap actually allows longer
durations. Replace it with a per-call budget the user picks, plus a
pre-flight gate so jobs that won't fit are refused before submitting —
no GPU time wasted after several minutes of compute.

- `_duration_for` honours `user_budget` when set; otherwise returns the
unclamped estimate (only the 60 s floor remains).
- `_estimate_duration_unclamped` helper exposes the formula for the
pre-flight diff message in app.py.
- GPU-budget slider in the sidebar drawer: 60-1800 s, default 240 s,
step 30 s, placed under the mode list for discoverability.
- Pre-flight refuses with a friendly diff: "config estimates ~X s,
slider is Y s" - user raises the slider or reduces preset/frames.
- `illegal_duration` + `gpu_timeout` friendly errors now point at the
slider as the first thing to try.
- Drops the dead `duration_multiplier` arg (leftover from the removed
auto-retry path) so the signature stays clean.

Tests cover: user_budget override, the 60 s floor clamp, the unclamped
estimate formula, and the unknown-mode default.

Files changed (3) hide show
  1. app.py +48 -8
  2. backend.py +41 -25
  3. tests/test_duration.py +82 -0
app.py CHANGED
@@ -523,6 +523,18 @@ def build_app() -> gr.Blocks:
523
  )
524
  for name, m in modes.MODE_REGISTRY.items()
525
  }
 
 
 
 
 
 
 
 
 
 
 
 
526
  gr.Markdown("Models", elem_classes=["aio-drawer-heading"])
527
  model_status = gr.HTML(_render_model_status_idle(), elem_id="aio-model-status")
528
  refresh_btn = gr.Button("Refresh", size="sm", variant="secondary")
@@ -539,9 +551,11 @@ def build_app() -> gr.Blocks:
539
  with gr.Column(scale=4, elem_classes=["aio-body"]):
540
  handles, tabs_component = _render_mode_panels()
541
 
542
- # Wire generate buttons
 
 
543
  for name, h in handles.items():
544
- inputs = _collect_inputs_for_mode(name, h)
545
  h["generate_btn"].click(
546
  fn=_make_handler(name, h),
547
  inputs=inputs,
@@ -818,8 +832,9 @@ PRESET_DURATION = {"Fast": 60, "Balanced": 120, "Quality": 300}
818
  _FRIENDLY_ERRORS: dict[str, tuple[str, str]] = {
819
  "gpu_timeout": (
820
  "Hit the GPU time limit",
821
- "This run took longer than the GPU budget. Try the Fast preset, a "
822
- "shorter video, or a smaller resolution then click Generate again.",
 
823
  ),
824
  "expired_token": (
825
  "Session timed out",
@@ -827,9 +842,10 @@ _FRIENDLY_ERRORS: dict[str, tuple[str, str]] = {
827
  "you'll keep your spot in the GPU queue.",
828
  ),
829
  "illegal_duration": (
830
- "GPU budget too high",
831
- "The estimator asked for more GPU time than the server allows. "
832
- "Try Fast preset or a shorter video.",
 
833
  ),
834
  "unlogged": (
835
  "Sign-in not detected",
@@ -956,6 +972,29 @@ async def _on_generate(mode_name: str, *, progress: Any = None, **inputs: Any):
956
  backend = _get_backend()
957
  preset = params["preset"] # already lowercased above
958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
959
  async def _translate(event, started_at):
960
  """Translate one backend event into Gradio (status_html, video) yields.
961
 
@@ -1008,7 +1047,7 @@ async def _on_generate(mode_name: str, *, progress: Any = None, **inputs: Any):
1008
  started = time.time()
1009
  async for event in backend.submit(
1010
  mode_name, workflow,
1011
- preset=preset, duration_multiplier=1.0,
1012
  progress=progress,
1013
  ):
1014
  translated = await _translate(event, started)
@@ -1034,6 +1073,7 @@ def _input_keys_for_mode(mode_name: str, h: dict) -> list[str]:
1034
  base.extend(["ic_lora", "ic_strength"])
1035
  if h["lora"].pose_on is not None:
1036
  base.append("pose_on")
 
1037
  return base
1038
 
1039
 
 
523
  )
524
  for name, m in modes.MODE_REGISTRY.items()
525
  }
526
+ # ZeroGPU per-call cap, placed right under the mode list so
527
+ # it's visible without scrolling. The pre-flight gate in
528
+ # _on_generate refuses calls whose estimate exceeds this.
529
+ gpu_budget_slider = gr.Slider(
530
+ minimum=60,
531
+ maximum=1800,
532
+ value=240,
533
+ step=30,
534
+ label="GPU budget (seconds)",
535
+ info="Max GPU time per generation. Higher = heavy modes fit; uses more of your daily quota per call.",
536
+ elem_classes=["aio-gpu-budget"],
537
+ )
538
  gr.Markdown("Models", elem_classes=["aio-drawer-heading"])
539
  model_status = gr.HTML(_render_model_status_idle(), elem_id="aio-model-status")
540
  refresh_btn = gr.Button("Refresh", size="sm", variant="secondary")
 
551
  with gr.Column(scale=4, elem_classes=["aio-body"]):
552
  handles, tabs_component = _render_mode_panels()
553
 
554
+ # Wire generate buttons. The GPU-budget slider lives in the drawer and
555
+ # is the same instance for every mode — append it last so the handler
556
+ # receives it as `gpu_budget` (see `_input_keys_for_mode`).
557
  for name, h in handles.items():
558
+ inputs = _collect_inputs_for_mode(name, h) + [gpu_budget_slider]
559
  h["generate_btn"].click(
560
  fn=_make_handler(name, h),
561
  inputs=inputs,
 
832
  _FRIENDLY_ERRORS: dict[str, tuple[str, str]] = {
833
  "gpu_timeout": (
834
  "Hit the GPU time limit",
835
+ "This run took longer than the GPU budget. Raise the GPU-budget "
836
+ "slider (in the sidebar), or try the Fast preset / a shorter video, "
837
+ "then click Generate again.",
838
  ),
839
  "expired_token": (
840
  "Session timed out",
 
842
  "you'll keep your spot in the GPU queue.",
843
  ),
844
  "illegal_duration": (
845
+ "GPU budget too high for your account",
846
+ "HF rejected the requested duration as exceeding your account's "
847
+ "per-call cap. Lower the GPU-budget slider (sidebar) and try again, "
848
+ "or drop the preset / shorten the video.",
849
  ),
850
  "unlogged": (
851
  "Sign-in not detected",
 
972
  backend = _get_backend()
973
  preset = params["preset"] # already lowercased above
974
 
975
+ # Pre-flight gate: refuse to submit if the estimator says this config
976
+ # needs more GPU time than the user has allocated. ZeroGPU charges actual
977
+ # usage, not declared duration, so under-allocating means the call still
978
+ # burns quota before timing out. Refuse here and tell the user to either
979
+ # bump the GPU-budget slider or reduce frames/preset.
980
+ user_budget: int | None = None
981
+ if "gpu_budget" in inputs and inputs["gpu_budget"] is not None:
982
+ user_budget = int(inputs["gpu_budget"])
983
+ estimate = backend_module._estimate_duration_unclamped(
984
+ mode=mode_name, preset=preset, frames=frames,
985
+ )
986
+ if estimate > user_budget:
987
+ yield (
988
+ f'<div class="status-card status-error">'
989
+ f' <div class="status-row"><span class="status-stage">GPU budget too low</span></div>'
990
+ f" <div>This config estimates ~{estimate}s of GPU time, but the "
991
+ f"GPU-budget slider is set to {user_budget}s. Raise the slider, drop "
992
+ f"the preset to Fast or Balanced, or reduce the duration / frame count.</div>"
993
+ f"</div>",
994
+ gr.update(),
995
+ )
996
+ return
997
+
998
  async def _translate(event, started_at):
999
  """Translate one backend event into Gradio (status_html, video) yields.
1000
 
 
1047
  started = time.time()
1048
  async for event in backend.submit(
1049
  mode_name, workflow,
1050
+ preset=preset, user_budget=user_budget,
1051
  progress=progress,
1052
  ):
1053
  translated = await _translate(event, started)
 
1073
  base.extend(["ic_lora", "ic_strength"])
1074
  if h["lora"].pose_on is not None:
1075
  base.append("pose_on")
1076
+ base.append("gpu_budget") # appended by build_app() from the global slider
1077
  return base
1078
 
1079
 
backend.py CHANGED
@@ -93,33 +93,49 @@ def _frames_from_workflow(workflow: dict) -> int:
93
  return 121
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
96
  def _duration_for(
97
  executor: Any,
98
  workflow: dict,
99
  output_ids: list[str],
100
  mode: str,
101
  preset: str,
102
- multiplier: float = 1.0,
103
  progress: Any = None,
 
104
  ) -> int:
105
- """ZeroGPU duration estimator. Same signature as _execute_workflow.
106
-
107
- `progress` is a gr.Progress instance forwarded by the caller; we ignore it
108
- here (estimator doesn't emit progress) but must accept it positionally so
109
- ZeroGPU can call us with the same arg list it'll use for _execute_workflow.
110
-
111
- Estimate = (base × preset multiplier + cold-cache buffer + per-frame VAE
112
- decode time) × retry multiplier, clamped to [60s, 240s]. ZeroGPU rejects
113
- durations above the server's per-call max with "ZeroGPU illegal duration"
114
- (client.py:137); 240s is observed to work for Pro identity (~2 min runs
115
- needed for style + lipsync detailer paths). If the server rejects values
116
- in this range, the user will see a clear error and can retry.
 
 
 
 
 
 
117
  """
118
- base = _BASE_DURATION_S.get(mode, 180)
119
- mult = _PRESET_MULT.get(preset.lower(), 1.5)
120
  frames = _frames_from_workflow(workflow)
121
- est = int((base * mult + 60 + frames * 0.3) * multiplier)
122
- return max(60, min(est, 240))
123
 
124
 
125
  # Decorate at module load time so ZeroGPU's startup analyzer detects it.
@@ -137,14 +153,14 @@ def _execute_workflow(
137
  output_ids: list[str],
138
  mode: str,
139
  preset: str,
140
- multiplier: float = 1.0,
141
  progress: Any = None,
 
142
  ) -> str:
143
  """Run the workflow on GPU and return the path of the first video output.
144
 
145
  Returns just the video path (a plain string, picklable across the
146
- @spaces.GPU subprocess boundary). The `mode`, `preset`, and `multiplier`
147
- args are consumed by `_duration_for` to estimate the GPU slot to reserve.
148
 
149
  `progress` is an optional `gr.Progress` instance. It's the only progress
150
  channel that crosses the @spaces.GPU subprocess boundary on HF Spaces —
@@ -384,15 +400,15 @@ class ComfyUILibraryBackend:
384
  workflow: dict,
385
  *,
386
  preset: str = "balanced",
387
- duration_multiplier: float = 1.0,
388
  gpu_duration: int = 0, # legacy, ignored (now derived from preset+frames)
389
  progress: Any = None,
390
  ) -> AsyncIterator[Any]:
391
  """Run a workflow end-to-end. Yields Download/Progress/Output/Error events.
392
 
393
- `preset` and `duration_multiplier` flow through to the @spaces.GPU
394
- duration estimator. The handler can re-call submit() with
395
- duration_multiplier=2.0 if the first attempt aborts on timeout.
396
  """
397
  # Pre-flight: ensure all model files exist.
398
  try:
@@ -467,7 +483,7 @@ class ComfyUILibraryBackend:
467
  # light calls get fast queue priority while heavy ones reserve
468
  # real headroom. Off-Spaces it's a plain call.
469
  video_path = _execute_workflow(
470
- self._executor, workflow, output_ids, mode, preset, duration_multiplier, progress,
471
  )
472
  # Fallback: if history_result didn't surface a path (rare on
473
  # Spaces — happens when ZeroGPU's subprocess boundary drops
 
93
  return 121
94
 
95
 
96
+ def _estimate_duration_unclamped(*, mode: str, preset: str, frames: int) -> int:
97
+ """Estimator formula minus the 60 s floor.
98
+
99
+ Used by the UI's pre-flight gate so it can show "this config needs ~Xs"
100
+ without re-implementing the constants in app.py.
101
+ """
102
+ base = _BASE_DURATION_S.get(mode, 180)
103
+ mult = _PRESET_MULT.get(preset.lower(), 1.5)
104
+ return int(base * mult + 60 + frames * 0.3)
105
+
106
+
107
  def _duration_for(
108
  executor: Any,
109
  workflow: dict,
110
  output_ids: list[str],
111
  mode: str,
112
  preset: str,
 
113
  progress: Any = None,
114
+ user_budget: int | None = None,
115
  ) -> int:
116
+ """ZeroGPU per-call duration. Same signature as _execute_workflow.
117
+
118
+ `progress` is a `gr.Progress` instance forwarded by the caller; we ignore it
119
+ here but must accept it so ZeroGPU calls us with the same arg list it uses
120
+ for `_execute_workflow`.
121
+
122
+ When `user_budget` is set, it overrides the estimator the user has decided
123
+ how much of their ZeroGPU quota to spend on this call. Clamped to ≥ 60 s
124
+ (HF's documented per-call floor); no upper clamp, so the user can declare
125
+ up to whatever their account tier actually allows. If they exceed the
126
+ account cap, HF raises "ZeroGPU illegal duration" and the UI surfaces it
127
+ via the `illegal_duration` friendly-error category.
128
+
129
+ Without `user_budget`, returns the unclamped estimate (base × preset
130
+ multiplier + cold-cache buffer + per-frame VAE decode). The pre-flight
131
+ gate in app.py refuses calls whose estimate exceeds the user-chosen
132
+ budget — so by the time we get here, either the user opted in or there
133
+ was no override.
134
  """
135
+ if user_budget is not None:
136
+ return max(60, int(user_budget))
137
  frames = _frames_from_workflow(workflow)
138
+ return max(60, _estimate_duration_unclamped(mode=mode, preset=preset, frames=frames))
 
139
 
140
 
141
  # Decorate at module load time so ZeroGPU's startup analyzer detects it.
 
153
  output_ids: list[str],
154
  mode: str,
155
  preset: str,
 
156
  progress: Any = None,
157
+ user_budget: int | None = None,
158
  ) -> str:
159
  """Run the workflow on GPU and return the path of the first video output.
160
 
161
  Returns just the video path (a plain string, picklable across the
162
+ @spaces.GPU subprocess boundary). The `mode`, `preset`, and `user_budget`
163
+ args are consumed by `_duration_for` to set the per-call GPU slot.
164
 
165
  `progress` is an optional `gr.Progress` instance. It's the only progress
166
  channel that crosses the @spaces.GPU subprocess boundary on HF Spaces —
 
400
  workflow: dict,
401
  *,
402
  preset: str = "balanced",
403
+ user_budget: int | None = None,
404
  gpu_duration: int = 0, # legacy, ignored (now derived from preset+frames)
405
  progress: Any = None,
406
  ) -> AsyncIterator[Any]:
407
  """Run a workflow end-to-end. Yields Download/Progress/Output/Error events.
408
 
409
+ `preset` and `user_budget` flow through to the @spaces.GPU duration
410
+ estimator. When `user_budget` is set the user has opted in to a
411
+ specific per-call GPU time cap; otherwise the estimator picks one.
412
  """
413
  # Pre-flight: ensure all model files exist.
414
  try:
 
483
  # light calls get fast queue priority while heavy ones reserve
484
  # real headroom. Off-Spaces it's a plain call.
485
  video_path = _execute_workflow(
486
+ self._executor, workflow, output_ids, mode, preset, progress, user_budget,
487
  )
488
  # Fallback: if history_result didn't surface a path (rare on
489
  # Spaces — happens when ZeroGPU's subprocess boundary drops
tests/test_duration.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the ZeroGPU per-call duration estimator + user-budget override."""
2
+
3
+ import backend
4
+
5
+
6
+ def _t2v_workflow(frames: int = 121) -> dict:
7
+ return {
8
+ "100": {
9
+ "class_type": "EmptyLTXVLatentVideo",
10
+ "inputs": {"length": frames, "width": 512, "height": 512},
11
+ }
12
+ }
13
+
14
+
15
+ def test_duration_for_uses_user_budget_when_set() -> None:
16
+ # 600s should pass through verbatim — the user knows what they're spending.
17
+ assert (
18
+ backend._duration_for(
19
+ executor=None,
20
+ workflow=_t2v_workflow(),
21
+ output_ids=[],
22
+ mode="t2v",
23
+ preset="fast",
24
+ user_budget=600,
25
+ )
26
+ == 600
27
+ )
28
+
29
+
30
+ def test_duration_for_clamps_user_budget_to_floor() -> None:
31
+ # 30s below the 60s ZeroGPU floor — clamp up, never below.
32
+ assert (
33
+ backend._duration_for(
34
+ executor=None,
35
+ workflow=_t2v_workflow(),
36
+ output_ids=[],
37
+ mode="t2v",
38
+ preset="fast",
39
+ user_budget=30,
40
+ )
41
+ == 60
42
+ )
43
+
44
+
45
+ def test_duration_for_no_budget_returns_unclamped_estimate() -> None:
46
+ # style/quality/121 frames: 360*3 + 60 + 121*0.3 = 1176.3 -> int 1176.
47
+ # No upper ceiling — the whole point of the user-budget refactor.
48
+ result = backend._duration_for(
49
+ executor=None,
50
+ workflow=_t2v_workflow(frames=121),
51
+ output_ids=[],
52
+ mode="style",
53
+ preset="quality",
54
+ )
55
+ assert result == 1176
56
+
57
+
58
+ def test_duration_for_no_budget_honours_floor() -> None:
59
+ # 1-frame t2v/fast: 90*1 + 60 + 0.3 = 150 -> int 150; well above floor, so
60
+ # this is really testing that the floor doesn't accidentally fire on real
61
+ # workloads. (See test_duration_for_clamps_user_budget_to_floor for the
62
+ # actual floor case via user_budget.)
63
+ result = backend._duration_for(
64
+ executor=None,
65
+ workflow=_t2v_workflow(frames=1),
66
+ output_ids=[],
67
+ mode="t2v",
68
+ preset="fast",
69
+ )
70
+ assert result == 150
71
+
72
+
73
+ def test_estimate_duration_unclamped_matches_formula() -> None:
74
+ # Surface the formula so the pre-flight gate in app.py can show the user
75
+ # "needs X seconds" without re-implementing it.
76
+ assert backend._estimate_duration_unclamped(mode="t2v", preset="fast", frames=121) == 90 + 60 + int(121 * 0.3)
77
+ assert backend._estimate_duration_unclamped(mode="style", preset="quality", frames=121) == int(360 * 3.0 + 60 + 121 * 0.3)
78
+
79
+
80
+ def test_estimate_duration_unclamped_unknown_mode_uses_default() -> None:
81
+ # Unknown mode -> default base 180. Preset still applies.
82
+ assert backend._estimate_duration_unclamped(mode="nonsense", preset="balanced", frames=100) == int(180 * 1.5 + 60 + 100 * 0.3)