techfreakworm commited on
Commit
14dcc06
·
unverified ·
1 Parent(s): d3e4302

feat(spaces): dynamic ZeroGPU duration + auto-retry on timeout

Browse files

Tier 1 — dynamic duration estimator:
@spaces.GPU(duration=callable) where the callable receives the same args
as _execute_workflow and returns int seconds. Estimate from
(mode, preset, frames) with per-mode base (t2v 90s, lipsync 240s,
style 360s, etc.), preset multiplier (fast 1×, balanced 1.5×, quality 3×),
+60s cold-cache buffer, +0.3s/frame VAE decode. Clamped to [60s, 900s].

Effect: light T2V calls reserve ~190s and jump the ZeroGPU queue;
heavy lipsync/style calls reserve real headroom only when needed.
No more one-size-fits-all 300s/600s wall.

Tier 2 — auto-retry on timeout:
ErrorEvent.category 'gpu_timeout' is now distinguished from generic
execution errors via _classify (matches "GPU task aborted" in the
exception message). The _on_generate handler catches this on the
first attempt only, shows a friendly "Retrying with extended budget"
banner, then re-submits with duration_multiplier=2.0 (which clamps to
the same 900s ceiling). One retry — if that also times out, surface
the error.

submit() signature: now takes preset and duration_multiplier kwargs,
old gpu_duration arg kept as ignored legacy.

Files changed (2) hide show
  1. app.py +74 -31
  2. backend.py +94 -22
app.py CHANGED
@@ -796,48 +796,91 @@ async def _on_generate(mode_name: str, **inputs: Any):
796
  wf_module.set_input(workflow, *patch)
797
 
798
  backend = _get_backend()
799
- duration = PRESET_DURATION.get(str(inputs.get("preset", "Balanced")), 120)
800
 
801
- started = time.time()
802
- async for event in backend.submit(mode_name, workflow, gpu_duration=duration):
803
- elapsed = time.time() - started
 
 
 
804
  if isinstance(event, backend_module.DownloadEvent):
805
- status = ui.render_status(
806
- stage_index=0,
807
- stage_label=f"Downloading {event.filename}",
808
- step=int(event.mb_done),
809
- total_steps=int(max(event.mb_total, 1)),
810
- elapsed_s=elapsed,
811
- eta_s=0,
 
 
 
812
  )
813
- yield status, gr.update()
814
- elif isinstance(event, backend_module.ProgressEvent):
815
- # Each sampler in the workflow gets its own stage label "Diffusion (n)".
816
- # The static `mode.stage_map` describes the full pipeline (encode →
817
- # diffusion → upscale → diffusion → decode) but our progress hook
818
- # only fires inside samplers, so we label by sampler index instead.
819
  label = f"Diffusion (Stage {event.stage})"
820
  eta = (elapsed / max(event.step, 1)) * (event.total_steps - event.step)
821
- status = ui.render_status(
822
- stage_index=event.stage,
823
- stage_label=label,
824
- step=event.step,
825
- total_steps=event.total_steps,
826
- elapsed_s=elapsed,
827
- eta_s=eta,
 
 
 
828
  )
829
- yield status, gr.update()
830
- elif isinstance(event, backend_module.OutputEvent):
831
  video_update = event.video_path if event.video_path else gr.update()
832
- yield ui._render_idle(), video_update
833
- elif isinstance(event, backend_module.ErrorEvent):
834
- error_html = (
835
  f'<div class="status-card status-error">'
836
  f' <div class="status-row"><span class="status-stage">Error · {event.category}</span></div>'
837
  f" <div>{event.message}</div>"
838
- f"</div>"
 
839
  )
840
- yield error_html, gr.update()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
841
 
842
 
843
  def _input_keys_for_mode(mode_name: str, h: dict) -> list[str]:
 
796
  wf_module.set_input(workflow, *patch)
797
 
798
  backend = _get_backend()
799
+ preset = params["preset"] # already lowercased above
800
 
801
+ async def _translate(event, started_at):
802
+ """Translate one backend event into Gradio (status_html, video) yields.
803
+
804
+ Returns the tuple to yield, plus a flag indicating terminal state.
805
+ """
806
+ elapsed = time.time() - started_at
807
  if isinstance(event, backend_module.DownloadEvent):
