Spaces:
Running on Zero
Running on Zero
feat(app): generate handler — async streaming, status banner, video output
Browse files
app.py
CHANGED
|
@@ -93,7 +93,14 @@ def build_app() -> gr.Blocks:
|
|
| 93 |
_render_sidebar()
|
| 94 |
with gr.Column(scale=4):
|
| 95 |
handles = _render_mode_panels()
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
return app
|
| 98 |
|
| 99 |
|
|
@@ -161,6 +168,152 @@ def _render_one_mode(name: str) -> dict:
|
|
| 161 |
return handles
|
| 162 |
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
if __name__ == "__main__":
|
| 165 |
app = build_app()
|
| 166 |
app.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
| 93 |
_render_sidebar()
|
| 94 |
with gr.Column(scale=4):
|
| 95 |
handles = _render_mode_panels()
|
| 96 |
+
|
| 97 |
+
for name, h in handles.items():
|
| 98 |
+
inputs = _collect_inputs_for_mode(name, h)
|
| 99 |
+
h["generate_btn"].click(
|
| 100 |
+
fn=_make_handler(name, h),
|
| 101 |
+
inputs=inputs,
|
| 102 |
+
outputs=[h["status"], h["video_out"]],
|
| 103 |
+
)
|
| 104 |
return app
|
| 105 |
|
| 106 |
|
|
|
|
| 168 |
return handles
|
| 169 |
|
| 170 |
|
| 171 |
+
import time
|
| 172 |
+
from typing import Any
|
| 173 |
+
|
| 174 |
+
import workflow as wf_module
|
| 175 |
+
import backend as backend_module
|
| 176 |
+
|
| 177 |
+
_BACKEND: backend_module.ComfyUILibraryBackend | None = None
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _get_backend() -> backend_module.ComfyUILibraryBackend:
|
| 181 |
+
global _BACKEND
|
| 182 |
+
if _BACKEND is None:
|
| 183 |
+
_BACKEND = backend_module.ComfyUILibraryBackend()
|
| 184 |
+
return _BACKEND
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
PRESET_DURATION = {"Fast": 60, "Balanced": 120, "Quality": 300}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
async def _on_generate(mode_name: str, **inputs: Any):
|
| 191 |
+
"""Generate handler — async generator yielding (status_html, video_path)."""
|
| 192 |
+
mode = modes.MODE_REGISTRY[mode_name]
|
| 193 |
+
|
| 194 |
+
# Translate UI inputs into the parameterize_fn input dict.
|
| 195 |
+
params: dict[str, Any] = {
|
| 196 |
+
"prompt": inputs.get("prompt", ""),
|
| 197 |
+
"negative_prompt": inputs.get("negative_prompt", ""),
|
| 198 |
+
"preset": inputs.get("preset", "Balanced").lower(),
|
| 199 |
+
"width": int(inputs.get("width", 512)),
|
| 200 |
+
"height": int(inputs.get("height", 768)),
|
| 201 |
+
"frames": int(inputs.get("frames", 81)),
|
| 202 |
+
"fps": int(inputs.get("fps", 24)),
|
| 203 |
+
"seed": int(inputs.get("seed", 42)),
|
| 204 |
+
}
|
| 205 |
+
for k in ("image", "audio", "first_frame", "last_frame", "input_video",
|
| 206 |
+
"camera_lora", "camera_strength",
|
| 207 |
+
"detailer_on", "detailer_strength",
|
| 208 |
+
"ic_lora", "ic_strength", "pose_on", "audio_cfg", "image_strength"):
|
| 209 |
+
if k in inputs:
|
| 210 |
+
params[k] = inputs[k]
|
| 211 |
+
|
| 212 |
+
patches = mode.parameterize_fn(params)
|
| 213 |
+
workflow = wf_module.load_template(mode_name)
|
| 214 |
+
for patch in patches:
|
| 215 |
+
wf_module.set_input(workflow, *patch)
|
| 216 |
+
wf_module.validate(workflow)
|
| 217 |
+
|
| 218 |
+
backend = _get_backend()
|
| 219 |
+
duration = PRESET_DURATION.get(inputs.get("preset", "Balanced"), 120)
|
| 220 |
+
|
| 221 |
+
started = time.time()
|
| 222 |
+
async for event in backend.submit(mode_name, workflow, gpu_duration=duration):
|
| 223 |
+
elapsed = time.time() - started
|
| 224 |
+
if isinstance(event, backend_module.DownloadEvent):
|
| 225 |
+
status = ui.render_status(
|
| 226 |
+
stage_index=0,
|
| 227 |
+
stage_label=f"Downloading {event.filename}",
|
| 228 |
+
step=int(event.mb_done),
|
| 229 |
+
total_steps=int(max(event.mb_total, 1)),
|
| 230 |
+
elapsed_s=elapsed, eta_s=0,
|
| 231 |
+
)
|
| 232 |
+
yield status, gr.update()
|
| 233 |
+
elif isinstance(event, backend_module.ProgressEvent):
|
| 234 |
+
stage = (
|
| 235 |
+
mode.stage_map[event.stage]
|
| 236 |
+
if event.stage < len(mode.stage_map)
|
| 237 |
+
else mode.stage_map[-1]
|
| 238 |
+
)
|
| 239 |
+
eta = (elapsed / max(event.step, 1)) * (event.total_steps - event.step)
|
| 240 |
+
status = ui.render_status(
|
| 241 |
+
stage_index=event.stage + 1,
|
| 242 |
+
stage_label=stage.label,
|
| 243 |
+
step=event.step,
|
| 244 |
+
total_steps=event.total_steps,
|
| 245 |
+
elapsed_s=elapsed, eta_s=eta,
|
| 246 |
+
)
|
| 247 |
+
yield status, gr.update()
|
| 248 |
+
elif isinstance(event, backend_module.OutputEvent):
|
| 249 |
+
yield ui._render_idle(), event.video_path
|
| 250 |
+
elif isinstance(event, backend_module.ErrorEvent):
|
| 251 |
+
error_html = (
|
| 252 |
+
f'<div class="status-card status-error">'
|
| 253 |
+
f' <div class="status-row"><span class="status-stage">Error · {event.category}</span></div>'
|
| 254 |
+
f' <div>{event.message}</div>'
|
| 255 |
+
f'</div>'
|
| 256 |
+
)
|
| 257 |
+
yield error_html, gr.update()
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def _input_keys_for_mode(mode_name: str, h: dict) -> list[str]:
|
| 261 |
+
base = ["prompt", "preset", "width", "height", "frames", "fps", "seed"]
|
| 262 |
+
if mode_name == "i2v":
|
| 263 |
+
base.append("image")
|
| 264 |
+
elif mode_name == "a2v":
|
| 265 |
+
base.append("audio")
|
| 266 |
+
elif mode_name == "lipsync":
|
| 267 |
+
base.extend(["image", "audio"])
|
| 268 |
+
elif mode_name == "keyframe":
|
| 269 |
+
base.extend(["first_frame", "last_frame"])
|
| 270 |
+
elif mode_name == "style":
|
| 271 |
+
base.append("input_video")
|
| 272 |
+
base.append("negative_prompt")
|
| 273 |
+
base.extend(["camera_lora", "camera_strength", "detailer_on", "detailer_strength"])
|
| 274 |
+
if h["lora"].ic_lora is not None:
|
| 275 |
+
base.extend(["ic_lora", "ic_strength"])
|
| 276 |
+
if h["lora"].pose_on is not None:
|
| 277 |
+
base.append("pose_on")
|
| 278 |
+
return base
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def _collect_inputs_for_mode(mode_name: str, h: dict) -> list:
|
| 282 |
+
"""Gather the gr.Component handles to pass into _on_generate."""
|
| 283 |
+
base = [h["prompt"], h["preset"], h["width"], h["height"], h["frames"], h["fps"], h["seed"]]
|
| 284 |
+
if mode_name == "i2v":
|
| 285 |
+
base.append(h["image"])
|
| 286 |
+
elif mode_name == "a2v":
|
| 287 |
+
base.append(h["audio"])
|
| 288 |
+
elif mode_name == "lipsync":
|
| 289 |
+
base.extend([h["image"], h["audio"]])
|
| 290 |
+
elif mode_name == "keyframe":
|
| 291 |
+
base.extend([h["first_frame"], h["last_frame"]])
|
| 292 |
+
elif mode_name == "style":
|
| 293 |
+
base.append(h["input_video"])
|
| 294 |
+
base.append(h["negative_prompt"])
|
| 295 |
+
base.extend([
|
| 296 |
+
h["lora"].camera_lora, h["lora"].camera_strength,
|
| 297 |
+
h["lora"].detailer_on, h["lora"].detailer_strength,
|
| 298 |
+
])
|
| 299 |
+
if h["lora"].ic_lora is not None:
|
| 300 |
+
base.extend([h["lora"].ic_lora, h["lora"].ic_strength])
|
| 301 |
+
if h["lora"].pose_on is not None:
|
| 302 |
+
base.append(h["lora"].pose_on)
|
| 303 |
+
return base
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def _make_handler(mode_name: str, h: dict):
|
| 307 |
+
keys = _input_keys_for_mode(mode_name, h)
|
| 308 |
+
|
| 309 |
+
async def handler(*values):
|
| 310 |
+
kwargs = dict(zip(keys, values))
|
| 311 |
+
async for output in _on_generate(mode_name, **kwargs):
|
| 312 |
+
yield output
|
| 313 |
+
|
| 314 |
+
return handler
|
| 315 |
+
|
| 316 |
+
|
| 317 |
if __name__ == "__main__":
|
| 318 |
app = build_app()
|
| 319 |
app.launch(server_name="0.0.0.0", server_port=7860)
|