Spaces:
Paused
Paused
File size: 4,704 Bytes
cc5df64 17d51c2 cc5df64 17d51c2 4ec8323 17d51c2 cc5df64 17d51c2 9678ca5 17d51c2 9678ca5 17d51c2 9678ca5 17d51c2 9678ca5 17d51c2 4ec8323 cc5df64 17d51c2 9678ca5 17d51c2 9678ca5 17d51c2 9678ca5 17d51c2 9678ca5 17d51c2 9678ca5 17d51c2 45895ff 17d51c2 4ec8323 17d51c2 cc5df64 1380a3b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import gradio as gr
import gc
import os
import random
import tempfile
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from fastapi.responses import HTMLResponse
from gradio.data_classes import FileData
# ZeroGPU. Degrade gracefully off-Spaces so `python app.py` works locally.
try:
import spaces
_HAS_SPACES = True
except ImportError:
_HAS_SPACES = False
# --- Model load ---------------------------------------------------------------
# Heavy startup is wrapped in `gr.NO_RELOAD` so `gradio app.py` hot reload
# does not redownload weights every time you save the HTML.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16
if gr.NO_RELOAD:
from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
PIPE = QwenImageEditPlusPipeline.from_pretrained(
"FireRedTeam/FireRed-Image-Edit-1.1",
transformer=QwenImageTransformer2DModel.from_pretrained(
"prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V19",
torch_dtype=DTYPE,
device_map="cuda",
),
torch_dtype=DTYPE,
).to(DEVICE)
try:
PIPE.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
print("Flash Attention 3 processor set.")
except Exception as e:
print(f"FA3 processor not set: {e}")
NEGATIVE_PROMPT = (
"worst quality, low quality, bad anatomy, bad hands, text, error, "
"missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, "
"signature, watermark, username, blurry"
)
MAX_SEED = np.iinfo(np.int32).max
def _round_dims(image: Image.Image) -> tuple[int, int]:
w, h = image.size
if w > h:
new_w, new_h = 1024, int(1024 * h / w)
else:
new_h, new_w = 1024, int(1024 * w / h)
return (new_w // 8) * 8, (new_h // 8) * 8
# --- Inner GPU function -------------------------------------------------------
# Per the reference: @spaces.GPU goes on the *inner* function that runs the
# model. The outer @server.api route just plugs it into the queue.
if _HAS_SPACES:
@spaces.GPU
def _edit(image: Image.Image, prompt: str, seed: int, steps: int) -> Image.Image:
return _run_pipe(image, prompt, seed, steps)
else:
def _edit(image, prompt, seed, steps):
return _run_pipe(image, prompt, seed, steps)
def _run_pipe(image, prompt, seed, steps):
print(f"[_run_pipe] start cuda_avail={torch.cuda.is_available()}", flush=True)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
width, height = _round_dims(image)
print(f"[_run_pipe] dims w={width} h={height} steps={steps}", flush=True)
generator = torch.Generator(device=DEVICE).manual_seed(seed)
out = PIPE(
image=[image],
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
width=width,
height=height,
num_inference_steps=steps,
true_cfg_scale=1.0,
generator=generator,
).images[0]
print(f"[_run_pipe] done size={out.size}", flush=True)
return out
# --- Server -------------------------------------------------------------------
server = gr.Server()
HOME = Path(__file__).parent
@server.api(name="edit_image")
def edit_image(image: FileData, prompt: str) -> dict:
"""Edit an image guided by a text prompt using FireRed-Image-Edit 1.1."""
print(f"[edit_image] received prompt={prompt!r} path={image.get('path')}", flush=True)
if not prompt or not prompt.strip():
return {"error": "Please enter an edit prompt."}
src = Image.open(image["path"]).convert("RGB")
print(f"[edit_image] image opened size={src.size}", flush=True)
seed = random.randint(0, MAX_SEED)
print(f"[edit_image] calling _edit seed={seed}", flush=True)
result = _edit(src, prompt.strip(), seed, steps=4)
print(f"[edit_image] _edit returned size={result.size}", flush=True)
fd, out_path = tempfile.mkstemp(suffix=".png")
os.close(fd)
result.save(out_path)
print(f"[edit_image] saved to {out_path} exists={os.path.exists(out_path)} size={os.path.getsize(out_path)}", flush=True)
payload = {"image": FileData(path=out_path), "seed": seed}
print(f"[edit_image] returning payload keys={list(payload.keys())} image={payload['image']}", flush=True)
return payload
@server.get("/", response_class=HTMLResponse)
async def homepage():
return (HOME / "index.html").read_text(encoding="utf-8")
if __name__ == "__main__":
server.launch(show_error=True, allowed_paths=[tempfile.gettempdir()])
|