"""MegaStyle Image Style Comparison โ€” Gradio Space (ZeroGPU). Compares a test image against 1-8 reference images using MegaStyle-Encoder (SigLIP-so400m-patch14-384 fine-tuned on MegaStyle-1.4M with style-supervised contrastive learning). Reports the mean cosine similarity as the headline score plus a per-reference breakdown. """ from __future__ import annotations from typing import Any import spaces # must be imported before torch for ZeroGPU import gradio as gr import torch from PIL import Image from huggingface_hub import hf_hub_download from transformers import SiglipImageProcessor, SiglipVisionModel MAX_REFERENCES = 8 SIGLIP_ID = "google/siglip-so400m-patch14-384" ENCODER_REPO = "Gaojunyao/MegaStyle" ENCODER_FILE = "megastyle_encoder.pth" LABEL_BANDS: list[tuple[float, str, str]] = [ (0.75, "Strong style match", "๐ŸŸข"), (0.65, "Good style match", "๐ŸŸข"), (0.55, "Moderate style match", "๐ŸŸก"), (0.45, "Weak style match", "๐ŸŸ "), (-1.0, "Minimal style match", "๐Ÿ”ด"), ] def _load_model_cpu() -> tuple[SiglipVisionModel, SiglipImageProcessor]: """Load SigLIP on CPU and overlay MegaStyle-Encoder fine-tuned weights. Kept on CPU at module scope; moved to GPU lazily inside the @spaces.GPU function. ZeroGPU does not bind a GPU at import time. Uses SiglipImageProcessor (image-only) rather than AutoProcessor to avoid pulling in SiglipTokenizer, which requires sentencepiece. Vision-only inference does not need the tokenizer. """ model = SiglipVisionModel.from_pretrained(SIGLIP_ID).eval() processor = SiglipImageProcessor.from_pretrained(SIGLIP_ID) ckpt_path = hf_hub_download(ENCODER_REPO, ENCODER_FILE) checkpoint = torch.load(ckpt_path, map_location="cpu") state = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint model.load_state_dict(state, strict=False) return model, processor MODEL, PROCESSOR = _load_model_cpu() def _label_for(score: float) -> tuple[str, str]: """Returns (label, emoji) for a cosine-similarity score.""" for threshold, label, emoji in LABEL_BANDS: if score >= threshold: return label, emoji _, label, emoji = LABEL_BANDS[-1] return label, emoji def _coerce_to_pil(image: Any) -> Image.Image: """Gallery/Image components return varied payloads across Gradio versions. Accepts: PIL.Image, numpy arrays, filepath strings, dicts with 'image'/'name' keys, or tuples of (path, caption). Returns an RGB PIL.Image. """ if isinstance(image, Image.Image): return image.convert("RGB") if isinstance(image, tuple) and image: return _coerce_to_pil(image[0]) if isinstance(image, dict): for key in ("image", "name", "path", "orig_name"): if key in image and image[key] is not None: return _coerce_to_pil(image[key]) if isinstance(image, str): return Image.open(image).convert("RGB") # numpy array fallback try: return Image.fromarray(image).convert("RGB") except Exception as exc: # surface a readable error to the user raise ValueError(f"Unsupported image input type: {type(image).__name__}") from exc @torch.no_grad() def _embed(images: list[Image.Image], device: torch.device) -> torch.Tensor: inputs = PROCESSOR(images=images, return_tensors="pt") pixel_values = inputs["pixel_values"].to(device) embeds = MODEL(pixel_values=pixel_values).pooler_output embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True) return embeds @spaces.GPU(duration=30) def compare_styles( test_image: Image.Image | None, reference_images: list | None, ) -> tuple[str, list[list]]: if test_image is None: return "### Upload a test image to compare.", [] if not reference_images: return "### Add at least one reference image.", [] clipped_note = "" total = len(reference_images) if total > MAX_REFERENCES: clipped_note = f"\n\n_Using the first {MAX_REFERENCES} of {total} references._" refs = [_coerce_to_pil(r) for r in reference_images[:MAX_REFERENCES]] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if next(MODEL.parameters()).device != device: MODEL.to(device) test_emb = _embed([_coerce_to_pil(test_image)], device) ref_emb = _embed(refs, device) per_ref = (test_emb @ ref_emb.T).squeeze(0).cpu().tolist() mean_score = float(sum(per_ref) / len(per_ref)) verdict, emoji = _label_for(mean_score) summary = ( f"## {emoji} {verdict}\n\n" f"**Score:** `{mean_score:.3f}`  ยท  " f"Mean cosine similarity across {len(refs)} reference image" f"{'s' if len(refs) != 1 else ''}. Range: `-1` (opposite) โ€ฆ `1` (identical style)." f"{clipped_note}" ) table = [] for i, score in enumerate(per_ref): ref_label, ref_emoji = _label_for(score) table.append([f"Reference {i + 1}", f"{score:.3f}", f"{ref_emoji} {ref_label}"]) return summary, table with gr.Blocks(title="MegaStyle Style Comparison", theme=gr.themes.Soft()) as demo: gr.Markdown( "# ๐ŸŽจ MegaStyle Image Style Comparison\n" "Compare how well a **test image** matches the style of **1โ€“" f"{MAX_REFERENCES} reference images** using " "[MegaStyle-Encoder](https://huggingface.co/Gaojunyao/MegaStyle) " "([paper](https://arxiv.org/abs/2604.08364))." ) with gr.Row(equal_height=True): with gr.Column(scale=1): gr.Markdown("### 1. Test image") test_in = gr.Image( type="pil", show_label=False, height=360, sources=["upload", "clipboard"], ) with gr.Column(scale=1): gr.Markdown(f"### 2. Reference images (up to {MAX_REFERENCES})") refs_in = gr.File( file_count="multiple", file_types=["image"], show_label=False, height=140, ) refs_preview = gr.Gallery( show_label=False, columns=4, rows=2, height=220, object_fit="contain", preview=False, allow_preview=False, show_fullscreen_button=False, show_share_button=False, show_download_button=False, interactive=False, ) def _files_to_gallery(files: list | None) -> list[str]: if not files: return [] return [getattr(f, "name", f) for f in files] refs_in.change(fn=_files_to_gallery, inputs=refs_in, outputs=refs_preview) with gr.Row(): compare_btn = gr.Button("Compare styles", variant="primary", size="lg") gr.Markdown("### 3. Result") summary_out = gr.Markdown("_Run a comparison to see the result._") breakdown_out = gr.Dataframe( headers=["Reference", "Cosine similarity", "Label"], datatype=["str", "str", "str"], label="Per-reference breakdown", interactive=False, wrap=True, ) compare_btn.click( fn=compare_styles, inputs=[test_in, refs_in], outputs=[summary_out, breakdown_out], ) gr.Markdown( "**Label bands (cosine similarity):** " "๐ŸŸข โ‰ฅ 0.75 Strong ยท ๐ŸŸข 0.65โ€“0.75 Good ยท ๐ŸŸก 0.55โ€“0.65 Moderate ยท " "๐ŸŸ  0.45โ€“0.55 Weak ยท ๐Ÿ”ด < 0.45 Minimal. " "Thresholds are heuristic; the raw cosine score is the source of truth." ) if __name__ == "__main__": # show_api=False avoids the /api/info endpoint whose JSON-schema walker # crashes on gr.Dataframe in older Gradio versions. Harmless for a demo Space. demo.queue().launch(show_api=False)