ysharma's picture
ysharma HF Staff
Update app.py
1380a3b verified
raw
history blame
4.7 kB
import gradio as gr
import gc
import os
import random
import tempfile
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from fastapi.responses import HTMLResponse
from gradio.data_classes import FileData
# ZeroGPU. Degrade gracefully off-Spaces so `python app.py` works locally.
try:
import spaces
_HAS_SPACES = True
except ImportError:
_HAS_SPACES = False
# --- Model load ---------------------------------------------------------------
# Heavy startup is wrapped in `gr.NO_RELOAD` so `gradio app.py` hot reload
# does not redownload weights every time you save the HTML.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16
if gr.NO_RELOAD:
from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
PIPE = QwenImageEditPlusPipeline.from_pretrained(
"FireRedTeam/FireRed-Image-Edit-1.1",
transformer=QwenImageTransformer2DModel.from_pretrained(
"prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V19",
torch_dtype=DTYPE,
device_map="cuda",
),
torch_dtype=DTYPE,
).to(DEVICE)
try:
PIPE.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
print("Flash Attention 3 processor set.")
except Exception as e:
print(f"FA3 processor not set: {e}")
NEGATIVE_PROMPT = (
"worst quality, low quality, bad anatomy, bad hands, text, error, "
"missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, "
"signature, watermark, username, blurry"
)
MAX_SEED = np.iinfo(np.int32).max
def _round_dims(image: Image.Image) -> tuple[int, int]:
w, h = image.size
if w > h:
new_w, new_h = 1024, int(1024 * h / w)
else:
new_h, new_w = 1024, int(1024 * w / h)
return (new_w // 8) * 8, (new_h // 8) * 8
# --- Inner GPU function -------------------------------------------------------
# Per the reference: @spaces.GPU goes on the *inner* function that runs the
# model. The outer @server.api route just plugs it into the queue.
if _HAS_SPACES:
@spaces.GPU
def _edit(image: Image.Image, prompt: str, seed: int, steps: int) -> Image.Image:
return _run_pipe(image, prompt, seed, steps)
else:
def _edit(image, prompt, seed, steps):
return _run_pipe(image, prompt, seed, steps)
def _run_pipe(image, prompt, seed, steps):
print(f"[_run_pipe] start cuda_avail={torch.cuda.is_available()}", flush=True)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
width, height = _round_dims(image)
print(f"[_run_pipe] dims w={width} h={height} steps={steps}", flush=True)
generator = torch.Generator(device=DEVICE).manual_seed(seed)
out = PIPE(
image=[image],
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
width=width,
height=height,
num_inference_steps=steps,
true_cfg_scale=1.0,
generator=generator,
).images[0]
print(f"[_run_pipe] done size={out.size}", flush=True)
return out
# --- Server -------------------------------------------------------------------
server = gr.Server()
HOME = Path(__file__).parent
@server.api(name="edit_image")
def edit_image(image: FileData, prompt: str) -> dict:
"""Edit an image guided by a text prompt using FireRed-Image-Edit 1.1."""
print(f"[edit_image] received prompt={prompt!r} path={image.get('path')}", flush=True)
if not prompt or not prompt.strip():
return {"error": "Please enter an edit prompt."}
src = Image.open(image["path"]).convert("RGB")
print(f"[edit_image] image opened size={src.size}", flush=True)
seed = random.randint(0, MAX_SEED)
print(f"[edit_image] calling _edit seed={seed}", flush=True)
result = _edit(src, prompt.strip(), seed, steps=4)
print(f"[edit_image] _edit returned size={result.size}", flush=True)
fd, out_path = tempfile.mkstemp(suffix=".png")
os.close(fd)
result.save(out_path)
print(f"[edit_image] saved to {out_path} exists={os.path.exists(out_path)} size={os.path.getsize(out_path)}", flush=True)
payload = {"image": FileData(path=out_path), "seed": seed}
print(f"[edit_image] returning payload keys={list(payload.keys())} image={payload['image']}", flush=True)
return payload
@server.get("/", response_class=HTMLResponse)
async def homepage():
return (HOME / "index.html").read_text(encoding="utf-8")
if __name__ == "__main__":
server.launch(show_error=True, allowed_paths=[tempfile.gettempdir()])