Elliot89 commited on
Commit
07c2bbf
Β·
verified Β·
1 Parent(s): 6d327cd

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +46 -6
  2. app.py +224 -0
  3. requirements.txt +38 -0
README.md CHANGED
@@ -1,13 +1,53 @@
1
  ---
2
  title: Universal Cross-Domain Vision Model
3
- emoji: πŸ†
4
- colorFrom: purple
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.14.0
8
- python_version: '3.13'
9
  app_file: app.py
10
  pinned: false
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Universal Cross-Domain Vision Model
3
+ emoji: πŸ₯🎾
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: "4.0.0"
 
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # Universal Cross-Domain Vision Model
14
+
15
+ A BiomedCLIP-powered vision model that classifies images across **medical** and **sports** domains using multi-modal attention fusion.
16
+
17
+ ## How to deploy to Hugging Face Spaces
18
+
19
+ 1. Create a new Space at https://huggingface.co/new-space
20
+ - SDK: **Gradio**
21
+ - Visibility: Public or Private
22
+
23
+ 2. Upload these files to the Space repository:
24
+ ```
25
+ app.py
26
+ requirements.txt
27
+ README_HF_SPACES.md ← rename this to README.md in the Space
28
+ ```
29
+
30
+ 3. Upload your checkpoint:
31
+ ```
32
+ universal_vision_checkpoints/best_model_phase1.pt
33
+ ```
34
+ > For large files (>1 GB) use Git LFS:
35
+ > ```bash
36
+ > git lfs install
37
+ > git lfs track "*.pt"
38
+ > git add .gitattributes
39
+ > ```
40
+
41
+ 4. Set the environment variable in Space Settings β†’ Variables:
42
+ ```
43
+ CHECKPOINT_PATH = universal_vision_checkpoints/best_model_phase1.pt
44
+ ```
45
+
46
+ 5. The Space will build automatically. First build takes ~5 minutes.
47
+
48
+ ## Classes
49
+
50
+ | Domain | Classes |
51
+ |----------|---------|
52
+ | Medical | Normal, Pneumonia, COVID-19, Tuberculosis, Cardiomegaly, Rib Fracture, Lung Mass, Pleural Effusion |
53
+ | Sports | Running, Jumping, Swimming, Cycling, Tennis, Football |
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Universal Cross-Domain Vision Model β€” Gradio Demo
3
+ ==================================================
4
+ Runs locally: python app.py
5
+ HF Spaces: push this folder to a Space (SDK: gradio)
6
+
7
+ The app loads the trained BiomedCLIP checkpoint and classifies uploaded images
8
+ across medical (8 pathologies) and sports (6 action categories) domains.
9
+ """
10
+
11
+ import os
12
+ import io
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import numpy as np
17
+ from PIL import Image
18
+ import gradio as gr
19
+
20
+ # ─────────────────────────────────────────────────────────────────────────────
21
+ # Configuration
22
+ # ─────────────────────────────────────────────────────────────────────────────
23
+ CHECKPOINT_PATH = os.environ.get(
24
+ "CHECKPOINT_PATH",
25
+ os.path.join(os.path.dirname(__file__), "..", "universal_vision_checkpoints", "best_model_phase1.pt"),
26
+ )
27
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ MEDICAL_CLASSES = [
30
+ "Normal",
31
+ "Pneumonia",
32
+ "COVID-19",
33
+ "Tuberculosis",
34
+ "Cardiomegaly",
35
+ "Rib Fracture",
36
+ "Lung Mass",
37
+ "Pleural Effusion",
38
+ ]
39
+
40
+ SPORTS_CLASSES = [
41
+ "Running",
42
+ "Jumping",
43
+ "Swimming",
44
+ "Cycling",
45
+ "Tennis",
46
+ "Football",
47
+ ]
48
+
49
+ ALL_CLASSES = MEDICAL_CLASSES + SPORTS_CLASSES
50
+
51
+ # ─────────────────────────────────────────────────────────────────────────────
52
+ # Model Definition (must match training architecture)
53
+ # ─────────────────────────────────────────────────────────────────────────────
54
+ class BiomedCLIPMultiModalFusion(nn.Module):
55
+ """Lightweight inference-only wrapper matching the training architecture."""
56
+
57
+ def __init__(self, embed_dim: int = 512, num_classes: int = len(ALL_CLASSES), dropout: float = 0.2):
58
+ super().__init__()
59
+ self.embed_dim = embed_dim
60
+
61
+ # Domain discriminator (kept for architecture compatibility)
62
+ self.domain_discriminator = nn.Sequential(
63
+ nn.Linear(embed_dim, embed_dim // 2),
64
+ nn.ReLU(),
65
+ nn.Dropout(dropout),
66
+ nn.Linear(embed_dim // 2, 2),
67
+ )
68
+
69
+ # Multi-head attention fusion
70
+ self.attention = nn.MultiheadAttention(
71
+ embed_dim=embed_dim, num_heads=8, dropout=dropout, batch_first=True
72
+ )
73
+
74
+ # Feed-forward network
75
+ self.ffn = nn.Sequential(
76
+ nn.Linear(embed_dim, embed_dim * 4),
77
+ nn.GELU(),
78
+ nn.Dropout(dropout),
79
+ nn.Linear(embed_dim * 4, embed_dim),
80
+ nn.Dropout(dropout),
81
+ )
82
+
83
+ self.norm1 = nn.LayerNorm(embed_dim)
84
+ self.norm2 = nn.LayerNorm(embed_dim)
85
+
86
+ # Classifier head
87
+ self.classifier = nn.Sequential(
88
+ nn.Linear(embed_dim, embed_dim // 2),
89
+ nn.GELU(),
90
+ nn.Dropout(dropout),
91
+ nn.Linear(embed_dim // 2, num_classes),
92
+ )
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ # x: [B, embed_dim] β€” pre-extracted image features
96
+ x = x.unsqueeze(1) # [B, 1, D]
97
+ attn_out, _ = self.attention(x, x, x)
98
+ x = self.norm1(x + attn_out)
99
+ ffn_out = self.ffn(x)
100
+ fused = self.norm2(x + ffn_out).squeeze(1) # [B, D]
101
+ return self.classifier(fused)
102
+
103
+
104
+ # ─────────────────────────────────────────────────────────────────────────────
105
+ # Load model + backbone
106
+ # ─────────────────────────────────────────────────────────────────────────────
107
+ _model = None
108
+ _backbone = None
109
+ _preprocess = None
110
+
111
+
112
+ def _load_models():
113
+ global _model, _backbone, _preprocess
114
+
115
+ if _model is not None:
116
+ return
117
+
118
+ print(f"[INFO] Loading models on {DEVICE} …")
119
+
120
+ # Try BiomedCLIP first, fall back to standard CLIP
121
+ try:
122
+ import open_clip
123
+ _backbone, _preprocess, _ = open_clip.create_model_and_transforms(
124
+ "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
125
+ )
126
+ embed_dim = 512
127
+ print("[INFO] BiomedCLIP backbone loaded.")
128
+ except Exception as e:
129
+ print(f"[WARN] BiomedCLIP failed ({e}), using CLIP-ViT-B/32.")
130
+ import open_clip
131
+ _backbone, _, _preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
132
+ embed_dim = 512
133
+
134
+ _backbone = _backbone.to(DEVICE).eval()
135
+
136
+ # Build fusion model
137
+ _model = BiomedCLIPMultiModalFusion(embed_dim=embed_dim, num_classes=len(ALL_CLASSES))
138
+
139
+ # Load checkpoint weights (graceful fallback if checkpoint is missing)
140
+ if os.path.isfile(CHECKPOINT_PATH):
141
+ try:
142
+ ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
143
+ state = ckpt.get("model_state_dict", ckpt)
144
+ _model.load_state_dict(state, strict=False)
145
+ print(f"[INFO] Checkpoint loaded from {CHECKPOINT_PATH}")
146
+ except Exception as e:
147
+ print(f"[WARN] Could not load checkpoint: {e}. Running with random weights.")
148
+ else:
149
+ print(f"[WARN] Checkpoint not found at {CHECKPOINT_PATH}. Running with random weights.")
150
+
151
+ _model = _model.to(DEVICE).eval()
152
+ print("[INFO] Model ready.")
153
+
154
+
155
+ # ─────────────────────────────────────────────────────────────────────────────
156
+ # Inference
157
+ # ─────────────────────────────────────────────────────────────────────────────
158
+ def predict(image: Image.Image) -> dict:
159
+ """Run inference on a PIL image. Returns a {label: confidence} dict."""
160
+ _load_models()
161
+
162
+ # Pre-process
163
+ tensor = _preprocess(image).unsqueeze(0).to(DEVICE)
164
+
165
+ with torch.no_grad():
166
+ features = _backbone.encode_image(tensor) # [1, D]
167
+ features = F.normalize(features.float(), dim=-1)
168
+ logits = _model(features) # [1, num_classes]
169
+ probs = F.softmax(logits, dim=-1).squeeze(0).cpu().numpy()
170
+
171
+ return {label: float(prob) for label, prob in zip(ALL_CLASSES, probs)}
172
+
173
+
174
+ def classify(image):
175
+ if image is None:
176
+ return {}
177
+ try:
178
+ pil_image = Image.fromarray(image).convert("RGB")
179
+ scores = predict(pil_image)
180
+ # Sort by confidence descending
181
+ return dict(sorted(scores.items(), key=lambda x: x[1], reverse=True))
182
+ except Exception as e:
183
+ return {"Error": str(e)}
184
+
185
+
186
+ # ─────────────────────────────────────────────────────────────────────────────
187
+ # Gradio Interface
188
+ # ─────────────────────────────────────────────────────────────────────────────
189
+ DESCRIPTION = """
190
+ ## πŸ₯🎾 Universal Cross-Domain Vision Model
191
+
192
+ Classifies images across **medical** (X-ray pathologies) and **sports** domains using a
193
+ BiomedCLIP backbone with multi-modal attention fusion.
194
+
195
+ **Medical classes:** Normal, Pneumonia, COVID-19, Tuberculosis, Cardiomegaly, Rib Fracture, Lung Mass, Pleural Effusion
196
+ **Sports classes:** Running, Jumping, Swimming, Cycling, Tennis, Football
197
+
198
+ Upload any image to get started.
199
+ """
200
+
201
+ with gr.Blocks(title="Universal Vision Model", theme=gr.themes.Soft()) as demo:
202
+ gr.Markdown(DESCRIPTION)
203
+
204
+ with gr.Row():
205
+ with gr.Column(scale=1):
206
+ img_input = gr.Image(label="Upload Image", type="numpy")
207
+ submit_btn = gr.Button("Classify", variant="primary")
208
+ with gr.Column(scale=1):
209
+ label_output = gr.Label(num_top_classes=8, label="Predictions")
210
+
211
+ submit_btn.click(fn=classify, inputs=img_input, outputs=label_output)
212
+ img_input.change(fn=classify, inputs=img_input, outputs=label_output)
213
+
214
+ gr.Examples(
215
+ examples=[], # Add example image paths here if available
216
+ inputs=img_input,
217
+ )
218
+
219
+ if __name__ == "__main__":
220
+ demo.launch(
221
+ server_name="0.0.0.0",
222
+ server_port=int(os.environ.get("PORT", 7860)),
223
+ share=False,
224
+ )
requirements.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── Core ML ──────────────────────────────────────────────────────────────────
2
+ torch>=2.1.0
3
+ torchvision>=0.16.0
4
+ timm>=0.9.12
5
+
6
+ # ── Vision-Language Models ────────────────────────────────────────────────────
7
+ transformers>=4.35.0
8
+ open-clip-torch>=2.24.0
9
+
10
+ # ── Medical Datasets ──────────────────────────────────────────────────────────
11
+ medmnist>=2.2.0
12
+
13
+ # ── Image Processing ──────────────────────────────────────────────────────────
14
+ Pillow>=10.0.0
15
+ opencv-python-headless>=4.8.0
16
+ albumentations>=1.3.0
17
+
18
+ # ── Data / Metrics ────────────────────────────────────────────────────────────
19
+ numpy>=1.24.0
20
+ scikit-learn>=1.3.0
21
+ scipy>=1.11.0
22
+ pandas>=2.0.0
23
+
24
+ # ── Visualisation ────────────────────────────────────────────────────────────
25
+ matplotlib>=3.7.0
26
+ seaborn>=0.12.0
27
+ tqdm>=4.65.0
28
+
29
+ # ── Web Demo ──────────────────────────────────────────────────────────────────
30
+ gradio>=4.0.0
31
+
32
+ # ── REST API ──────────────────────────────────────────────────────────────────
33
+ fastapi>=0.104.0
34
+ uvicorn[standard]>=0.24.0
35
+ python-multipart>=0.0.6
36
+
37
+ # ── Utilities ────────────────────────────────────────────────────────────────
38
+ huggingface_hub>=0.19.0