import gradio as gr import gc import os import random import tempfile from pathlib import Path import numpy as np import torch from PIL import Image from fastapi.responses import HTMLResponse from gradio.data_classes import FileData # ZeroGPU. Degrade gracefully off-Spaces so `python app.py` works locally. try: import spaces _HAS_SPACES = True except ImportError: _HAS_SPACES = False # --- Model load --------------------------------------------------------------- # Heavy startup is wrapped in `gr.NO_RELOAD` so `gradio app.py` hot reload # does not redownload weights every time you save the HTML. DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if gr.NO_RELOAD: from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3 PIPE = QwenImageEditPlusPipeline.from_pretrained( "FireRedTeam/FireRed-Image-Edit-1.1", transformer=QwenImageTransformer2DModel.from_pretrained( "prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V19", torch_dtype=DTYPE, device_map="cuda", ), torch_dtype=DTYPE, ).to(DEVICE) try: PIPE.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3()) print("Flash Attention 3 processor set.") except Exception as e: print(f"FA3 processor not set: {e}") NEGATIVE_PROMPT = ( "worst quality, low quality, bad anatomy, bad hands, text, error, " "missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, " "signature, watermark, username, blurry" ) MAX_SEED = np.iinfo(np.int32).max def _round_dims(image: Image.Image) -> tuple[int, int]: w, h = image.size if w > h: new_w, new_h = 1024, int(1024 * h / w) else: new_h, new_w = 1024, int(1024 * w / h) return (new_w // 8) * 8, (new_h // 8) * 8 # --- Inner GPU function ------------------------------------------------------- # Per the reference: @spaces.GPU goes on the *inner* function that runs the # model. The outer @server.api route just plugs it into the queue. if _HAS_SPACES: @spaces.GPU def _edit(image: Image.Image, prompt: str, seed: int, steps: int) -> Image.Image: return _run_pipe(image, prompt, seed, steps) else: def _edit(image, prompt, seed, steps): return _run_pipe(image, prompt, seed, steps) def _run_pipe(image, prompt, seed, steps): print(f"[_run_pipe] start cuda_avail={torch.cuda.is_available()}", flush=True) gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() width, height = _round_dims(image) print(f"[_run_pipe] dims w={width} h={height} steps={steps}", flush=True) generator = torch.Generator(device=DEVICE).manual_seed(seed) out = PIPE( image=[image], prompt=prompt, negative_prompt=NEGATIVE_PROMPT, width=width, height=height, num_inference_steps=steps, true_cfg_scale=1.0, generator=generator, ).images[0] print(f"[_run_pipe] done size={out.size}", flush=True) return out # --- Server ------------------------------------------------------------------- server = gr.Server() HOME = Path(__file__).parent @server.api(name="edit_image") def edit_image(image: FileData, prompt: str) -> dict: """Edit an image guided by a text prompt using FireRed-Image-Edit 1.1.""" print(f"[edit_image] received prompt={prompt!r} path={image.get('path')}", flush=True) if not prompt or not prompt.strip(): return {"error": "Please enter an edit prompt."} src = Image.open(image["path"]).convert("RGB") print(f"[edit_image] image opened size={src.size}", flush=True) seed = random.randint(0, MAX_SEED) print(f"[edit_image] calling _edit seed={seed}", flush=True) result = _edit(src, prompt.strip(), seed, steps=4) print(f"[edit_image] _edit returned size={result.size}", flush=True) fd, out_path = tempfile.mkstemp(suffix=".png") os.close(fd) result.save(out_path) print(f"[edit_image] saved to {out_path} exists={os.path.exists(out_path)} size={os.path.getsize(out_path)}", flush=True) payload = {"image": FileData(path=out_path), "seed": seed} print(f"[edit_image] returning payload keys={list(payload.keys())} image={payload['image']}", flush=True) return payload @server.get("/", response_class=HTMLResponse) async def homepage(): return (HOME / "index.html").read_text(encoding="utf-8") if __name__ == "__main__": server.launch(show_error=True, allowed_paths=[tempfile.gettempdir()])