808
+ return (
809
+ ui.render_status(
810
+ stage_index=0,
811
+ stage_label=f"Downloading {event.filename}",
812
+ step=int(event.mb_done),
813
+ total_steps=int(max(event.mb_total, 1)),
814
+ elapsed_s=elapsed,
815
+ eta_s=0,
816
+ ),
817
+ gr.update(),
818
  )
819
+ if isinstance(event, backend_module.ProgressEvent):
 
 
 
 
 
820
  label = f"Diffusion (Stage {event.stage})"
821
  eta = (elapsed / max(event.step, 1)) * (event.total_steps - event.step)
822
+ return (
823
+ ui.render_status(
824
+ stage_index=event.stage,
825
+ stage_label=label,
826
+ step=event.step,
827
+ total_steps=event.total_steps,
828
+ elapsed_s=elapsed,
829
+ eta_s=eta,
830
+ ),
831
+ gr.update(),
832
  )
833
+ if isinstance(event, backend_module.OutputEvent):
 
834
  video_update = event.video_path if event.video_path else gr.update()
835
+ return (ui._render_idle(), video_update)
836
+ if isinstance(event, backend_module.ErrorEvent):
837
+ return (
838
  f'<div class="status-card status-error">'
839
  f' <div class="status-row"><span class="status-stage">Error · {event.category}</span></div>'
840
  f" <div>{event.message}</div>"
841
+ f"</div>",
842
+ gr.update(),
843
  )
844
+ return None
845
+
846
+ # Tier 1 + Tier 2: one normal attempt; if it aborts on ZeroGPU duration
847
+ # cap, retry once with a 2× duration multiplier. Each multiplier is
848
+ # capped at 900s server-side, so the second attempt never exceeds that.
849
+ started = time.time()
850
+ multiplier = 1.0
851
+ timed_out = False
852
+ for attempt in (0, 1):
853
+ if attempt == 1:
854
+ # Show a friendly retry banner before the second submit
855
+ yield (
856
+ '<div class="status-card status-error">'
857
+ ' <div class="status-row"><span class="status-stage">'
858
+ "Retrying with extended GPU budget</span></div>"
859
+ " <div>First attempt hit the per-call duration cap "
860
+ "(usually a cold model cache or a heavier mode than estimated). "
861
+ "Reserving 2× the budget and trying once more.</div>"
862
+ "</div>",
863
+ gr.update(),
864
+ )
865
+ multiplier = 2.0
866
+ started = time.time() # reset so progress ETAs are sensible
867
+
868
+ timed_out = False
869
+ async for event in backend.submit(
870
+ mode_name, workflow, preset=preset, duration_multiplier=multiplier
871
+ ):
872
+ if (
873
+ isinstance(event, backend_module.ErrorEvent)
874
+ and event.category == "gpu_timeout"
875
+ and attempt == 0
876
+ ):
877
+ timed_out = True
878
+ break # don't yield the timeout error — auto-retry instead
879
+ translated = await _translate(event, started)
880
+ if translated is not None:
881
+ yield translated
882
+ if not timed_out:
883
+ return
884
 
885
 
886
  def _input_keys_for_mode(mode_name: str, h: dict) -> list[str]:
backend.py CHANGED
@@ -63,25 +63,78 @@ def _identity(fn):
63
  return fn
64
 
65
 
66
- # ZeroGPU's startup detector scans loaded modules for spaces.GPU-wrapped
67
- # functions. The decorator must be applied at module load time — runtime
68
- # wrapping inside a request handler isn't detected. `duration` is the per-call
69
- # timeout, NOT a billing cap (HF bills actual usage). Setting it generously
70
- # (10 min) so heavy modes like lipsync (audio encoder + extra LoRAs + VAE
71
- # decode + ffmpeg mux) don't hit the 300s wall mid-mux. Light modes return
72
- # in ~30-60s and free the GPU back into the pool.
73
- _GPU = spaces.GPU(duration=600) if (spaces is not None and _on_spaces()) else _identity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
 
76
  @_GPU
