techfreakworm commited on
Commit
a2c44d8
·
verified ·
1 Parent(s): a5c8e5e

Revert "feat(spaces): user-controlled GPU budget slider + pre-flight gate"

Browse files

This reverts commit ecc159588e928140424b9f985f7f40d5fbc03157.

Files changed (3) hide show
  1. app.py +8 -48
  2. backend.py +25 -41
  3. tests/test_duration.py +0 -82
app.py CHANGED
@@ -556,18 +556,6 @@ def build_app() -> gr.Blocks:
556
  )
557
  for name, m in modes.MODE_REGISTRY.items()
558
  }
559
- # ZeroGPU per-call cap, placed right under the mode list so
560
- # it's visible without scrolling. The pre-flight gate in
561
- # _on_generate refuses calls whose estimate exceeds this.
562
- gpu_budget_slider = gr.Slider(
563
- minimum=60,
564
- maximum=1800,
565
- value=240,
566
- step=30,
567
- label="GPU budget (seconds)",
568
- info="Max GPU time per generation. Higher = heavy modes fit; uses more of your daily quota per call.",
569
- elem_classes=["aio-gpu-budget"],
570
- )
571
  gr.Markdown("Models", elem_classes=["aio-drawer-heading"])
572
  model_status = gr.HTML(_render_model_status_idle(), elem_id="aio-model-status")
573
  refresh_btn = gr.Button("Refresh", size="sm", variant="secondary")
@@ -593,11 +581,9 @@ def build_app() -> gr.Blocks:
593
  with gr.Column(scale=4, elem_classes=["aio-body"]):
594
  handles, tabs_component = _render_mode_panels()
595
 
596
- # Wire generate buttons. The GPU-budget slider lives in the drawer and
597
- # is the same instance for every mode — append it last so the handler
598
- # receives it as `gpu_budget` (see `_input_keys_for_mode`).
599
  for name, h in handles.items():
