techfreakworm commited on
Commit
b23fbf5
·
unverified ·
1 Parent(s): 15811ca

feat(backend): submit() async generator with progress hooks + ZeroGPU

Browse files
Files changed (1) hide show
  1. backend.py +130 -1
backend.py CHANGED
@@ -5,10 +5,13 @@ divergence between local and HF Spaces deployment.
5
  """
6
  from __future__ import annotations
7
 
 
8
  import os
9
  import pathlib
10
  import sys
11
- from collections.abc import AsyncIterator
 
 
12
  from dataclasses import dataclass, field
13
  from typing import Any, Optional
14
 
@@ -84,3 +87,129 @@ class ComfyUILibraryBackend:
84
 
85
  def __repr__(self) -> str:
86
  return f"ComfyUILibraryBackend(comfy_dir={self._comfy_dir!r})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
  from __future__ import annotations
7
 
8
+ import asyncio
9
  import os
10
  import pathlib
11
  import sys
12
+ import threading
13
+ import traceback as tb_mod
14
+ from collections.abc import AsyncIterator, Iterable
15
  from dataclasses import dataclass, field
16
  from typing import Any, Optional
17
 
 
87
 
88
  def __repr__(self) -> str:
89
  return f"ComfyUILibraryBackend(comfy_dir={self._comfy_dir!r})"
90
+
91
+ async def submit(
92
+ self, mode: str, workflow: dict, gpu_duration: int = 120
93
+ ) -> AsyncIterator[Any]:
94
+ """Run a workflow end-to-end. Yields Download/Progress/Output/Error events."""
95
+ # Pre-flight: ensure all model files exist.
96
+ try:
97
+ needed = models.walk_workflow_for_models(workflow)
98
+ for download_event in models.ensure_models(needed):
99
+ yield download_event
100
+ except Exception as e:
101
+ yield ErrorEvent(
102
+ category="download",
103
+ message=str(e),
104
+ traceback=tb_mod.format_exc(),
105
+ )
106
+ return
107
+
108
+ # Run the inference in a worker thread; pass progress events through a queue.
109
+ queue: asyncio.Queue = asyncio.Queue()
110
+ loop = asyncio.get_running_loop()
111
+
112
+ def _push(event: Any) -> None:
113
+ asyncio.run_coroutine_threadsafe(queue.put(event), loop)
114
+
115
+ def _hook(value: int, total: int, _preview=None) -> None:
116
+ _push(ProgressEvent(
117
+ stage=0, stage_label="diffusion",
118
+ step=int(value), total_steps=int(total),
119
+ ))
120
+
121
+ def _worker() -> None:
122
+ import comfy.utils
123
+ saved_hook = getattr(comfy.utils, "PROGRESS_BAR_HOOK", None)
124
+ try:
125
+ # Use the public setter; it writes the same global the
126
+ # ProgressBar class reads, but is the documented API.
127
+ comfy.utils.set_progress_bar_global_hook(_hook)
128
+ self._executor.execute(
129
+ workflow,
130
+ prompt_id="ltx23-aio",
131
+ extra_data={"client_id": "ltx23-aio"},
132
+ execute_outputs=[],
133
+ )
134
+ # PromptExecutor writes output files via VHS_VideoCombine; we read its
135
+ # history to find the most recent saved video.
136
+ outputs = list(self._executor.outputs.values())
137
+ video_path = _first_video_path(outputs) or ""
138
+ _push(OutputEvent(video_path=video_path))
139
+ except Exception as exc:
140
+ _push(ErrorEvent(
141
+ category=_classify(exc),
142
+ message=str(exc),
143
+ traceback=tb_mod.format_exc(),
144
+ ))
145
+ finally:
146
+ comfy.utils.set_progress_bar_global_hook(saved_hook)
147
+ _free_memory()
148
+ _push(None) # sentinel: stop the consumer
149
+
150
+ if _on_spaces():
151
+ import spaces
152
+ execute = spaces.GPU(duration=gpu_duration)(_worker)
153
+ thread = threading.Thread(target=execute, daemon=True)
154
+ else:
155
+ thread = threading.Thread(target=_worker, daemon=True)
156
+ thread.start()
157
+
158
+ while True:
159
+ event = await queue.get()
160
+ if event is None:
161
+ return
162
+ yield event
163
+
164
+ def interrupt(self) -> None:
165
+ """Cancel the currently running workflow (if any)."""
166
+ try:
167
+ import comfy.model_management as mm
168
+ mm.interrupt_current_processing()
169
+ except Exception:
170
+ pass
171
+
172
+
173
+ def _classify(exc: Exception) -> str:
174
+ name = type(exc).__name__.lower()
175
+ if "outofmemory" in name or "cuda out of memory" in str(exc).lower():
176
+ return "oom"
177
+ if "interrupt" in name:
178
+ return "interrupt"
179
+ return "execution"
180
+
181
+
182
+ def _free_memory() -> None:
183
+ """Free VRAM after a workflow finishes (success or failure)."""
184
+ try:
185
+ import comfy.model_management as mm
186
+ mm.unload_all_models()
187
+ except Exception:
188
+ pass
189
+ try:
190
+ import torch
191
+ if torch.backends.mps.is_available():
192
+ torch.mps.empty_cache()
193
+ except Exception:
194
+ pass
195
+ try:
196
+ import torch
197
+ if torch.cuda.is_available():
198
+ torch.cuda.empty_cache()
199
+ except Exception:
200
+ pass
201
+
202
+
203
+ def _first_video_path(outputs: Iterable) -> Optional[str]:
204
+ """Find the first .mp4 path emitted by VHS_VideoCombine in PromptExecutor outputs."""
205
+ for output in outputs:
206
+ if not isinstance(output, dict):
207
+ continue
208
+ for value in output.values():
209
+ if isinstance(value, list):
210
+ for item in value:
211
+ if isinstance(item, dict) and "filename" in item:
212
+ fn = item["filename"]
213
+ if fn.endswith((".mp4", ".webm", ".mov")):
214
+ return item.get("fullpath", fn)
215
+ return None