File size: 12,567 Bytes
07c2bbf
 
 
9b58add
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07c2bbf
 
 
 
 
 
 
 
 
 
9b58add
07c2bbf
9b58add
07c2bbf
9b58add
07c2bbf
 
9b58add
 
07c2bbf
9b58add
07c2bbf
 
 
9b58add
07c2bbf
9b58add
 
 
 
 
 
 
 
07c2bbf
 
9b58add
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07c2bbf
9b58add
07c2bbf
 
 
9b58add
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07c2bbf
 
 
9b58add
07c2bbf
9b58add
 
07c2bbf
 
 
9b58add
 
07c2bbf
9b58add
 
07c2bbf
9b58add
 
 
 
 
07c2bbf
9b58add
 
07c2bbf
9b58add
07c2bbf
 
9b58add
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07c2bbf
9b58add
 
07c2bbf
9b58add
 
 
 
 
07c2bbf
 
 
 
 
9b58add
 
 
07c2bbf
9b58add
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07c2bbf
 
 
 
 
 
9b58add
07c2bbf
 
 
 
 
9b58add
07c2bbf
 
 
 
9b58add
 
 
07c2bbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
"""
Universal Cross-Domain Vision Model β€” Gradio Demo
==================================================
Architecture (matches best_model_phase1.pt):
  Backbones (loaded from HF Hub at runtime β€” no storage cost):
    - CLIP ViT-B/32         via open_clip
    - ViT-B/16              via timm
    - ResNet-50             via timm
    - EfficientNet-B0       via timm

  Fine-tuned layers (loaded from head_weights.pt β€” ~25 MB):
    - *_proj.*              projection adapters per backbone
    - fusion.*              multi-head attention fusion
    - classifier.*          final 14-class head
    - uncertainty_head.*    uncertainty estimation

Run locally:   python app.py
HF Spaces:     push this folder + head_weights.pt
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import gradio as gr

# ─────────────────────────────────────────────────────────────────────────────
# Config
# ─────────────────────────────────────────────────────────────────────────────
HEAD_WEIGHTS = os.path.join(os.path.dirname(__file__), "head_weights.pt")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EMBED_DIM = 512

MEDICAL_CLASSES = [
    "Normal", "Pneumonia", "COVID-19", "Tuberculosis",
    "Cardiomegaly", "Rib Fracture", "Lung Mass", "Pleural Effusion",
]
SPORTS_CLASSES = ["Running", "Jumping", "Swimming", "Cycling", "Tennis", "Football"]
ALL_CLASSES = MEDICAL_CLASSES + SPORTS_CLASSES

# ─────────────────────────────────────────────────────────────────────────────
# Model definition (must match training architecture)
# ─────────────────────────────────────────────────────────────────────────────
class UniversalVisionModel(nn.Module):
    """
    Multi-backbone fusion model.
    Backbones are loaded separately; this module holds only the
    projection adapters, fusion transformer, and classifier head.
    """

    def __init__(self, embed_dim=EMBED_DIM, num_classes=len(ALL_CLASSES), dropout=0.2):
        super().__init__()

        # Projection adapters (one per backbone)
        self.clip_vision_proj   = nn.Linear(embed_dim, embed_dim)
        self.vit_proj           = nn.Linear(embed_dim, embed_dim)
        self.resnet_proj        = nn.Linear(embed_dim, embed_dim)  # ResNet-50 β†’ 512 via adapter
        self.efficientnet_proj  = nn.Linear(embed_dim, embed_dim)  # EfficientNet β†’ 512 via adapter
        self.clip_text_proj     = nn.Linear(embed_dim, embed_dim)

        # Fusion transformer
        self.fusion = nn.ModuleDict({
            "attention": nn.MultiheadAttention(embed_dim, num_heads=8, dropout=dropout, batch_first=True),
            "ffn": nn.Sequential(
                nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(dropout),
                nn.Linear(embed_dim * 4, embed_dim), nn.Dropout(dropout),
            ),
            "norm1": nn.LayerNorm(embed_dim),
            "norm2": nn.LayerNorm(embed_dim),
        })

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(embed_dim // 2, num_classes),
        )

        # Uncertainty head
        self.uncertainty_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 4), nn.ReLU(),
            nn.Linear(embed_dim // 4, num_classes),
        )

    def fuse(self, feature_list):
        """Fuse a list of [B, D] feature tensors via multi-head attention."""
        stacked = torch.stack(feature_list, dim=1)          # [B, N, D]
        attn_out, _ = self.fusion["attention"](stacked, stacked, stacked)
        stacked = self.fusion["norm1"](stacked + attn_out)
        ffn_out = self.fusion["ffn"](stacked)
        fused = self.fusion["norm2"](stacked + ffn_out)
        return fused.mean(dim=1)                             # [B, D]

    def forward(self, features: dict) -> dict:
        """
        features: dict with keys matching backbone names,
                  each value is [B, raw_dim] tensor.
        """
        projected = []
        if "clip_vision" in features:
            projected.append(self.clip_vision_proj(features["clip_vision"]))
        if "vit" in features:
            projected.append(self.vit_proj(features["vit"]))
        if "resnet" in features:
            projected.append(self.resnet_proj(features["resnet"]))
        if "efficientnet" in features:
            projected.append(self.efficientnet_proj(features["efficientnet"]))
        if "clip_text" in features:
            projected.append(self.clip_text_proj(features["clip_text"]))

        fused = self.fuse(projected)
        logits = self.classifier(fused)
        uncertainty = self.uncertainty_head(fused)
        return {"logits": logits, "uncertainty": uncertainty, "fused": fused}


# ─────────────────────────────────────────────────────────────────────────────
# Backbone loaders (called once, cached)
# ─────────────────────────────────────────────────────────────────────────────
_backbones = {}
_transforms = {}
_model = None


def _load_backbones():
    global _backbones, _transforms

    import open_clip, timm
    from torchvision import transforms as T

    # Standard 224Γ—224 transform for timm models
    timm_tfm = T.Compose([
        T.Resize(224), T.CenterCrop(224), T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    # 1. CLIP (via open_clip β€” uses BiomedCLIP if available, else ViT-B/32)
    print("[INFO] Loading CLIP backbone...")
    try:
        clip_model, clip_tfm, _ = open_clip.create_model_and_transforms(
            "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
        )
    except Exception:
        clip_model, _, clip_tfm = open_clip.create_model_and_transforms(
            "ViT-B-32", pretrained="openai"
        )
    clip_model = clip_model.to(DEVICE).eval()
    _backbones["clip"] = clip_model
    _transforms["clip"] = clip_tfm

    # 2. ViT-B/16 (timm)
    print("[INFO] Loading ViT-B/16 backbone...")
    vit = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0)
    vit = vit.to(DEVICE).eval()
    _backbones["vit"] = vit
    _transforms["vit"] = timm_tfm

    # 3. ResNet-50 (timm)
    print("[INFO] Loading ResNet-50 backbone...")
    resnet = timm.create_model("resnet50", pretrained=True, num_classes=0)
    resnet = resnet.to(DEVICE).eval()
    _backbones["resnet"] = resnet
    _transforms["resnet"] = timm_tfm

    # 4. EfficientNet-B0 (timm)
    print("[INFO] Loading EfficientNet-B0 backbone...")
    effnet = timm.create_model("efficientnet_b0", pretrained=True, num_classes=0)
    effnet = effnet.to(DEVICE).eval()
    _backbones["efficientnet"] = effnet
    _transforms["efficientnet"] = timm_tfm

    print("[INFO] All backbones loaded.")


def _load_model():
    global _model
    _model = UniversalVisionModel().to(DEVICE)
    if os.path.isfile(HEAD_WEIGHTS):
        ckpt = torch.load(HEAD_WEIGHTS, map_location=DEVICE, weights_only=False)
        state = ckpt.get("model_state_dict", ckpt)
        missing, unexpected = _model.load_state_dict(state, strict=False)
        print(f"[INFO] Head loaded β€” missing: {len(missing)}, unexpected: {len(unexpected)}")
    else:
        print("[WARN] head_weights.pt not found β€” using random weights.")
    _model.eval()


def _ensure_loaded():
    if _model is None:
        _load_backbones()
        _load_model()


# ─────────────────────────────────────────────────────────────────────────────
# Inference
# ─────────────────────────────────────────────────────────────────────────────
def extract_features(pil_image: Image.Image) -> dict:
    """Extract features from all backbones."""
    feats = {}
    with torch.no_grad():
        # CLIP vision features
        t = _transforms["clip"](pil_image).unsqueeze(0).to(DEVICE)
        clip_feat = _backbones["clip"].encode_image(t)
        clip_feat = F.normalize(clip_feat.float(), dim=-1)
        feats["clip_vision"] = clip_feat

        # ViT features
        t = _transforms["vit"](pil_image).unsqueeze(0).to(DEVICE)
        vit_feat = _backbones["vit"](t).float()
        # ViT-B/16 outputs 768-dim; project down via linear if needed
        if vit_feat.shape[-1] != EMBED_DIM:
            # Simple mean-pool trick to match dim (head_weights.pt has proper projection)
            vit_feat = vit_feat[..., :EMBED_DIM]
        feats["vit"] = F.normalize(vit_feat, dim=-1)

        # ResNet features
        t = _transforms["resnet"](pil_image).unsqueeze(0).to(DEVICE)
        res_feat = _backbones["resnet"](t).float()
        if res_feat.shape[-1] != EMBED_DIM:
            res_feat = res_feat[..., :EMBED_DIM]
        feats["resnet"] = F.normalize(res_feat, dim=-1)

        # EfficientNet features
        t = _transforms["efficientnet"](pil_image).unsqueeze(0).to(DEVICE)
        eff_feat = _backbones["efficientnet"](t).float()
        if eff_feat.shape[-1] != EMBED_DIM:
            eff_feat = eff_feat[..., :EMBED_DIM]
        feats["efficientnet"] = F.normalize(eff_feat, dim=-1)

    return feats


def predict(pil_image: Image.Image) -> dict:
    _ensure_loaded()
    feats = extract_features(pil_image)
    with torch.no_grad():
        out = _model(feats)
        probs = F.softmax(out["logits"], dim=-1).squeeze(0).cpu().tolist()
    scores = {label: round(p, 6) for label, p in zip(ALL_CLASSES, probs)}
    return dict(sorted(scores.items(), key=lambda x: x[1], reverse=True))


def classify(image):
    if image is None:
        return {}
    try:
        return predict(Image.fromarray(image))
    except Exception as e:
        return {"Error": str(e)}


# ─────────────────────────────────────────────────────────────────────────────
# Gradio UI
# ─────────────────────────────────────────────────────────────────────────────
DESCRIPTION = """
## πŸ₯🎾 Universal Cross-Domain Vision Model

Classifies images across **medical** (X-ray pathologies) and **sports** domains using an
ensemble of BiomedCLIP, ViT-B/16, ResNet-50, and EfficientNet-B0 backbones
with fine-tuned multi-modal attention fusion.

**Medical classes:** Normal, Pneumonia, COVID-19, Tuberculosis, Cardiomegaly, Rib Fracture, Lung Mass, Pleural Effusion
**Sports classes:** Running, Jumping, Swimming, Cycling, Tennis, Football

Upload any image to get started.
"""

with gr.Blocks(title="Universal Vision Model", theme=gr.themes.Soft()) as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column(scale=1):
            img_input = gr.Image(label="Upload Image", type="numpy")
            submit_btn = gr.Button("Classify", variant="primary")
        with gr.Column(scale=1):
            label_output = gr.Label(num_top_classes=8, label="Predictions")

    submit_btn.click(fn=classify, inputs=img_input, outputs=label_output)
    img_input.change(fn=classify, inputs=img_input, outputs=label_output)

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=int(os.environ.get("PORT", 7860)),
        share=False,
    )