Maxim Kruglikov
Replace Gallery-as-input with File picker + Gallery preview
c0a71d7
"""MegaStyle Image Style Comparison β€” Gradio Space (ZeroGPU).
Compares a test image against 1-8 reference images using MegaStyle-Encoder
(SigLIP-so400m-patch14-384 fine-tuned on MegaStyle-1.4M with style-supervised
contrastive learning). Reports the mean cosine similarity as the headline
score plus a per-reference breakdown.
"""
from __future__ import annotations
from typing import Any
import spaces # must be imported before torch for ZeroGPU
import gradio as gr
import torch
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import SiglipImageProcessor, SiglipVisionModel
MAX_REFERENCES = 8
SIGLIP_ID = "google/siglip-so400m-patch14-384"
ENCODER_REPO = "Gaojunyao/MegaStyle"
ENCODER_FILE = "megastyle_encoder.pth"
LABEL_BANDS: list[tuple[float, str, str]] = [
(0.75, "Strong style match", "🟒"),
(0.65, "Good style match", "🟒"),
(0.55, "Moderate style match", "🟑"),
(0.45, "Weak style match", "🟠"),
(-1.0, "Minimal style match", "πŸ”΄"),
]
def _load_model_cpu() -> tuple[SiglipVisionModel, SiglipImageProcessor]:
"""Load SigLIP on CPU and overlay MegaStyle-Encoder fine-tuned weights.
Kept on CPU at module scope; moved to GPU lazily inside the @spaces.GPU
function. ZeroGPU does not bind a GPU at import time.
Uses SiglipImageProcessor (image-only) rather than AutoProcessor to avoid
pulling in SiglipTokenizer, which requires sentencepiece. Vision-only
inference does not need the tokenizer.
"""
model = SiglipVisionModel.from_pretrained(SIGLIP_ID).eval()
processor = SiglipImageProcessor.from_pretrained(SIGLIP_ID)
ckpt_path = hf_hub_download(ENCODER_REPO, ENCODER_FILE)
checkpoint = torch.load(ckpt_path, map_location="cpu")
state = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint
model.load_state_dict(state, strict=False)
return model, processor
MODEL, PROCESSOR = _load_model_cpu()
def _label_for(score: float) -> tuple[str, str]:
"""Returns (label, emoji) for a cosine-similarity score."""
for threshold, label, emoji in LABEL_BANDS:
if score >= threshold:
return label, emoji
_, label, emoji = LABEL_BANDS[-1]
return label, emoji
def _coerce_to_pil(image: Any) -> Image.Image:
"""Gallery/Image components return varied payloads across Gradio versions.
Accepts: PIL.Image, numpy arrays, filepath strings, dicts with 'image'/'name'
keys, or tuples of (path, caption). Returns an RGB PIL.Image.
"""
if isinstance(image, Image.Image):
return image.convert("RGB")
if isinstance(image, tuple) and image:
return _coerce_to_pil(image[0])
if isinstance(image, dict):
for key in ("image", "name", "path", "orig_name"):
if key in image and image[key] is not None:
return _coerce_to_pil(image[key])
if isinstance(image, str):
return Image.open(image).convert("RGB")
# numpy array fallback
try:
return Image.fromarray(image).convert("RGB")
except Exception as exc: # surface a readable error to the user
raise ValueError(f"Unsupported image input type: {type(image).__name__}") from exc
@torch.no_grad()
def _embed(images: list[Image.Image], device: torch.device) -> torch.Tensor:
inputs = PROCESSOR(images=images, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(device)
embeds = MODEL(pixel_values=pixel_values).pooler_output
embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True)
return embeds
@spaces.GPU(duration=30)
def compare_styles(
test_image: Image.Image | None,
reference_images: list | None,
) -> tuple[str, list[list]]:
if test_image is None:
return "### Upload a test image to compare.", []
if not reference_images:
return "### Add at least one reference image.", []
clipped_note = ""
total = len(reference_images)
if total > MAX_REFERENCES:
clipped_note = f"\n\n_Using the first {MAX_REFERENCES} of {total} references._"
refs = [_coerce_to_pil(r) for r in reference_images[:MAX_REFERENCES]]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if next(MODEL.parameters()).device != device:
MODEL.to(device)
test_emb = _embed([_coerce_to_pil(test_image)], device)
ref_emb = _embed(refs, device)
per_ref = (test_emb @ ref_emb.T).squeeze(0).cpu().tolist()
mean_score = float(sum(per_ref) / len(per_ref))
verdict, emoji = _label_for(mean_score)
summary = (
f"## {emoji} {verdict}\n\n"
f"**Score:** `{mean_score:.3f}`  Β·  "
f"Mean cosine similarity across {len(refs)} reference image"
f"{'s' if len(refs) != 1 else ''}. Range: `-1` (opposite) … `1` (identical style)."
f"{clipped_note}"
)
table = []
for i, score in enumerate(per_ref):
ref_label, ref_emoji = _label_for(score)
table.append([f"Reference {i + 1}", f"{score:.3f}", f"{ref_emoji} {ref_label}"])
return summary, table
with gr.Blocks(title="MegaStyle Style Comparison", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"# 🎨 MegaStyle Image Style Comparison\n"
"Compare how well a **test image** matches the style of **1–"
f"{MAX_REFERENCES} reference images** using "
"[MegaStyle-Encoder](https://huggingface.co/Gaojunyao/MegaStyle) "
"([paper](https://arxiv.org/abs/2604.08364))."
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
gr.Markdown("### 1. Test image")
test_in = gr.Image(
type="pil",
show_label=False,
height=360,
sources=["upload", "clipboard"],
)
with gr.Column(scale=1):
gr.Markdown(f"### 2. Reference images (up to {MAX_REFERENCES})")
refs_in = gr.File(
file_count="multiple",
file_types=["image"],
show_label=False,
height=140,
)
refs_preview = gr.Gallery(
show_label=False,
columns=4,
rows=2,
height=220,
object_fit="contain",
preview=False,
allow_preview=False,
show_fullscreen_button=False,
show_share_button=False,
show_download_button=False,
interactive=False,
)
def _files_to_gallery(files: list | None) -> list[str]:
if not files:
return []
return [getattr(f, "name", f) for f in files]
refs_in.change(fn=_files_to_gallery, inputs=refs_in, outputs=refs_preview)
with gr.Row():
compare_btn = gr.Button("Compare styles", variant="primary", size="lg")
gr.Markdown("### 3. Result")
summary_out = gr.Markdown("_Run a comparison to see the result._")
breakdown_out = gr.Dataframe(
headers=["Reference", "Cosine similarity", "Label"],
datatype=["str", "str", "str"],
label="Per-reference breakdown",
interactive=False,
wrap=True,
)
compare_btn.click(
fn=compare_styles,
inputs=[test_in, refs_in],
outputs=[summary_out, breakdown_out],
)
gr.Markdown(
"**Label bands (cosine similarity):** "
"🟒 β‰₯ 0.75 Strong Β· 🟒 0.65–0.75 Good Β· 🟑 0.55–0.65 Moderate Β· "
"🟠 0.45–0.55 Weak Β· πŸ”΄ < 0.45 Minimal. "
"Thresholds are heuristic; the raw cosine score is the source of truth."
)
if __name__ == "__main__":
# show_api=False avoids the /api/info endpoint whose JSON-schema walker
# crashes on gr.Dataframe in older Gradio versions. Harmless for a demo Space.
demo.queue().launch(show_api=False)