Elliot89 commited on
Commit
9b58add
Β·
verified Β·
1 Parent(s): a5a19ee

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +196 -135
  2. head_weights.pt +3 -0
app.py CHANGED
@@ -1,196 +1,263 @@
1
  """
2
  Universal Cross-Domain Vision Model β€” Gradio Demo
3
  ==================================================
4
- Runs locally: python app.py
5
- HF Spaces: push this folder to a Space (SDK: gradio)
6
-
7
- The app loads the trained BiomedCLIP checkpoint and classifies uploaded images
8
- across medical (8 pathologies) and sports (6 action categories) domains.
 
 
 
 
 
 
 
 
 
 
9
  """
10
 
11
  import os
12
- import io
13
  import torch
14
  import torch.nn as nn
15
  import torch.nn.functional as F
16
- import numpy as np
17
  from PIL import Image
18
  import gradio as gr
19
 
20
  # ─────────────────────────────────────────────────────────────────────────────
21
- # Configuration
22
  # ─────────────────────────────────────────────────────────────────────────────
23
- CHECKPOINT_PATH = os.environ.get(
24
- "CHECKPOINT_PATH",
25
- os.path.join(os.path.dirname(__file__), "..", "universal_vision_checkpoints", "best_model_phase1.pt"),
26
- )
27
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
28
 
29
  MEDICAL_CLASSES = [
30
- "Normal",
31
- "Pneumonia",
32
- "COVID-19",
33
- "Tuberculosis",
34
- "Cardiomegaly",
35
- "Rib Fracture",
36
- "Lung Mass",
37
- "Pleural Effusion",
38
- ]
39
-
40
- SPORTS_CLASSES = [
41
- "Running",
42
- "Jumping",
43
- "Swimming",
44
- "Cycling",
45
- "Tennis",
46
- "Football",
47
  ]
48
-
49
  ALL_CLASSES = MEDICAL_CLASSES + SPORTS_CLASSES
50
 
51
  # ─────────────────────────────────────────────────────────────────────────────
52
- # Model Definition (must match training architecture)
53
  # ─────────────────────────────────────────────────────────────────────────────
54
- class BiomedCLIPMultiModalFusion(nn.Module):
55
- """Lightweight inference-only wrapper matching the training architecture."""
56
-
57
- def __init__(self, embed_dim: int = 512, num_classes: int = len(ALL_CLASSES), dropout: float = 0.2):
 
 
 
 
58
  super().__init__()