77
- def _execute_workflow(executor: Any, workflow: dict, output_ids: list[str]) -> str:
 
 
 
 
 
 
 
78
  """Run the workflow on GPU and return the path of the first video output.
79
 
80
  Returns just the video path (a plain string, picklable across the
81
- @spaces.GPU subprocess boundary). Returning the full history_result dict
82
- was unreliable on Spaces under ZeroGPU's GPU-context wrapping, the
83
- parent process didn't see the executor's mutated state, so video_path
84
- came back empty even when the file was on disk.
85
  """
86
  executor.execute(
87
  workflow,
@@ -291,9 +344,20 @@ class ComfyUILibraryBackend:
291
  return f"ComfyUILibraryBackend(comfy_dir={self._comfy_dir!r})"
292
 
293
  async def submit(
294
- self, mode: str, workflow: dict, gpu_duration: int = 300
 
 
 
 
 
 
295
  ) -> AsyncIterator[Any]:
296
- """Run a workflow end-to-end. Yields Download/Progress/Output/Error events."""
 
 
 
 
 
297
  # Pre-flight: ensure all model files exist.
298
  try:
299
  needed = models.walk_workflow_for_models(workflow)
@@ -361,12 +425,14 @@ class ComfyUILibraryBackend:
361
  # Use the public setter; it writes the same global the
362
  # ProgressBar class reads, but is the documented API.
363
  comfy.utils.set_progress_bar_global_hook(_hook)
364
- # _execute_workflow is module-level and decorated with
365
- # @spaces.GPU(duration=600) on Spaces — that's what makes the
366
- # heavy compute run on a borrowed H200. Off-Spaces it's a
367
- # plain call. Returns the video path directly (computed
368
- # inside the GPU context so the executor's history is fresh).
369
- video_path = _execute_workflow(self._executor, workflow, output_ids)
 
 
370
  # Fallback: if history_result didn't surface a path (rare on
371
  # Spaces — happens when ZeroGPU's subprocess boundary drops
372
  # mutated state), scan the output dir for the newest mp4
@@ -415,8 +481,14 @@ class ComfyUILibraryBackend:
415
 
416
  def _classify(exc: Exception) -> str:
417
  name = type(exc).__name__.lower()
418
- if "outofmemory" in name or "cuda out of memory" in str(exc).lower():
 
419
  return "oom"
 
 
 
 
 
420
  if "interrupt" in name:
421
  return "interrupt"
422
  return "execution"
 
63
  return fn
64
 
65
 
66
+ # --- Per-call ZeroGPU duration estimator -----------------------------------
67
+ # `duration` is a per-call timeout. Shorter declared duration faster queue
68
+ # priority on the shared ZeroGPU pool. Estimating from (mode, preset, frames)
69
+ # instead of using a one-size-fits-all 600s cap means light T2V calls jump
70
+ # the queue while heavy modes (lipsync, style) reserve real headroom.
71
+
72
+ _BASE_DURATION_S: dict[str, int] = {
73
+ # Rough sampler+decode time at ~120 frames, balanced preset, warm cache.
74
+ "t2v": 90,
75
+ "i2v": 90,
76
+ "a2v": 120,
77
+ "lipsync": 240, # extra: audio encoder + audio VAE + extra LoRAs
78
+ "keyframe": 180,
79
+ "style": 360, # extra: preprocessor (canny/dwpose/depth) + IC-LoRAs
80
+ }
81
+ _PRESET_MULT: dict[str, float] = {"fast": 1.0, "balanced": 1.5, "quality": 3.0}
82
+
83
+
84
+ def _frames_from_workflow(workflow: dict) -> int:
85
+ """Read the frame count from the workflow's EmptyLTXVLatentVideo node."""
86
+ for node in workflow.values():
87
+ if isinstance(node, dict) and node.get("class_type") == "EmptyLTXVLatentVideo":
88
+ try:
89
+ return int((node.get("inputs") or {}).get("length", 121))
90
+ except (TypeError, ValueError):
91
+ return 121
92
+ return 121
93
+
94
+
95
+ def _duration_for(
96
+ executor: Any,
97
+ workflow: dict,
98
+ output_ids: list[str],
99
+ mode: str,
100
+ preset: str,
101
+ multiplier: float = 1.0,
102
+ ) -> int:
103
+ """ZeroGPU duration estimator. Same signature as _execute_workflow.
104
+
105
+ Estimate = (base × preset multiplier + cold-cache buffer + per-frame VAE
106
+ decode time) × retry multiplier, clamped to [60s, 900s]. The 900s ceiling
107
+ keeps a single failed call from torching the daily quota.
108
+ """
109
+ base = _BASE_DURATION_S.get(mode, 180)
110
+ mult = _PRESET_MULT.get(preset.lower(), 1.5)
111
+ frames = _frames_from_workflow(workflow)
112
+ est = int((base * mult + 60 + frames * 0.3) * multiplier)
113
+ return max(60, min(est, 900))
114
+
115
+
116
+ # Decorate at module load time so ZeroGPU's startup analyzer detects it.
117
+ _GPU = (
118
+ spaces.GPU(duration=_duration_for)
119
+ if (spaces is not None and _on_spaces())
120
+ else _identity
121
+ )
122
 
123
 
124
  @_GPU
125
+ def _execute_workflow(
126
+ executor: Any,
127
+ workflow: dict,
128
+ output_ids: list[str],
129
+ mode: str,
130
+ preset: str,
131
+ multiplier: float = 1.0,
132
+ ) -> str:
133
  """Run the workflow on GPU and return the path of the first video output.
134
 
135
  Returns just the video path (a plain string, picklable across the
136
+ @spaces.GPU subprocess boundary). The `mode`, `preset`, and `multiplier`
137
+ args are consumed by `_duration_for` to estimate the GPU slot to reserve.
 
 
138
  """
139
  executor.execute(
140
  workflow,
 
344
  return f"ComfyUILibraryBackend(comfy_dir={self._comfy_dir!r})"
345
 
346
  async def submit(
347
+ self,
348
+ mode: str,
349
+ workflow: dict,
350
+ *,
351
+ preset: str = "balanced",
352
+ duration_multiplier: float = 1.0,
353
+ gpu_duration: int = 0, # legacy, ignored (now derived from preset+frames)
354
  ) -> AsyncIterator[Any]:
355
+ """Run a workflow end-to-end. Yields Download/Progress/Output/Error events.
356
+
357
+ `preset` and `duration_multiplier` flow through to the @spaces.GPU
358
+ duration estimator. The handler can re-call submit() with
359
+ duration_multiplier=2.0 if the first attempt aborts on timeout.
360
+ """
361
  # Pre-flight: ensure all model files exist.
362
  try:
363
  needed = models.walk_workflow_for_models(workflow)
 
425
  # Use the public setter; it writes the same global the
426
  # ProgressBar class reads, but is the documented API.
427
  comfy.utils.set_progress_bar_global_hook(_hook)
428
+ # _execute_workflow is module-level and decorated with a
429
+ # @spaces.GPU(duration=callable) on Spaces — the callable
430
+ # estimates per-call timeout from (mode, preset, frames) so
431
+ # light calls get fast queue priority while heavy ones reserve
432
+ # real headroom. Off-Spaces it's a plain call.
433
+ video_path = _execute_workflow(
434
+ self._executor, workflow, output_ids, mode, preset, duration_multiplier,
435
+ )
436
  # Fallback: if history_result didn't surface a path (rare on
437
  # Spaces — happens when ZeroGPU's subprocess boundary drops
438
  # mutated state), scan the output dir for the newest mp4
 
481
 
482
  def _classify(exc: Exception) -> str:
483
  name = type(exc).__name__.lower()
484
+ msg = str(exc).lower()
485
+ if "outofmemory" in name or "cuda out of memory" in msg:
486
  return "oom"
487
+ # ZeroGPU enforces the @spaces.GPU(duration=N) cap and re-raises as
488
+ # gradio.exceptions.Error('GPU task aborted'). Surface a distinct
489
+ # category so the handler can offer a retry with a bigger budget.
490
+ if "gpu task aborted" in msg or ("gpu" in msg and "aborted" in msg):
491
+ return "gpu_timeout"
492
  if "interrupt" in name:
493
  return "interrupt"
494
  return "execution"