Spaces:
Running on Zero
Running on Zero
| """MegaStyle Image Style Comparison β Gradio Space (ZeroGPU). | |
| Compares a test image against 1-8 reference images using MegaStyle-Encoder | |
| (SigLIP-so400m-patch14-384 fine-tuned on MegaStyle-1.4M with style-supervised | |
| contrastive learning). Reports the mean cosine similarity as the headline | |
| score plus a per-reference breakdown. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any | |
| import spaces # must be imported before torch for ZeroGPU | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from transformers import SiglipImageProcessor, SiglipVisionModel | |
| MAX_REFERENCES = 8 | |
| SIGLIP_ID = "google/siglip-so400m-patch14-384" | |
| ENCODER_REPO = "Gaojunyao/MegaStyle" | |
| ENCODER_FILE = "megastyle_encoder.pth" | |
| LABEL_BANDS: list[tuple[float, str, str]] = [ | |
| (0.75, "Strong style match", "π’"), | |
| (0.65, "Good style match", "π’"), | |
| (0.55, "Moderate style match", "π‘"), | |
| (0.45, "Weak style match", "π "), | |
| (-1.0, "Minimal style match", "π΄"), | |
| ] | |
| def _load_model_cpu() -> tuple[SiglipVisionModel, SiglipImageProcessor]: | |
| """Load SigLIP on CPU and overlay MegaStyle-Encoder fine-tuned weights. | |
| Kept on CPU at module scope; moved to GPU lazily inside the @spaces.GPU | |
| function. ZeroGPU does not bind a GPU at import time. | |
| Uses SiglipImageProcessor (image-only) rather than AutoProcessor to avoid | |
| pulling in SiglipTokenizer, which requires sentencepiece. Vision-only | |
| inference does not need the tokenizer. | |
| """ | |
| model = SiglipVisionModel.from_pretrained(SIGLIP_ID).eval() | |
| processor = SiglipImageProcessor.from_pretrained(SIGLIP_ID) | |
| ckpt_path = hf_hub_download(ENCODER_REPO, ENCODER_FILE) | |
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| state = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint | |
| model.load_state_dict(state, strict=False) | |
| return model, processor | |
| MODEL, PROCESSOR = _load_model_cpu() | |
| def _label_for(score: float) -> tuple[str, str]: | |
| """Returns (label, emoji) for a cosine-similarity score.""" | |
| for threshold, label, emoji in LABEL_BANDS: | |
| if score >= threshold: | |
| return label, emoji | |
| _, label, emoji = LABEL_BANDS[-1] | |
| return label, emoji | |
| def _coerce_to_pil(image: Any) -> Image.Image: | |
| """Gallery/Image components return varied payloads across Gradio versions. | |
| Accepts: PIL.Image, numpy arrays, filepath strings, dicts with 'image'/'name' | |
| keys, or tuples of (path, caption). Returns an RGB PIL.Image. | |
| """ | |
| if isinstance(image, Image.Image): | |
| return image.convert("RGB") | |
| if isinstance(image, tuple) and image: | |
| return _coerce_to_pil(image[0]) | |
| if isinstance(image, dict): | |
| for key in ("image", "name", "path", "orig_name"): | |
| if key in image and image[key] is not None: | |
| return _coerce_to_pil(image[key]) | |
| if isinstance(image, str): | |
| return Image.open(image).convert("RGB") | |
| # numpy array fallback | |
| try: | |
| return Image.fromarray(image).convert("RGB") | |
| except Exception as exc: # surface a readable error to the user | |
| raise ValueError(f"Unsupported image input type: {type(image).__name__}") from exc | |
| def _embed(images: list[Image.Image], device: torch.device) -> torch.Tensor: | |
| inputs = PROCESSOR(images=images, return_tensors="pt") | |
| pixel_values = inputs["pixel_values"].to(device) | |
| embeds = MODEL(pixel_values=pixel_values).pooler_output | |
| embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True) | |
| return embeds | |
| def compare_styles( | |
| test_image: Image.Image | None, | |
| reference_images: list | None, | |
| ) -> tuple[str, list[list]]: | |
| if test_image is None: | |
| return "### Upload a test image to compare.", [] | |
| if not reference_images: | |
| return "### Add at least one reference image.", [] | |
| clipped_note = "" | |
| total = len(reference_images) | |
| if total > MAX_REFERENCES: | |
| clipped_note = f"\n\n_Using the first {MAX_REFERENCES} of {total} references._" | |
| refs = [_coerce_to_pil(r) for r in reference_images[:MAX_REFERENCES]] | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if next(MODEL.parameters()).device != device: | |
| MODEL.to(device) | |
| test_emb = _embed([_coerce_to_pil(test_image)], device) | |
| ref_emb = _embed(refs, device) | |
| per_ref = (test_emb @ ref_emb.T).squeeze(0).cpu().tolist() | |
| mean_score = float(sum(per_ref) / len(per_ref)) | |
| verdict, emoji = _label_for(mean_score) | |
| summary = ( | |
| f"## {emoji} {verdict}\n\n" | |
| f"**Score:** `{mean_score:.3f}` Β· " | |
| f"Mean cosine similarity across {len(refs)} reference image" | |
| f"{'s' if len(refs) != 1 else ''}. Range: `-1` (opposite) β¦ `1` (identical style)." | |
| f"{clipped_note}" | |
| ) | |
| table = [] | |
| for i, score in enumerate(per_ref): | |
| ref_label, ref_emoji = _label_for(score) | |
| table.append([f"Reference {i + 1}", f"{score:.3f}", f"{ref_emoji} {ref_label}"]) | |
| return summary, table | |
| with gr.Blocks(title="MegaStyle Style Comparison", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| "# π¨ MegaStyle Image Style Comparison\n" | |
| "Compare how well a **test image** matches the style of **1β" | |
| f"{MAX_REFERENCES} reference images** using " | |
| "[MegaStyle-Encoder](https://huggingface.co/Gaojunyao/MegaStyle) " | |
| "([paper](https://arxiv.org/abs/2604.08364))." | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1. Test image") | |
| test_in = gr.Image( | |
| type="pil", | |
| show_label=False, | |
| height=360, | |
| sources=["upload", "clipboard"], | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown(f"### 2. Reference images (up to {MAX_REFERENCES})") | |
| refs_in = gr.File( | |
| file_count="multiple", | |
| file_types=["image"], | |
| show_label=False, | |
| height=140, | |
| ) | |
| refs_preview = gr.Gallery( | |
| show_label=False, | |
| columns=4, | |
| rows=2, | |
| height=220, | |
| object_fit="contain", | |
| preview=False, | |
| allow_preview=False, | |
| show_fullscreen_button=False, | |
| show_share_button=False, | |
| show_download_button=False, | |
| interactive=False, | |
| ) | |
| def _files_to_gallery(files: list | None) -> list[str]: | |
| if not files: | |
| return [] | |
| return [getattr(f, "name", f) for f in files] | |
| refs_in.change(fn=_files_to_gallery, inputs=refs_in, outputs=refs_preview) | |
| with gr.Row(): | |
| compare_btn = gr.Button("Compare styles", variant="primary", size="lg") | |
| gr.Markdown("### 3. Result") | |
| summary_out = gr.Markdown("_Run a comparison to see the result._") | |
| breakdown_out = gr.Dataframe( | |
| headers=["Reference", "Cosine similarity", "Label"], | |
| datatype=["str", "str", "str"], | |
| label="Per-reference breakdown", | |
| interactive=False, | |
| wrap=True, | |
| ) | |
| compare_btn.click( | |
| fn=compare_styles, | |
| inputs=[test_in, refs_in], | |
| outputs=[summary_out, breakdown_out], | |
| ) | |
| gr.Markdown( | |
| "**Label bands (cosine similarity):** " | |
| "π’ β₯ 0.75 Strong Β· π’ 0.65β0.75 Good Β· π‘ 0.55β0.65 Moderate Β· " | |
| "π 0.45β0.55 Weak Β· π΄ < 0.45 Minimal. " | |
| "Thresholds are heuristic; the raw cosine score is the source of truth." | |
| ) | |
| if __name__ == "__main__": | |
| # show_api=False avoids the /api/info endpoint whose JSON-schema walker | |
| # crashes on gr.Dataframe in older Gradio versions. Harmless for a demo Space. | |
| demo.queue().launch(show_api=False) | |