600
- inputs = _collect_inputs_for_mode(name, h) + [gpu_budget_slider]
601
  h["generate_btn"].click(
602
  fn=_make_handler(name, h),
603
  inputs=inputs,
@@ -874,9 +860,8 @@ PRESET_DURATION = {"Fast": 60, "Balanced": 120, "Quality": 300}
874
  _FRIENDLY_ERRORS: dict[str, tuple[str, str]] = {
875
  "gpu_timeout": (
876
  "Hit the GPU time limit",
877
- "This run took longer than the GPU budget. Raise the GPU-budget "
878
- "slider (in the sidebar), or try the Fast preset / a shorter video, "
879
- "then click Generate again.",
880
  ),
881
  "expired_token": (
882
  "Session timed out",
@@ -884,10 +869,9 @@ _FRIENDLY_ERRORS: dict[str, tuple[str, str]] = {
884
  "you'll keep your spot in the GPU queue.",
885
  ),
886
  "illegal_duration": (
887
- "GPU budget too high for your account",
888
- "HF rejected the requested duration as exceeding your account's "
889
- "per-call cap. Lower the GPU-budget slider (sidebar) and try again, "
890
- "or drop the preset / shorten the video.",
891
  ),
892
  "unlogged": (
893
  "Sign-in not detected",
@@ -1014,29 +998,6 @@ async def _on_generate(mode_name: str, *, progress: Any = None, **inputs: Any):
1014
  backend = _get_backend()
1015
  preset = params["preset"] # already lowercased above
1016
 
1017
- # Pre-flight gate: refuse to submit if the estimator says this config
1018
- # needs more GPU time than the user has allocated. ZeroGPU charges actual
1019
- # usage, not declared duration, so under-allocating means the call still
1020
- # burns quota before timing out. Refuse here and tell the user to either
1021
- # bump the GPU-budget slider or reduce frames/preset.
1022
- user_budget: int | None = None
1023
- if "gpu_budget" in inputs and inputs["gpu_budget"] is not None:
1024
- user_budget = int(inputs["gpu_budget"])
1025
- estimate = backend_module._estimate_duration_unclamped(
1026
- mode=mode_name, preset=preset, frames=frames,
1027
- )
1028
- if estimate > user_budget:
1029
- yield (
1030
- f'<div class="status-card status-error">'
1031
- f' <div class="status-row"><span class="status-stage">GPU budget too low</span></div>'
1032
- f" <div>This config estimates ~{estimate}s of GPU time, but the "
1033
- f"GPU-budget slider is set to {user_budget}s. Raise the slider, drop "
1034
- f"the preset to Fast or Balanced, or reduce the duration / frame count.</div>"
1035
- f"</div>",
1036
- gr.update(),
1037
- )
1038
- return
1039
-
1040
  async def _translate(event, started_at):
1041
  """Translate one backend event into Gradio (status_html, video) yields.
1042
 
@@ -1089,7 +1050,7 @@ async def _on_generate(mode_name: str, *, progress: Any = None, **inputs: Any):
1089
  started = time.time()
1090
  async for event in backend.submit(
1091
  mode_name, workflow,
1092
- preset=preset, user_budget=user_budget,
1093
  progress=progress,
1094
  ):
1095
  translated = await _translate(event, started)
@@ -1115,7 +1076,6 @@ def _input_keys_for_mode(mode_name: str, h: dict) -> list[str]:
1115
  base.extend(["ic_lora", "ic_strength"])
1116
  if h["lora"].pose_on is not None:
1117
  base.append("pose_on")
1118
- base.append("gpu_budget") # appended by build_app() from the global slider
1119
  return base
1120
 
1121
 
 
556
  )
557
  for name, m in modes.MODE_REGISTRY.items()
558
  }
 
 
 
 
 
 
 
 
 
 
 
 
559
  gr.Markdown("Models", elem_classes=["aio-drawer-heading"])
560
  model_status = gr.HTML(_render_model_status_idle(), elem_id="aio-model-status")
561
  refresh_btn = gr.Button("Refresh", size="sm", variant="secondary")
 
581
  with gr.Column(scale=4, elem_classes=["aio-body"]):
582
  handles, tabs_component = _render_mode_panels()
583
 
584
+ # Wire generate buttons
 
 
585
  for name, h in handles.items():
586
+ inputs = _collect_inputs_for_mode(name, h)
587
  h["generate_btn"].click(
588
  fn=_make_handler(name, h),
589
  inputs=inputs,
 
860
  _FRIENDLY_ERRORS: dict[str, tuple[str, str]] = {
861
  "gpu_timeout": (
862
  "Hit the GPU time limit",
863
+ "This run took longer than the GPU budget. Try the Fast preset, a "
864
+ "shorter video, or a smaller resolution then click Generate again.",
 
865
  ),
866
  "expired_token": (
867
  "Session timed out",
 
869
  "you'll keep your spot in the GPU queue.",
870
  ),
871
  "illegal_duration": (
872
+ "GPU budget too high",
873
+ "The estimator asked for more GPU time than the server allows. "
874
+ "Try Fast preset or a shorter video.",
 
875
  ),
876
  "unlogged": (
877
  "Sign-in not detected",
 
998
  backend = _get_backend()
999
  preset = params["preset"] # already lowercased above
1000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1001
  async def _translate(event, started_at):
1002
  """Translate one backend event into Gradio (status_html, video) yields.
1003
 
 
1050
  started = time.time()
1051
  async for event in backend.submit(
1052
  mode_name, workflow,
1053
+ preset=preset, duration_multiplier=1.0,
1054
  progress=progress,
1055
  ):
1056
  translated = await _translate(event, started)
 
1076
  base.extend(["ic_lora", "ic_strength"])
1077
  if h["lora"].pose_on is not None:
1078
  base.append("pose_on")
 
1079
  return base
1080
 
1081
 
backend.py CHANGED
@@ -93,49 +93,33 @@ def _frames_from_workflow(workflow: dict) -> int:
93
  return 121
94
 
95
 
96
- def _estimate_duration_unclamped(*, mode: str, preset: str, frames: int) -> int:
97
- """Estimator formula minus the 60 s floor.
98
-
99
- Used by the UI's pre-flight gate so it can show "this config needs ~Xs"
100
- without re-implementing the constants in app.py.
101
- """
102
- base = _BASE_DURATION_S.get(mode, 180)
103
- mult = _PRESET_MULT.get(preset.lower(), 1.5)
104
- return int(base * mult + 60 + frames * 0.3)
105
-
106
-
107
  def _duration_for(
108
  executor: Any,
109
  workflow: dict,
110
  output_ids: list[str],
111
  mode: str,
112
  preset: str,
 
113
  progress: Any = None,
114
- user_budget: int | None = None,
115
  ) -> int:
116
- """ZeroGPU per-call duration. Same signature as _execute_workflow.
117
-
118
- `progress` is a `gr.Progress` instance forwarded by the caller; we ignore it
119
- here but must accept it so ZeroGPU calls us with the same arg list it uses
120
- for `_execute_workflow`.
121
-
122
- When `user_budget` is set, it overrides the estimator the user has decided
123
- how much of their ZeroGPU quota to spend on this call. Clamped to ≥ 60 s
124
- (HF's documented per-call floor); no upper clamp, so the user can declare
125
- up to whatever their account tier actually allows. If they exceed the
126
- account cap, HF raises "ZeroGPU illegal duration" and the UI surfaces it
127
- via the `illegal_duration` friendly-error category.
128
-
129
- Without `user_budget`, returns the unclamped estimate (base × preset
130
- multiplier + cold-cache buffer + per-frame VAE decode). The pre-flight
131
- gate in app.py refuses calls whose estimate exceeds the user-chosen
132
- budget — so by the time we get here, either the user opted in or there
133
- was no override.
134
  """
135
- if user_budget is not None:
136
- return max(60, int(user_budget))
137
  frames = _frames_from_workflow(workflow)
138
- return max(60, _estimate_duration_unclamped(mode=mode, preset=preset, frames=frames))
 
139
 
140
 
141
  # Decorate at module load time so ZeroGPU's startup analyzer detects it.
@@ -153,14 +137,14 @@ def _execute_workflow(
153
  output_ids: list[str],
154
  mode: str,
155
  preset: str,
 
156
  progress: Any = None,
157
- user_budget: int | None = None,
158
  ) -> str:
159
  """Run the workflow on GPU and return the path of the first video output.
160
 
161
  Returns just the video path (a plain string, picklable across the
162
- @spaces.GPU subprocess boundary). The `mode`, `preset`, and `user_budget`
163
- args are consumed by `_duration_for` to set the per-call GPU slot.
164
 
165
  `progress` is an optional `gr.Progress` instance. It's the only progress
166
  channel that crosses the @spaces.GPU subprocess boundary on HF Spaces —
@@ -400,15 +384,15 @@ class ComfyUILibraryBackend:
400
  workflow: dict,
401
  *,
402
  preset: str = "balanced",
403
- user_budget: int | None = None,
404
  gpu_duration: int = 0, # legacy, ignored (now derived from preset+frames)
405
  progress: Any = None,
406
  ) -> AsyncIterator[Any]:
407
  """Run a workflow end-to-end. Yields Download/Progress/Output/Error events.
408
 
409
- `preset` and `user_budget` flow through to the @spaces.GPU duration
410
- estimator. When `user_budget` is set the user has opted in to a
411
- specific per-call GPU time cap; otherwise the estimator picks one.
412
  """
413
  # Pre-flight: ensure all model files exist.
414
  try:
@@ -483,7 +467,7 @@ class ComfyUILibraryBackend:
483
  # light calls get fast queue priority while heavy ones reserve
484
  # real headroom. Off-Spaces it's a plain call.
485
  video_path = _execute_workflow(
486
- self._executor, workflow, output_ids, mode, preset, progress, user_budget,
487
  )
488
  # Fallback: if history_result didn't surface a path (rare on
489
  # Spaces — happens when ZeroGPU's subprocess boundary drops
 
93
  return 121
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
96
  def _duration_for(
97
  executor: Any,
98
  workflow: dict,
99
  output_ids: list[str],
100
  mode: str,
101
  preset: str,
102
+ multiplier: float = 1.0,
103
  progress: Any = None,
 
104
  ) -> int:
105
+ """ZeroGPU duration estimator. Same signature as _execute_workflow.
106
+
107
+ `progress` is a gr.Progress instance forwarded by the caller; we ignore it
108
+ here (estimator doesn't emit progress) but must accept it positionally so
109
+ ZeroGPU can call us with the same arg list it'll use for _execute_workflow.
110
+
111
+ Estimate = (base × preset multiplier + cold-cache buffer + per-frame VAE
112
+ decode time) × retry multiplier, clamped to [60s, 240s]. ZeroGPU rejects
113
+ durations above the server's per-call max with "ZeroGPU illegal duration"
114
+ (client.py:137); 240s is observed to work for Pro identity (~2 min runs
115
+ needed for style + lipsync detailer paths). If the server rejects values
116
+ in this range, the user will see a clear error and can retry.
 
 
 
 
 
 
117
  """
118
+ base = _BASE_DURATION_S.get(mode, 180)
119
+ mult = _PRESET_MULT.get(preset.lower(), 1.5)
120
  frames = _frames_from_workflow(workflow)
121
+ est = int((base * mult + 60 + frames * 0.3) * multiplier)
122
+ return max(60, min(est, 240))
123
 
124
 
125
  # Decorate at module load time so ZeroGPU's startup analyzer detects it.
 
137
  output_ids: list[str],
138
  mode: str,
139
  preset: str,
140
+ multiplier: float = 1.0,
141
  progress: Any = None,
 
142
  ) -> str:
143
  """Run the workflow on GPU and return the path of the first video output.
144
 
145
  Returns just the video path (a plain string, picklable across the
146
+ @spaces.GPU subprocess boundary). The `mode`, `preset`, and `multiplier`
147
+ args are consumed by `_duration_for` to estimate the GPU slot to reserve.
148
 
149
  `progress` is an optional `gr.Progress` instance. It's the only progress
150
  channel that crosses the @spaces.GPU subprocess boundary on HF Spaces —
 
384
  workflow: dict,
385
  *,
386
  preset: str = "balanced",
387
+ duration_multiplier: float = 1.0,
388
  gpu_duration: int = 0, # legacy, ignored (now derived from preset+frames)
389
  progress: Any = None,
390
  ) -> AsyncIterator[Any]:
391
  """Run a workflow end-to-end. Yields Download/Progress/Output/Error events.
392
 
393
+ `preset` and `duration_multiplier` flow through to the @spaces.GPU
394
+ duration estimator. The handler can re-call submit() with
395
+ duration_multiplier=2.0 if the first attempt aborts on timeout.
396
  """
397
  # Pre-flight: ensure all model files exist.
398
  try:
 
467
  # light calls get fast queue priority while heavy ones reserve
468
  # real headroom. Off-Spaces it's a plain call.
469
  video_path = _execute_workflow(
470
+ self._executor, workflow, output_ids, mode, preset, duration_multiplier, progress,
471
  )
472
  # Fallback: if history_result didn't surface a path (rare on
473
  # Spaces — happens when ZeroGPU's subprocess boundary drops
tests/test_duration.py DELETED
@@ -1,82 +0,0 @@
1
- """Tests for the ZeroGPU per-call duration estimator + user-budget override."""
2
-
3
- import backend
4
-
5
-
6
- def _t2v_workflow(frames: int = 121) -> dict:
7
- return {
8
- "100": {
9
- "class_type": "EmptyLTXVLatentVideo",
10
- "inputs": {"length": frames, "width": 512, "height": 512},
11
- }
12
- }
13
-
14
-
15
- def test_duration_for_uses_user_budget_when_set() -> None:
16
- # 600s should pass through verbatim — the user knows what they're spending.
17
- assert (
18
- backend._duration_for(
19
- executor=None,
20
- workflow=_t2v_workflow(),
21
- output_ids=[],
22
- mode="t2v",
23
- preset="fast",
24
- user_budget=600,
25
- )
26
- == 600
27
- )
28
-
29
-
30
- def test_duration_for_clamps_user_budget_to_floor() -> None:
31
- # 30s below the 60s ZeroGPU floor — clamp up, never below.
32
- assert (
33
- backend._duration_for(
34
- executor=None,
35
- workflow=_t2v_workflow(),
36
- output_ids=[],
37
- mode="t2v",
38
- preset="fast",
39
- user_budget=30,
40
- )
41
- == 60
42
- )
43
-
44
-
45
- def test_duration_for_no_budget_returns_unclamped_estimate() -> None:
46
- # style/quality/121 frames: 360*3 + 60 + 121*0.3 = 1176.3 -> int 1176.
47
- # No upper ceiling — the whole point of the user-budget refactor.
48
- result = backend._duration_for(
49
- executor=None,
50
- workflow=_t2v_workflow(frames=121),
51
- output_ids=[],
52
- mode="style",
53
- preset="quality",
54
- )
55
- assert result == 1176
56
-
57
-
58
- def test_duration_for_no_budget_honours_floor() -> None:
59
- # 1-frame t2v/fast: 90*1 + 60 + 0.3 = 150 -> int 150; well above floor, so
60
- # this is really testing that the floor doesn't accidentally fire on real
61
- # workloads. (See test_duration_for_clamps_user_budget_to_floor for the
62
- # actual floor case via user_budget.)
63
- result = backend._duration_for(
64
- executor=None,
65
- workflow=_t2v_workflow(frames=1),
66
- output_ids=[],
67
- mode="t2v",
68
- preset="fast",
69
- )
70
- assert result == 150
71
-
72
-
73
- def test_estimate_duration_unclamped_matches_formula() -> None:
74
- # Surface the formula so the pre-flight gate in app.py can show the user
75
- # "needs X seconds" without re-implementing it.
76
- assert backend._estimate_duration_unclamped(mode="t2v", preset="fast", frames=121) == 90 + 60 + int(121 * 0.3)
77
- assert backend._estimate_duration_unclamped(mode="style", preset="quality", frames=121) == int(360 * 3.0 + 60 + 121 * 0.3)
78
-
79
-
80
- def test_estimate_duration_unclamped_unknown_mode_uses_default() -> None:
81
- # Unknown mode -> default base 180. Preset still applies.
82
- assert backend._estimate_duration_unclamped(mode="nonsense", preset="balanced", frames=100) == int(180 * 1.5 + 60 + 100 * 0.3)