59
- self.embed_dim = embed_dim
60
-
61
- # Domain discriminator (kept for architecture compatibility)
62
- self.domain_discriminator = nn.Sequential(
63
- nn.Linear(embed_dim, embed_dim // 2),
64
- nn.ReLU(),
65
- nn.Dropout(dropout),
66
- nn.Linear(embed_dim // 2, 2),
67
- )
68
 
69
- # Multi-head attention fusion
70
- self.attention = nn.MultiheadAttention(
71
- embed_dim=embed_dim, num_heads=8, dropout=dropout, batch_first=True
72
- )
73
-
74
- # Feed-forward network
75
- self.ffn = nn.Sequential(
76
- nn.Linear(embed_dim, embed_dim * 4),
77
- nn.GELU(),
78
- nn.Dropout(dropout),
79
- nn.Linear(embed_dim * 4, embed_dim),
80
- nn.Dropout(dropout),
81
- )
82
-
83
- self.norm1 = nn.LayerNorm(embed_dim)
84
- self.norm2 = nn.LayerNorm(embed_dim)
85
-
86
- # Classifier head
 
87
  self.classifier = nn.Sequential(
88
- nn.Linear(embed_dim, embed_dim // 2),
89
- nn.GELU(),
90
- nn.Dropout(dropout),
91
  nn.Linear(embed_dim // 2, num_classes),
92
  )
93
 
94
- def forward(self, x: torch.Tensor) -> torch.Tensor:
95
- # x: [B, embed_dim] β€” pre-extracted image features
96
- x = x.unsqueeze(1) # [B, 1, D]
97
- attn_out, _ = self.attention(x, x, x)
98
- x = self.norm1(x + attn_out)
99
- ffn_out = self.ffn(x)
100
- fused = self.norm2(x + ffn_out).squeeze(1) # [B, D]
101
- return self.classifier(fused)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
 
104
  # ─────────────────────────────────────────────────────────────────────────────
105
- # Load model + backbone
106
  # ─────────────────────────────────────────────────────────────────────────────
 
 
107
  _model = None
108
- _backbone = None
109
- _preprocess = None
110
 
111
 
112
- def _load_models():
113
- global _model, _backbone, _preprocess
114
 
115
- if _model is not None:
116
- return
117
 
118
- print(f"[INFO] Loading models on {DEVICE} …")
 
 
 
 
119
 
120
- # Try BiomedCLIP first, fall back to standard CLIP
 
121
  try:
122
- import open_clip
123
- _backbone, _preprocess, _ = open_clip.create_model_and_transforms(
124
  "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
125
  )
126
- embed_dim = 512
127
- print("[INFO] BiomedCLIP backbone loaded.")
128
- except Exception as e:
129
- print(f"[WARN] BiomedCLIP failed ({e}), using CLIP-ViT-B/32.")
130
- import open_clip
131
- _backbone, _, _preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
132
- embed_dim = 512
133
-
134
- _backbone = _backbone.to(DEVICE).eval()
135
-
136
- # Build fusion model
137
- _model = BiomedCLIPMultiModalFusion(embed_dim=embed_dim, num_classes=len(ALL_CLASSES))
138
-
139
- # Load checkpoint weights (graceful fallback if checkpoint is missing)
140
- if os.path.isfile(CHECKPOINT_PATH):
141
- try:
142
- ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
143
- state = ckpt.get("model_state_dict", ckpt)
144
- _model.load_state_dict(state, strict=False)
145
- print(f"[INFO] Checkpoint loaded from {CHECKPOINT_PATH}")
146
- except Exception as e:
147
- print(f"[WARN] Could not load checkpoint: {e}. Running with random weights.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  else:
149
- print(f"[WARN] Checkpoint not found at {CHECKPOINT_PATH}. Running with random weights.")
 
150
 
151
- _model = _model.to(DEVICE).eval()
152
- print("[INFO] Model ready.")
 
 
 
153
 
154
 
155
  # ─────────────────────────────────────────────────────────────────────────────
156
  # Inference
157
  # ─────────────────────────────────────────────────────────────────────────────
158
- def predict(image: Image.Image) -> dict:
159
- """Run inference on a PIL image. Returns a {label: confidence} dict."""
160
- _load_models()
161
-
162
- # Pre-process
163
- tensor = _preprocess(image).unsqueeze(0).to(DEVICE)
164
-
165
  with torch.no_grad():
166
- features = _backbone.encode_image(tensor) # [1, D]
167
- features = F.normalize(features.float(), dim=-1)
168
- logits = _model(features) # [1, num_classes]
169
- probs = F.softmax(logits, dim=-1).squeeze(0).cpu().numpy()
170
-
171
- return {label: float(prob) for label, prob in zip(ALL_CLASSES, probs)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
 
174
  def classify(image):
175
  if image is None:
176
  return {}
177
  try:
178
- pil_image = Image.fromarray(image).convert("RGB")
179
- scores = predict(pil_image)
180
- # Sort by confidence descending
181
- return dict(sorted(scores.items(), key=lambda x: x[1], reverse=True))
182
  except Exception as e:
183
  return {"Error": str(e)}
184
 
185
 
186
  # ─────────────────────────────────────────────────────────────────────────────
187
- # Gradio Interface
188
  # ─────────────────────────────────────────────────────────────────────────────
189
  DESCRIPTION = """
190
  ## πŸ₯🎾 Universal Cross-Domain Vision Model
191
 
192
- Classifies images across **medical** (X-ray pathologies) and **sports** domains using a
193
- BiomedCLIP backbone with multi-modal attention fusion.
 
194
 
195
  **Medical classes:** Normal, Pneumonia, COVID-19, Tuberculosis, Cardiomegaly, Rib Fracture, Lung Mass, Pleural Effusion
196
  **Sports classes:** Running, Jumping, Swimming, Cycling, Tennis, Football
@@ -200,7 +267,6 @@ Upload any image to get started.
200
 
201
  with gr.Blocks(title="Universal Vision Model", theme=gr.themes.Soft()) as demo:
202
  gr.Markdown(DESCRIPTION)
203
-
204
  with gr.Row():
205
  with gr.Column(scale=1):
206
  img_input = gr.Image(label="Upload Image", type="numpy")
@@ -211,11 +277,6 @@ with gr.Blocks(title="Universal Vision Model", theme=gr.themes.Soft()) as demo:
211
  submit_btn.click(fn=classify, inputs=img_input, outputs=label_output)
212
  img_input.change(fn=classify, inputs=img_input, outputs=label_output)
213
 
214
- gr.Examples(
215
- examples=[], # Add example image paths here if available
216
- inputs=img_input,
217
- )
218
-
219
  if __name__ == "__main__":
220
  demo.launch(
221
  server_name="0.0.0.0",
 
1
  """
2
  Universal Cross-Domain Vision Model β€” Gradio Demo
3
  ==================================================
4
+ Architecture (matches best_model_phase1.pt):
5
+ Backbones (loaded from HF Hub at runtime β€” no storage cost):
6
+ - CLIP ViT-B/32 via open_clip
7
+ - ViT-B/16 via timm
8
+ - ResNet-50 via timm
9
+ - EfficientNet-B0 via timm
10
+
11
+ Fine-tuned layers (loaded from head_weights.pt β€” ~25 MB):
12
+ - *_proj.* projection adapters per backbone
13
+ - fusion.* multi-head attention fusion
14
+ - classifier.* final 14-class head
15
+ - uncertainty_head.* uncertainty estimation
16
+
17
+ Run locally: python app.py
18
+ HF Spaces: push this folder + head_weights.pt
19
  """
20
 
21
  import os
 
22
  import torch
23
  import torch.nn as nn
24
  import torch.nn.functional as F
 
25
  from PIL import Image
26
  import gradio as gr
27
 
28
  # ─────────────────────────────────────────────────────────────────────────────
29
+ # Config
30
  # ─────────────────────────────────────────────────────────────────────────────
31
+ HEAD_WEIGHTS = os.path.join(os.path.dirname(__file__), "head_weights.pt")
 
 
 
32
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ EMBED_DIM = 512
34
 
35
  MEDICAL_CLASSES = [
36
+ "Normal", "Pneumonia", "COVID-19", "Tuberculosis",
37
+ "Cardiomegaly", "Rib Fracture", "Lung Mass", "Pleural Effusion",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  ]
39
+ SPORTS_CLASSES = ["Running", "Jumping", "Swimming", "Cycling", "Tennis", "Football"]
40
  ALL_CLASSES = MEDICAL_CLASSES + SPORTS_CLASSES
41
 
42
  # ─────────────────────────────────────────────────────────────────────────────
43
+ # Model definition (must match training architecture)
44
  # ─────────────────────────────────────────────────────────────────────────────
45
+ class UniversalVisionModel(nn.Module):
46
+ """
47
+ Multi-backbone fusion model.
48
+ Backbones are loaded separately; this module holds only the
49
+ projection adapters, fusion transformer, and classifier head.
50
+ """
51
+
52
+ def __init__(self, embed_dim=EMBED_DIM, num_classes=len(ALL_CLASSES), dropout=0.2):
53
  super().__init__()
 
 
 
 
 
 
 
 
 
54
 
55
+ # Projection adapters (one per backbone)
56
+ self.clip_vision_proj = nn.Linear(embed_dim, embed_dim)
57
+ self.vit_proj = nn.Linear(embed_dim, embed_dim)
58
+ self.resnet_proj = nn.Linear(embed_dim, embed_dim) # ResNet-50 β†’ 512 via adapter
59
+ self.efficientnet_proj = nn.Linear(embed_dim, embed_dim) # EfficientNet β†’ 512 via adapter
60
+ self.clip_text_proj = nn.Linear(embed_dim, embed_dim)
61
+
62
+ # Fusion transformer
63
+ self.fusion = nn.ModuleDict({
64
+ "attention": nn.MultiheadAttention(embed_dim, num_heads=8, dropout=dropout, batch_first=True),
65
+ "ffn": nn.Sequential(
66
+ nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(dropout),
67
+ nn.Linear(embed_dim * 4, embed_dim), nn.Dropout(dropout),
68
+ ),
69
+ "norm1": nn.LayerNorm(embed_dim),
70
+ "norm2": nn.LayerNorm(embed_dim),
71
+ })
72
+
73
+ # Classification head
74
  self.classifier = nn.Sequential(
75
+ nn.Linear(embed_dim, embed_dim // 2), nn.GELU(), nn.Dropout(dropout),
 
 
76
  nn.Linear(embed_dim // 2, num_classes),
77
  )
78
 
79
+ # Uncertainty head
80
+ self.uncertainty_head = nn.Sequential(
81
+ nn.Linear(embed_dim, embed_dim // 4), nn.ReLU(),
82
+ nn.Linear(embed_dim // 4, num_classes),
83
+ )
84
+
85
+ def fuse(self, feature_list):
86
+ """Fuse a list of [B, D] feature tensors via multi-head attention."""
87
+ stacked = torch.stack(feature_list, dim=1) # [B, N, D]
88
+ attn_out, _ = self.fusion["attention"](stacked, stacked, stacked)
89
+ stacked = self.fusion["norm1"](stacked + attn_out)
90
+ ffn_out = self.fusion["ffn"](stacked)
91
+ fused = self.fusion["norm2"](stacked + ffn_out)
92
+ return fused.mean(dim=1) # [B, D]
93
+
94
+ def forward(self, features: dict) -> dict:
95
+ """
96
+ features: dict with keys matching backbone names,
97
+ each value is [B, raw_dim] tensor.
98
+ """
99
+ projected = []
100
+ if "clip_vision" in features:
101
+ projected.append(self.clip_vision_proj(features["clip_vision"]))
102
+ if "vit" in features:
103
+ projected.append(self.vit_proj(features["vit"]))
104
+ if "resnet" in features:
105
+ projected.append(self.resnet_proj(features["resnet"]))
106
+ if "efficientnet" in features:
107
+ projected.append(self.efficientnet_proj(features["efficientnet"]))
108
+ if "clip_text" in features:
109
+ projected.append(self.clip_text_proj(features["clip_text"]))
110
+
111
+ fused = self.fuse(projected)
112
+ logits = self.classifier(fused)
113
+ uncertainty = self.uncertainty_head(fused)
114
+ return {"logits": logits, "uncertainty": uncertainty, "fused": fused}
115
 
116
 
117
  # ─────────────────────────────────────────────────────────────────────────────
118
+ # Backbone loaders (called once, cached)
119
  # ─────────────────────────────────────────────────────────────────────────────
120
+ _backbones = {}
121
+ _transforms = {}
122
  _model = None
 
 
123
 
124
 
125
+ def _load_backbones():
126
+ global _backbones, _transforms
127
 
128
+ import open_clip, timm
129
+ from torchvision import transforms as T
130
 
131
+ # Standard 224Γ—224 transform for timm models
132
+ timm_tfm = T.Compose([
133
+ T.Resize(224), T.CenterCrop(224), T.ToTensor(),
134
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
135
+ ])
136
 
137
+ # 1. CLIP (via open_clip β€” uses BiomedCLIP if available, else ViT-B/32)
138
+ print("[INFO] Loading CLIP backbone...")
139
  try:
140
+ clip_model, clip_tfm, _ = open_clip.create_model_and_transforms(
 
141
  "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
142
  )
143
+ except Exception:
144
+ clip_model, _, clip_tfm = open_clip.create_model_and_transforms(
145
+ "ViT-B-32", pretrained="openai"
146
+ )
147
+ clip_model = clip_model.to(DEVICE).eval()
148
+ _backbones["clip"] = clip_model
149
+ _transforms["clip"] = clip_tfm
150
+
151
+ # 2. ViT-B/16 (timm)
152
+ print("[INFO] Loading ViT-B/16 backbone...")
153
+ vit = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0)
154
+ vit = vit.to(DEVICE).eval()
155
+ _backbones["vit"] = vit
156
+ _transforms["vit"] = timm_tfm
157
+
158
+ # 3. ResNet-50 (timm)
159
+ print("[INFO] Loading ResNet-50 backbone...")
160
+ resnet = timm.create_model("resnet50", pretrained=True, num_classes=0)
161
+ resnet = resnet.to(DEVICE).eval()
162
+ _backbones["resnet"] = resnet
163
+ _transforms["resnet"] = timm_tfm
164
+
165
+ # 4. EfficientNet-B0 (timm)
166
+ print("[INFO] Loading EfficientNet-B0 backbone...")
167
+ effnet = timm.create_model("efficientnet_b0", pretrained=True, num_classes=0)
168
+ effnet = effnet.to(DEVICE).eval()
169
+ _backbones["efficientnet"] = effnet
170
+ _transforms["efficientnet"] = timm_tfm
171
+
172
+ print("[INFO] All backbones loaded.")
173
+
174
+
175
+ def _load_model():
176
+ global _model
177
+ _model = UniversalVisionModel().to(DEVICE)
178
+ if os.path.isfile(HEAD_WEIGHTS):
179
+ ckpt = torch.load(HEAD_WEIGHTS, map_location=DEVICE, weights_only=False)
180
+ state = ckpt.get("model_state_dict", ckpt)
181
+ missing, unexpected = _model.load_state_dict(state, strict=False)
182
+ print(f"[INFO] Head loaded β€” missing: {len(missing)}, unexpected: {len(unexpected)}")
183
  else:
184
+ print("[WARN] head_weights.pt not found β€” using random weights.")
185
+ _model.eval()
186
 
187
+
188
+ def _ensure_loaded():
189
+ if _model is None:
190
+ _load_backbones()
191
+ _load_model()
192
 
193
 
194
  # ─────────────────────────────────────────────────────────────────────────────
195
  # Inference
196
  # ─────────────────────────────────────────────────────────────────────────────
197
+ def extract_features(pil_image: Image.Image) -> dict:
198
+ """Extract features from all backbones."""
199
+ feats = {}
 
 
 
 
200
  with torch.no_grad():
201
+ # CLIP vision features
202
+ t = _transforms["clip"](pil_image).unsqueeze(0).to(DEVICE)
203
+ clip_feat = _backbones["clip"].encode_image(t)
204
+ clip_feat = F.normalize(clip_feat.float(), dim=-1)
205
+ feats["clip_vision"] = clip_feat
206
+
207
+ # ViT features
208
+ t = _transforms["vit"](pil_image).unsqueeze(0).to(DEVICE)
209
+ vit_feat = _backbones["vit"](t).float()
210
+ # ViT-B/16 outputs 768-dim; project down via linear if needed
211
+ if vit_feat.shape[-1] != EMBED_DIM:
212
+ # Simple mean-pool trick to match dim (head_weights.pt has proper projection)
213
+ vit_feat = vit_feat[..., :EMBED_DIM]
214
+ feats["vit"] = F.normalize(vit_feat, dim=-1)
215
+
216
+ # ResNet features
217
+ t = _transforms["resnet"](pil_image).unsqueeze(0).to(DEVICE)
218
+ res_feat = _backbones["resnet"](t).float()
219
+ if res_feat.shape[-1] != EMBED_DIM:
220
+ res_feat = res_feat[..., :EMBED_DIM]
221
+ feats["resnet"] = F.normalize(res_feat, dim=-1)
222
+
223
+ # EfficientNet features
224
+ t = _transforms["efficientnet"](pil_image).unsqueeze(0).to(DEVICE)
225
+ eff_feat = _backbones["efficientnet"](t).float()
226
+ if eff_feat.shape[-1] != EMBED_DIM:
227
+ eff_feat = eff_feat[..., :EMBED_DIM]
228
+ feats["efficientnet"] = F.normalize(eff_feat, dim=-1)
229
+
230
+ return feats
231
+
232
+
233
+ def predict(pil_image: Image.Image) -> dict:
234
+ _ensure_loaded()
235
+ feats = extract_features(pil_image)
236
+ with torch.no_grad():
237
+ out = _model(feats)
238
+ probs = F.softmax(out["logits"], dim=-1).squeeze(0).cpu().tolist()
239
+ scores = {label: round(p, 6) for label, p in zip(ALL_CLASSES, probs)}
240
+ return dict(sorted(scores.items(), key=lambda x: x[1], reverse=True))
241
 
242
 
243
  def classify(image):
244
  if image is None:
245
  return {}
246
  try:
247
+ return predict(Image.fromarray(image))
 
 
 
248
  except Exception as e:
249
  return {"Error": str(e)}
250
 
251
 
252
  # ─────────────────────────────────────────────────────────────────────────────
253
+ # Gradio UI
254
  # ─────────────────────────────────────────────────────────────────────────────
255
  DESCRIPTION = """
256
  ## πŸ₯🎾 Universal Cross-Domain Vision Model
257
 
258
+ Classifies images across **medical** (X-ray pathologies) and **sports** domains using an
259
+ ensemble of BiomedCLIP, ViT-B/16, ResNet-50, and EfficientNet-B0 backbones
260
+ with fine-tuned multi-modal attention fusion.
261
 
262
  **Medical classes:** Normal, Pneumonia, COVID-19, Tuberculosis, Cardiomegaly, Rib Fracture, Lung Mass, Pleural Effusion
263
  **Sports classes:** Running, Jumping, Swimming, Cycling, Tennis, Football
 
267
 
268
  with gr.Blocks(title="Universal Vision Model", theme=gr.themes.Soft()) as demo:
269
  gr.Markdown(DESCRIPTION)
 
270
  with gr.Row():
271
  with gr.Column(scale=1):
272
  img_input = gr.Image(label="Upload Image", type="numpy")
 
277
  submit_btn.click(fn=classify, inputs=img_input, outputs=label_output)
278
  img_input.change(fn=classify, inputs=img_input, outputs=label_output)
279
 
 
 
 
 
 
280
  if __name__ == "__main__":
281
  demo.launch(
282
  server_name="0.0.0.0",
head_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dae17f3ebc3025aa5d4bfe007741aab77c4b956fb3205ad1d7a8059ed54595f4
3
+ size 25521282