File size: 4,704 Bytes
cc5df64
17d51c2
 
 
 
 
 
 
cc5df64
17d51c2
4ec8323
17d51c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc5df64
17d51c2
 
 
 
 
 
 
 
 
 
 
 
 
 
9678ca5
17d51c2
 
 
 
9678ca5
17d51c2
9678ca5
17d51c2
 
 
 
 
 
 
 
 
9678ca5
 
17d51c2
 
 
 
4ec8323
cc5df64
17d51c2
9678ca5
17d51c2
 
9678ca5
17d51c2
 
 
9678ca5
17d51c2
9678ca5
17d51c2
9678ca5
17d51c2
 
 
 
45895ff
 
 
 
17d51c2
 
 
 
4ec8323
 
17d51c2
cc5df64
1380a3b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import gc
import os
import random
import tempfile
from pathlib import Path

import numpy as np
import torch
from PIL import Image
from fastapi.responses import HTMLResponse
from gradio.data_classes import FileData

# ZeroGPU. Degrade gracefully off-Spaces so `python app.py` works locally.
try:
    import spaces
    _HAS_SPACES = True
except ImportError:
    _HAS_SPACES = False

# --- Model load ---------------------------------------------------------------
# Heavy startup is wrapped in `gr.NO_RELOAD` so `gradio app.py` hot reload
# does not redownload weights every time you save the HTML.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16

if gr.NO_RELOAD:
    from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
    from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
    from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3

    PIPE = QwenImageEditPlusPipeline.from_pretrained(
        "FireRedTeam/FireRed-Image-Edit-1.1",
        transformer=QwenImageTransformer2DModel.from_pretrained(
            "prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V19",
            torch_dtype=DTYPE,
            device_map="cuda",
        ),
        torch_dtype=DTYPE,
    ).to(DEVICE)

    try:
        PIPE.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
        print("Flash Attention 3 processor set.")
    except Exception as e:
        print(f"FA3 processor not set: {e}")

NEGATIVE_PROMPT = (
    "worst quality, low quality, bad anatomy, bad hands, text, error, "
    "missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, "
    "signature, watermark, username, blurry"
)
MAX_SEED = np.iinfo(np.int32).max


def _round_dims(image: Image.Image) -> tuple[int, int]:
    w, h = image.size
    if w > h:
        new_w, new_h = 1024, int(1024 * h / w)
    else:
        new_h, new_w = 1024, int(1024 * w / h)
    return (new_w // 8) * 8, (new_h // 8) * 8


# --- Inner GPU function -------------------------------------------------------
# Per the reference: @spaces.GPU goes on the *inner* function that runs the
# model. The outer @server.api route just plugs it into the queue.
if _HAS_SPACES:
    @spaces.GPU
    def _edit(image: Image.Image, prompt: str, seed: int, steps: int) -> Image.Image:
        return _run_pipe(image, prompt, seed, steps)
else:
    def _edit(image, prompt, seed, steps):
        return _run_pipe(image, prompt, seed, steps)


def _run_pipe(image, prompt, seed, steps):
    print(f"[_run_pipe] start cuda_avail={torch.cuda.is_available()}", flush=True)
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    width, height = _round_dims(image)
    print(f"[_run_pipe] dims w={width} h={height} steps={steps}", flush=True)
    generator = torch.Generator(device=DEVICE).manual_seed(seed)
    out = PIPE(
        image=[image],
        prompt=prompt,
        negative_prompt=NEGATIVE_PROMPT,
        width=width,
        height=height,
        num_inference_steps=steps,
        true_cfg_scale=1.0,
        generator=generator,
    ).images[0]
    print(f"[_run_pipe] done size={out.size}", flush=True)
    return out


# --- Server -------------------------------------------------------------------
server = gr.Server()
HOME = Path(__file__).parent


@server.api(name="edit_image")
def edit_image(image: FileData, prompt: str) -> dict:
    """Edit an image guided by a text prompt using FireRed-Image-Edit 1.1."""
    print(f"[edit_image] received prompt={prompt!r} path={image.get('path')}", flush=True)
    if not prompt or not prompt.strip():
        return {"error": "Please enter an edit prompt."}
    src = Image.open(image["path"]).convert("RGB")
    print(f"[edit_image] image opened size={src.size}", flush=True)
    seed = random.randint(0, MAX_SEED)
    print(f"[edit_image] calling _edit seed={seed}", flush=True)
    result = _edit(src, prompt.strip(), seed, steps=4)
    print(f"[edit_image] _edit returned size={result.size}", flush=True)

    fd, out_path = tempfile.mkstemp(suffix=".png")
    os.close(fd)
    result.save(out_path)
    print(f"[edit_image] saved to {out_path} exists={os.path.exists(out_path)} size={os.path.getsize(out_path)}", flush=True)
    payload = {"image": FileData(path=out_path), "seed": seed}
    print(f"[edit_image] returning payload keys={list(payload.keys())} image={payload['image']}", flush=True)
    return payload


@server.get("/", response_class=HTMLResponse)
async def homepage():
    return (HOME / "index.html").read_text(encoding="utf-8")


if __name__ == "__main__":
    server.launch(show_error=True, allowed_paths=[tempfile.gettempdir()])