Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| import gc | |
| import os | |
| import random | |
| import tempfile | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from fastapi.responses import HTMLResponse | |
| from gradio.data_classes import FileData | |
| # ZeroGPU. Degrade gracefully off-Spaces so `python app.py` works locally. | |
| try: | |
| import spaces | |
| _HAS_SPACES = True | |
| except ImportError: | |
| _HAS_SPACES = False | |
| # --- Model load --------------------------------------------------------------- | |
| # Heavy startup is wrapped in `gr.NO_RELOAD` so `gradio app.py` hot reload | |
| # does not redownload weights every time you save the HTML. | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.bfloat16 | |
| if gr.NO_RELOAD: | |
| from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline | |
| from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel | |
| from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3 | |
| PIPE = QwenImageEditPlusPipeline.from_pretrained( | |
| "FireRedTeam/FireRed-Image-Edit-1.1", | |
| transformer=QwenImageTransformer2DModel.from_pretrained( | |
| "prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V19", | |
| torch_dtype=DTYPE, | |
| device_map="cuda", | |
| ), | |
| torch_dtype=DTYPE, | |
| ).to(DEVICE) | |
| try: | |
| PIPE.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3()) | |
| print("Flash Attention 3 processor set.") | |
| except Exception as e: | |
| print(f"FA3 processor not set: {e}") | |
| NEGATIVE_PROMPT = ( | |
| "worst quality, low quality, bad anatomy, bad hands, text, error, " | |
| "missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, " | |
| "signature, watermark, username, blurry" | |
| ) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| def _round_dims(image: Image.Image) -> tuple[int, int]: | |
| w, h = image.size | |
| if w > h: | |
| new_w, new_h = 1024, int(1024 * h / w) | |
| else: | |
| new_h, new_w = 1024, int(1024 * w / h) | |
| return (new_w // 8) * 8, (new_h // 8) * 8 | |
| # --- Inner GPU function ------------------------------------------------------- | |
| # Per the reference: @spaces.GPU goes on the *inner* function that runs the | |
| # model. The outer @server.api route just plugs it into the queue. | |
| if _HAS_SPACES: | |
| def _edit(image: Image.Image, prompt: str, seed: int, steps: int) -> Image.Image: | |
| return _run_pipe(image, prompt, seed, steps) | |
| else: | |
| def _edit(image, prompt, seed, steps): | |
| return _run_pipe(image, prompt, seed, steps) | |
| def _run_pipe(image, prompt, seed, steps): | |
| print(f"[_run_pipe] start cuda_avail={torch.cuda.is_available()}", flush=True) | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| width, height = _round_dims(image) | |
| print(f"[_run_pipe] dims w={width} h={height} steps={steps}", flush=True) | |
| generator = torch.Generator(device=DEVICE).manual_seed(seed) | |
| out = PIPE( | |
| image=[image], | |
| prompt=prompt, | |
| negative_prompt=NEGATIVE_PROMPT, | |
| width=width, | |
| height=height, | |
| num_inference_steps=steps, | |
| true_cfg_scale=1.0, | |
| generator=generator, | |
| ).images[0] | |
| print(f"[_run_pipe] done size={out.size}", flush=True) | |
| return out | |
| # --- Server ------------------------------------------------------------------- | |
| server = gr.Server() | |
| HOME = Path(__file__).parent | |
| def edit_image(image: FileData, prompt: str) -> dict: | |
| """Edit an image guided by a text prompt using FireRed-Image-Edit 1.1.""" | |
| print(f"[edit_image] received prompt={prompt!r} path={image.get('path')}", flush=True) | |
| if not prompt or not prompt.strip(): | |
| return {"error": "Please enter an edit prompt."} | |
| src = Image.open(image["path"]).convert("RGB") | |
| print(f"[edit_image] image opened size={src.size}", flush=True) | |
| seed = random.randint(0, MAX_SEED) | |
| print(f"[edit_image] calling _edit seed={seed}", flush=True) | |
| result = _edit(src, prompt.strip(), seed, steps=4) | |
| print(f"[edit_image] _edit returned size={result.size}", flush=True) | |
| fd, out_path = tempfile.mkstemp(suffix=".png") | |
| os.close(fd) | |
| result.save(out_path) | |
| print(f"[edit_image] saved to {out_path} exists={os.path.exists(out_path)} size={os.path.getsize(out_path)}", flush=True) | |
| payload = {"image": FileData(path=out_path), "seed": seed} | |
| print(f"[edit_image] returning payload keys={list(payload.keys())} image={payload['image']}", flush=True) | |
| return payload | |
| async def homepage(): | |
| return (HOME / "index.html").read_text(encoding="utf-8") | |
| if __name__ == "__main__": | |
| server.launch(show_error=True, allowed_paths=[tempfile.gettempdir()]) | |