techfreakworm commited on
Commit
2fd6ed6
·
unverified ·
1 Parent(s): 62e7a26

feat(app): generate handler — async streaming, status banner, video output

Browse files
Files changed (1) hide show
  1. app.py +154 -1
app.py CHANGED
@@ -93,7 +93,14 @@ def build_app() -> gr.Blocks:
93
  _render_sidebar()
94
  with gr.Column(scale=4):
95
  handles = _render_mode_panels()
96
- # Generate-handler wiring deferred to Task 23.
 
 
 
 
 
 
 
97
  return app
98
 
99
 
@@ -161,6 +168,152 @@ def _render_one_mode(name: str) -> dict:
161
  return handles
162
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  if __name__ == "__main__":
165
  app = build_app()
166
  app.launch(server_name="0.0.0.0", server_port=7860)
 
93
  _render_sidebar()
94
  with gr.Column(scale=4):
95
  handles = _render_mode_panels()
96
+
97
+ for name, h in handles.items():
98
+ inputs = _collect_inputs_for_mode(name, h)
99
+ h["generate_btn"].click(
100
+ fn=_make_handler(name, h),
101
+ inputs=inputs,
102
+ outputs=[h["status"], h["video_out"]],
103
+ )
104
  return app
105
 
106
 
 
168
  return handles
169
 
170
 
