Lizug's picture
Upload 2 files
6d84e1c verified
import os
import gc
import random
import numpy as np
import torch
import gradio as gr
import spaces
from PIL import Image
from diffusers.models import QwenImageTransformer2DModel
from diffusers import QwenImageEditPlusPipeline
from diffusers.utils import load_image
# ── Device setup ──────────────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# ── Model loading ─────────────────────────────────────────────────────────────
MODEL_ID = "prithivMLmods/FireRed-Image-Edit-1.0-8bit" # 8-bit = lebih hemat VRAM
dtype = torch.bfloat16
transformer = QwenImageTransformer2DModel.from_pretrained(
MODEL_ID,
subfolder="transformer",
torch_dtype=dtype,
)
pipe = QwenImageEditPlusPipeline.from_pretrained(
MODEL_ID,
transformer=transformer,
torch_dtype=dtype,
).to(device)
MAX_SEED = np.iinfo(np.int32).max
# ── Helpers ───────────────────────────────────────────────────────────────────
def update_dimensions_on_upload(image):
"""Auto-detect best output size from uploaded image."""
if image is None:
return 1024, 1024
pil = Image.fromarray(image) if not isinstance(image, Image.Image) else image
w, h = pil.size
# snap to nearest multiple of 64, cap at 1024
w = min(round(w / 64) * 64, 1024)
h = min(round(h / 64) * 64, 1024)
return w, h
# ── Inference ─────────────────────────────────────────────────────────────────
@spaces.GPU(duration=90)
def generate(
images,
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
):
if not images:
raise gr.Error("⚠️ Upload minimal 1 gambar dulu ya!")
if not prompt.strip():
raise gr.Error("⚠️ Tulis prompt dulu – apa yang mau diubah?")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
pil_images = []
for img in images:
if img is None:
continue
pil = Image.fromarray(img).convert("RGB") if not isinstance(img, Image.Image) else img.convert("RGB")
pil_images.append(pil)
result = pipe(
image=pil_images,
prompt=prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
)
gc.collect()
torch.cuda.empty_cache()
return result.images[0], seed
# ── UI ────────────────────────────────────────────────────────────────────────
TITLE = """
<div style="text-align:center; padding: 12px 0 4px">
<h1 style="font-size:2rem; font-weight:700; margin:0">πŸ”₯ FireRed Image Edit</h1>
<p style="color:#666; margin:6px 0 0">
Upload foto β†’ tulis instruksi β†’ dapatkan hasil edit berkualitas tinggi
</p>
</div>
"""
EXAMPLES = [
[None, "Ganti background dengan pantai tropis saat sunset", 42, False, 1024, 1024, 3.5, 4],
[None, "Ubah baju jadi warna merah dengan motif batik", 7, False, 1024, 1024, 3.5, 4],
[None, "Tambahkan efek snow/salju di seluruh gambar", 0, False, 1024, 1024, 3.5, 4],
]
css = """
#col-left { min-width: 360px; }
#col-right { min-width: 360px; }
.gr-button-primary { background: #e63946 !important; border-color: #e63946 !important; }
footer { display: none !important; }
"""
with gr.Blocks(css=css, title="FireRed Image Edit") as demo:
gr.HTML(TITLE)
with gr.Row():
# ── Left column ───────────────────────────────────────────────────────
with gr.Column(elem_id="col-left"):
input_images = gr.Gallery(
label="πŸ“Έ Upload Gambar (1–3 foto)",
columns=3,
rows=1,
height=280,
type="numpy",
interactive=True,
)
prompt = gr.Textbox(
label="✏️ Instruksi Edit",
placeholder="Contoh: ganti warna baju jadi biru tua, tambahkan kacamata hitam...",
lines=3,
)
with gr.Row():
run_btn = gr.Button("πŸ”₯ Generate", variant="primary", scale=3)
clear_btn = gr.Button("πŸ—‘οΈ Clear", scale=1)
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed")
randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
with gr.Row():
width = gr.Slider(512, 1024, value=1024, step=64, label="Width")
height = gr.Slider(512, 1024, value=1024, step=64, label="Height")
guidance_scale = gr.Slider(1.0, 7.0, value=3.5, step=0.1, label="Guidance Scale")
num_inference_steps = gr.Slider(1, 8, value=4, step=1, label="Steps (4 = fast)")
# ── Right column ──────────────────────────────────────────────────────
with gr.Column(elem_id="col-right"):
output_image = gr.Image(label="✨ Hasil Edit", type="pil", height=480)
used_seed = gr.Number(label="Seed yang dipakai", interactive=False)
# ── Event wiring ──────────────────────────────────────────────────────────
input_images.upload(
fn=lambda imgs: update_dimensions_on_upload(imgs[0] if imgs else None),
inputs=input_images,
outputs=[width, height],
)
run_btn.click(
fn=generate,
inputs=[input_images, prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs=[output_image, used_seed],
)
clear_btn.click(
fn=lambda: (None, "", 42, True, 1024, 1024, 3.5, 4, None, 0),
outputs=[input_images, prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, output_image, used_seed],
)
gr.Markdown("""
---
**Tips:**
- πŸ–ΌοΈ Upload 1–3 gambar sekaligus untuk multi-image editing (misal: virtual try-on)
- ⚑ Steps = 4 sudah cukup cepat dan hasilnya bagus
- 🌱 Seed tetap = hasil konsisten; centang *Randomize* untuk variasi
- πŸ”₯ Model: [prithivMLmods/FireRed-Image-Edit-1.0-8bit](https://huggingface.co/prithivMLmods/FireRed-Image-Edit-1.0-8bit)
""")
if __name__ == "__main__":
demo.launch()