Spaces:
Running on Zero
Running on Zero
File size: 7,867 Bytes
33eae59 711c874 33eae59 711c874 33eae59 711c874 33eae59 711c874 33eae59 c0a71d7 33eae59 c0a71d7 bdf41b9 c0a71d7 33eae59 c0a71d7 33eae59 5d76f8c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | """MegaStyle Image Style Comparison β Gradio Space (ZeroGPU).
Compares a test image against 1-8 reference images using MegaStyle-Encoder
(SigLIP-so400m-patch14-384 fine-tuned on MegaStyle-1.4M with style-supervised
contrastive learning). Reports the mean cosine similarity as the headline
score plus a per-reference breakdown.
"""
from __future__ import annotations
from typing import Any
import spaces # must be imported before torch for ZeroGPU
import gradio as gr
import torch
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import SiglipImageProcessor, SiglipVisionModel
MAX_REFERENCES = 8
SIGLIP_ID = "google/siglip-so400m-patch14-384"
ENCODER_REPO = "Gaojunyao/MegaStyle"
ENCODER_FILE = "megastyle_encoder.pth"
LABEL_BANDS: list[tuple[float, str, str]] = [
(0.75, "Strong style match", "π’"),
(0.65, "Good style match", "π’"),
(0.55, "Moderate style match", "π‘"),
(0.45, "Weak style match", "π "),
(-1.0, "Minimal style match", "π΄"),
]
def _load_model_cpu() -> tuple[SiglipVisionModel, SiglipImageProcessor]:
"""Load SigLIP on CPU and overlay MegaStyle-Encoder fine-tuned weights.
Kept on CPU at module scope; moved to GPU lazily inside the @spaces.GPU
function. ZeroGPU does not bind a GPU at import time.
Uses SiglipImageProcessor (image-only) rather than AutoProcessor to avoid
pulling in SiglipTokenizer, which requires sentencepiece. Vision-only
inference does not need the tokenizer.
"""
model = SiglipVisionModel.from_pretrained(SIGLIP_ID).eval()
processor = SiglipImageProcessor.from_pretrained(SIGLIP_ID)
ckpt_path = hf_hub_download(ENCODER_REPO, ENCODER_FILE)
checkpoint = torch.load(ckpt_path, map_location="cpu")
state = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint
model.load_state_dict(state, strict=False)
return model, processor
MODEL, PROCESSOR = _load_model_cpu()
def _label_for(score: float) -> tuple[str, str]:
"""Returns (label, emoji) for a cosine-similarity score."""
for threshold, label, emoji in LABEL_BANDS:
if score >= threshold:
return label, emoji
_, label, emoji = LABEL_BANDS[-1]
return label, emoji
def _coerce_to_pil(image: Any) -> Image.Image:
"""Gallery/Image components return varied payloads across Gradio versions.
Accepts: PIL.Image, numpy arrays, filepath strings, dicts with 'image'/'name'
keys, or tuples of (path, caption). Returns an RGB PIL.Image.
"""
if isinstance(image, Image.Image):
return image.convert("RGB")
if isinstance(image, tuple) and image:
return _coerce_to_pil(image[0])
if isinstance(image, dict):
for key in ("image", "name", "path", "orig_name"):
if key in image and image[key] is not None:
return _coerce_to_pil(image[key])
if isinstance(image, str):
return Image.open(image).convert("RGB")
# numpy array fallback
try:
return Image.fromarray(image).convert("RGB")
except Exception as exc: # surface a readable error to the user
raise ValueError(f"Unsupported image input type: {type(image).__name__}") from exc
@torch.no_grad()
def _embed(images: list[Image.Image], device: torch.device) -> torch.Tensor:
inputs = PROCESSOR(images=images, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(device)
embeds = MODEL(pixel_values=pixel_values).pooler_output
embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True)
return embeds
@spaces.GPU(duration=30)
def compare_styles(
test_image: Image.Image | None,
reference_images: list | None,
) -> tuple[str, list[list]]:
if test_image is None:
return "### Upload a test image to compare.", []
if not reference_images:
return "### Add at least one reference image.", []
clipped_note = ""
total = len(reference_images)
if total > MAX_REFERENCES:
clipped_note = f"\n\n_Using the first {MAX_REFERENCES} of {total} references._"
refs = [_coerce_to_pil(r) for r in reference_images[:MAX_REFERENCES]]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if next(MODEL.parameters()).device != device:
MODEL.to(device)
test_emb = _embed([_coerce_to_pil(test_image)], device)
ref_emb = _embed(refs, device)
per_ref = (test_emb @ ref_emb.T).squeeze(0).cpu().tolist()
mean_score = float(sum(per_ref) / len(per_ref))
verdict, emoji = _label_for(mean_score)
summary = (
f"## {emoji} {verdict}\n\n"
f"**Score:** `{mean_score:.3f}` Β· "
f"Mean cosine similarity across {len(refs)} reference image"
f"{'s' if len(refs) != 1 else ''}. Range: `-1` (opposite) β¦ `1` (identical style)."
f"{clipped_note}"
)
table = []
for i, score in enumerate(per_ref):
ref_label, ref_emoji = _label_for(score)
table.append([f"Reference {i + 1}", f"{score:.3f}", f"{ref_emoji} {ref_label}"])
return summary, table
with gr.Blocks(title="MegaStyle Style Comparison", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"# π¨ MegaStyle Image Style Comparison\n"
"Compare how well a **test image** matches the style of **1β"
f"{MAX_REFERENCES} reference images** using "
"[MegaStyle-Encoder](https://huggingface.co/Gaojunyao/MegaStyle) "
"([paper](https://arxiv.org/abs/2604.08364))."
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
gr.Markdown("### 1. Test image")
test_in = gr.Image(
type="pil",
show_label=False,
height=360,
sources=["upload", "clipboard"],
)
with gr.Column(scale=1):
gr.Markdown(f"### 2. Reference images (up to {MAX_REFERENCES})")
refs_in = gr.File(
file_count="multiple",
file_types=["image"],
show_label=False,
height=140,
)
refs_preview = gr.Gallery(
show_label=False,
columns=4,
rows=2,
height=220,
object_fit="contain",
preview=False,
allow_preview=False,
show_fullscreen_button=False,
show_share_button=False,
show_download_button=False,
interactive=False,
)
def _files_to_gallery(files: list | None) -> list[str]:
if not files:
return []
return [getattr(f, "name", f) for f in files]
refs_in.change(fn=_files_to_gallery, inputs=refs_in, outputs=refs_preview)
with gr.Row():
compare_btn = gr.Button("Compare styles", variant="primary", size="lg")
gr.Markdown("### 3. Result")
summary_out = gr.Markdown("_Run a comparison to see the result._")
breakdown_out = gr.Dataframe(
headers=["Reference", "Cosine similarity", "Label"],
datatype=["str", "str", "str"],
label="Per-reference breakdown",
interactive=False,
wrap=True,
)
compare_btn.click(
fn=compare_styles,
inputs=[test_in, refs_in],
outputs=[summary_out, breakdown_out],
)
gr.Markdown(
"**Label bands (cosine similarity):** "
"π’ β₯ 0.75 Strong Β· π’ 0.65β0.75 Good Β· π‘ 0.55β0.65 Moderate Β· "
"π 0.45β0.55 Weak Β· π΄ < 0.45 Minimal. "
"Thresholds are heuristic; the raw cosine score is the source of truth."
)
if __name__ == "__main__":
# show_api=False avoids the /api/info endpoint whose JSON-schema walker
# crashes on gr.Dataframe in older Gradio versions. Harmless for a demo Space.
demo.queue().launch(show_api=False)
|