171
+ import time
172
+ from typing import Any
173
+
174
+ import workflow as wf_module
175
+ import backend as backend_module
176
+
177
+ _BACKEND: backend_module.ComfyUILibraryBackend | None = None
178
+
179
+
180
+ def _get_backend() -> backend_module.ComfyUILibraryBackend:
181
+ global _BACKEND
182
+ if _BACKEND is None:
183
+ _BACKEND = backend_module.ComfyUILibraryBackend()
184
+ return _BACKEND
185
+
186
+
187
+ PRESET_DURATION = {"Fast": 60, "Balanced": 120, "Quality": 300}
188
+
189
+
190
+ async def _on_generate(mode_name: str, **inputs: Any):
191
+ """Generate handler — async generator yielding (status_html, video_path)."""
192
+ mode = modes.MODE_REGISTRY[mode_name]
193
+
194
+ # Translate UI inputs into the parameterize_fn input dict.
195
+ params: dict[str, Any] = {
196
+ "prompt": inputs.get("prompt", ""),
197
+ "negative_prompt": inputs.get("negative_prompt", ""),
198
+ "preset": inputs.get("preset", "Balanced").lower(),
199
+ "width": int(inputs.get("width", 512)),
200
+ "height": int(inputs.get("height", 768)),
201
+ "frames": int(inputs.get("frames", 81)),
202
+ "fps": int(inputs.get("fps", 24)),
203
+ "seed": int(inputs.get("seed", 42)),
204
+ }
205
+ for k in ("image", "audio", "first_frame", "last_frame", "input_video",
206
+ "camera_lora", "camera_strength",
207
+ "detailer_on", "detailer_strength",
208
+ "ic_lora", "ic_strength", "pose_on", "audio_cfg", "image_strength"):
209
+ if k in inputs:
210
+ params[k] = inputs[k]
211
+
212
+ patches = mode.parameterize_fn(params)
213
+ workflow = wf_module.load_template(mode_name)
214
+ for patch in patches:
215
+ wf_module.set_input(workflow, *patch)
216
+ wf_module.validate(workflow)
217
+
218
+ backend = _get_backend()
219
+ duration = PRESET_DURATION.get(inputs.get("preset", "Balanced"), 120)
220
+
221
+ started = time.time()
222
+ async for event in backend.submit(mode_name, workflow, gpu_duration=duration):
223
+ elapsed = time.time() - started
224
+ if isinstance(event, backend_module.DownloadEvent):
225
+ status = ui.render_status(
226
+ stage_index=0,
227
+ stage_label=f"Downloading {event.filename}",
228
+ step=int(event.mb_done),
229
+ total_steps=int(max(event.mb_total, 1)),
230
+ elapsed_s=elapsed, eta_s=0,
231
+ )
232
+ yield status, gr.update()
233
+ elif isinstance(event, backend_module.ProgressEvent):
234
+ stage = (
235
+ mode.stage_map[event.stage]
236
+ if event.stage < len(mode.stage_map)
237
+ else mode.stage_map[-1]
238
+ )
239
+ eta = (elapsed / max(event.step, 1)) * (event.total_steps - event.step)
240
+ status = ui.render_status(
241
+ stage_index=event.stage + 1,
242
+ stage_label=stage.label,
243
+ step=event.step,
244
+ total_steps=event.total_steps,
245
+ elapsed_s=elapsed, eta_s=eta,
246
+ )
247
+ yield status, gr.update()
248
+ elif isinstance(event, backend_module.OutputEvent):
249
+ yield ui._render_idle(), event.video_path
250
+ elif isinstance(event, backend_module.ErrorEvent):
251
+ error_html = (
252
+ f'<div class="status-card status-error">'
253
+ f' <div class="status-row"><span class="status-stage">Error · {event.category}</span></div>'
254
+ f' <div>{event.message}</div>'
255
+ f'</div>'
256
+ )
257
+ yield error_html, gr.update()
258
+
259
+
260
+ def _input_keys_for_mode(mode_name: str, h: dict) -> list[str]:
261
+ base = ["prompt", "preset", "width", "height", "frames", "fps", "seed"]
262
+ if mode_name == "i2v":
263
+ base.append("image")
264
+ elif mode_name == "a2v":
265
+ base.append("audio")
266
+ elif mode_name == "lipsync":
267
+ base.extend(["image", "audio"])
268
+ elif mode_name == "keyframe":
269
+ base.extend(["first_frame", "last_frame"])
270
+ elif mode_name == "style":
271
+ base.append("input_video")
272
+ base.append("negative_prompt")
273
+ base.extend(["camera_lora", "camera_strength", "detailer_on", "detailer_strength"])
274
+ if h["lora"].ic_lora is not None:
275
+ base.extend(["ic_lora", "ic_strength"])
276
+ if h["lora"].pose_on is not None:
277
+ base.append("pose_on")
278
+ return base
279
+
280
+
281
+ def _collect_inputs_for_mode(mode_name: str, h: dict) -> list:
282
+ """Gather the gr.Component handles to pass into _on_generate."""
283
+ base = [h["prompt"], h["preset"], h["width"], h["height"], h["frames"], h["fps"], h["seed"]]
284
+ if mode_name == "i2v":
285
+ base.append(h["image"])
286
+ elif mode_name == "a2v":
287
+ base.append(h["audio"])
288
+ elif mode_name == "lipsync":
289
+ base.extend([h["image"], h["audio"]])
290
+ elif mode_name == "keyframe":
291
+ base.extend([h["first_frame"], h["last_frame"]])
292
+ elif mode_name == "style":
293
+ base.append(h["input_video"])
294
+ base.append(h["negative_prompt"])
295
+ base.extend([
296
+ h["lora"].camera_lora, h["lora"].camera_strength,
297
+ h["lora"].detailer_on, h["lora"].detailer_strength,
298
+ ])
299
+ if h["lora"].ic_lora is not None:
300
+ base.extend([h["lora"].ic_lora, h["lora"].ic_strength])
301
+ if h["lora"].pose_on is not None:
302
+ base.append(h["lora"].pose_on)
303
+ return base
304
+
305
+
306
+ def _make_handler(mode_name: str, h: dict):
307
+ keys = _input_keys_for_mode(mode_name, h)
308
+
309
+ async def handler(*values):
310
+ kwargs = dict(zip(keys, values))
311
+ async for output in _on_generate(mode_name, **kwargs):
312
+ yield output
313
+
314
+ return handler
315
+
316
+
317
  if __name__ == "__main__":
318
  app = build_app()
319
  app.launch(server_name="0.0.0.0", server_port=7860)