File size: 7,867 Bytes
33eae59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711c874
33eae59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711c874
33eae59
 
 
 
711c874
 
 
 
33eae59
 
711c874
33eae59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0a71d7
 
 
 
 
 
 
33eae59
 
 
c0a71d7
bdf41b9
 
 
 
 
 
c0a71d7
33eae59
 
c0a71d7
 
 
 
 
 
 
33eae59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d76f8c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""MegaStyle Image Style Comparison β€” Gradio Space (ZeroGPU).

Compares a test image against 1-8 reference images using MegaStyle-Encoder
(SigLIP-so400m-patch14-384 fine-tuned on MegaStyle-1.4M with style-supervised
contrastive learning). Reports the mean cosine similarity as the headline
score plus a per-reference breakdown.
"""

from __future__ import annotations

from typing import Any

import spaces  # must be imported before torch for ZeroGPU
import gradio as gr
import torch
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import SiglipImageProcessor, SiglipVisionModel


MAX_REFERENCES = 8
SIGLIP_ID = "google/siglip-so400m-patch14-384"
ENCODER_REPO = "Gaojunyao/MegaStyle"
ENCODER_FILE = "megastyle_encoder.pth"

LABEL_BANDS: list[tuple[float, str, str]] = [
    (0.75, "Strong style match", "🟒"),
    (0.65, "Good style match", "🟒"),
    (0.55, "Moderate style match", "🟑"),
    (0.45, "Weak style match", "🟠"),
    (-1.0, "Minimal style match", "πŸ”΄"),
]


def _load_model_cpu() -> tuple[SiglipVisionModel, SiglipImageProcessor]:
    """Load SigLIP on CPU and overlay MegaStyle-Encoder fine-tuned weights.

    Kept on CPU at module scope; moved to GPU lazily inside the @spaces.GPU
    function. ZeroGPU does not bind a GPU at import time.

    Uses SiglipImageProcessor (image-only) rather than AutoProcessor to avoid
    pulling in SiglipTokenizer, which requires sentencepiece. Vision-only
    inference does not need the tokenizer.
    """
    model = SiglipVisionModel.from_pretrained(SIGLIP_ID).eval()
    processor = SiglipImageProcessor.from_pretrained(SIGLIP_ID)

    ckpt_path = hf_hub_download(ENCODER_REPO, ENCODER_FILE)
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    state = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint
    model.load_state_dict(state, strict=False)
    return model, processor


MODEL, PROCESSOR = _load_model_cpu()


def _label_for(score: float) -> tuple[str, str]:
    """Returns (label, emoji) for a cosine-similarity score."""
    for threshold, label, emoji in LABEL_BANDS:
        if score >= threshold:
            return label, emoji
    _, label, emoji = LABEL_BANDS[-1]
    return label, emoji


def _coerce_to_pil(image: Any) -> Image.Image:
    """Gallery/Image components return varied payloads across Gradio versions.

    Accepts: PIL.Image, numpy arrays, filepath strings, dicts with 'image'/'name'
    keys, or tuples of (path, caption). Returns an RGB PIL.Image.
    """
    if isinstance(image, Image.Image):
        return image.convert("RGB")
    if isinstance(image, tuple) and image:
        return _coerce_to_pil(image[0])
    if isinstance(image, dict):
        for key in ("image", "name", "path", "orig_name"):
            if key in image and image[key] is not None:
                return _coerce_to_pil(image[key])
    if isinstance(image, str):
        return Image.open(image).convert("RGB")
    # numpy array fallback
    try:
        return Image.fromarray(image).convert("RGB")
    except Exception as exc:  # surface a readable error to the user
        raise ValueError(f"Unsupported image input type: {type(image).__name__}") from exc


@torch.no_grad()
def _embed(images: list[Image.Image], device: torch.device) -> torch.Tensor:
    inputs = PROCESSOR(images=images, return_tensors="pt")
    pixel_values = inputs["pixel_values"].to(device)
    embeds = MODEL(pixel_values=pixel_values).pooler_output
    embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True)
    return embeds


@spaces.GPU(duration=30)
def compare_styles(
    test_image: Image.Image | None,
    reference_images: list | None,
) -> tuple[str, list[list]]:
    if test_image is None:
        return "### Upload a test image to compare.", []
    if not reference_images:
        return "### Add at least one reference image.", []

    clipped_note = ""
    total = len(reference_images)
    if total > MAX_REFERENCES:
        clipped_note = f"\n\n_Using the first {MAX_REFERENCES} of {total} references._"
    refs = [_coerce_to_pil(r) for r in reference_images[:MAX_REFERENCES]]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if next(MODEL.parameters()).device != device:
        MODEL.to(device)

    test_emb = _embed([_coerce_to_pil(test_image)], device)
    ref_emb = _embed(refs, device)

    per_ref = (test_emb @ ref_emb.T).squeeze(0).cpu().tolist()
    mean_score = float(sum(per_ref) / len(per_ref))
    verdict, emoji = _label_for(mean_score)

    summary = (
        f"## {emoji} {verdict}\n\n"
        f"**Score:** `{mean_score:.3f}`  Β·  "
        f"Mean cosine similarity across {len(refs)} reference image"
        f"{'s' if len(refs) != 1 else ''}. Range: `-1` (opposite) … `1` (identical style)."
        f"{clipped_note}"
    )

    table = []
    for i, score in enumerate(per_ref):
        ref_label, ref_emoji = _label_for(score)
        table.append([f"Reference {i + 1}", f"{score:.3f}", f"{ref_emoji} {ref_label}"])
    return summary, table


with gr.Blocks(title="MegaStyle Style Comparison", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        "# 🎨 MegaStyle Image Style Comparison\n"
        "Compare how well a **test image** matches the style of **1–"
        f"{MAX_REFERENCES} reference images** using "
        "[MegaStyle-Encoder](https://huggingface.co/Gaojunyao/MegaStyle) "
        "([paper](https://arxiv.org/abs/2604.08364))."
    )

    with gr.Row(equal_height=True):
        with gr.Column(scale=1):
            gr.Markdown("### 1. Test image")
            test_in = gr.Image(
                type="pil",
                show_label=False,
                height=360,
                sources=["upload", "clipboard"],
            )
        with gr.Column(scale=1):
            gr.Markdown(f"### 2. Reference images (up to {MAX_REFERENCES})")
            refs_in = gr.File(
                file_count="multiple",
                file_types=["image"],
                show_label=False,
                height=140,
            )
            refs_preview = gr.Gallery(
                show_label=False,
                columns=4,
                rows=2,
                height=220,
                object_fit="contain",
                preview=False,
                allow_preview=False,
                show_fullscreen_button=False,
                show_share_button=False,
                show_download_button=False,
                interactive=False,
            )

    def _files_to_gallery(files: list | None) -> list[str]:
        if not files:
            return []
        return [getattr(f, "name", f) for f in files]

    refs_in.change(fn=_files_to_gallery, inputs=refs_in, outputs=refs_preview)

    with gr.Row():
        compare_btn = gr.Button("Compare styles", variant="primary", size="lg")

    gr.Markdown("### 3. Result")
    summary_out = gr.Markdown("_Run a comparison to see the result._")
    breakdown_out = gr.Dataframe(
        headers=["Reference", "Cosine similarity", "Label"],
        datatype=["str", "str", "str"],
        label="Per-reference breakdown",
        interactive=False,
        wrap=True,
    )

    compare_btn.click(
        fn=compare_styles,
        inputs=[test_in, refs_in],
        outputs=[summary_out, breakdown_out],
    )

    gr.Markdown(
        "**Label bands (cosine similarity):** "
        "🟒 β‰₯ 0.75 Strong Β· 🟒 0.65–0.75 Good Β· 🟑 0.55–0.65 Moderate Β· "
        "🟠 0.45–0.55 Weak Β· πŸ”΄ < 0.45 Minimal. "
        "Thresholds are heuristic; the raw cosine score is the source of truth."
    )


if __name__ == "__main__":
    # show_api=False avoids the /api/info endpoint whose JSON-schema walker
    # crashes on gr.Dataframe in older Gradio versions. Harmless for a demo Space.
    demo.queue().launch(show_api=False)