Elliot89 commited on
Commit
6f0e045
Β·
verified Β·
1 Parent(s): 25589b2

Upload 2 files

Browse files
Files changed (2) hide show
  1. api.py +198 -0
  2. extract_head.py +53 -0
api.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Universal Cross-Domain Vision Model β€” FastAPI Inference Server
3
+ ==============================================================
4
+ Run: uvicorn api:app --host 0.0.0.0 --port 8000 --reload
5
+
6
+ Endpoints
7
+ ---------
8
+ GET / health check
9
+ POST /predict upload an image β†’ JSON predictions
10
+ POST /predict/url pass an image URL β†’ JSON predictions
11
+ """
12
+
13
+ import io
14
+ import os
15
+ import base64
16
+ from typing import Optional
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import numpy as np
22
+ from PIL import Image
23
+
24
+ from fastapi import FastAPI, File, UploadFile, HTTPException
25
+ from fastapi.middleware.cors import CORSMiddleware
26
+ from pydantic import BaseModel, HttpUrl
27
+ import uvicorn
28
+
29
+ # ─────────────────────────────────────────────────────────────────────────────
30
+ # Config
31
+ # ─────────────────────────────────────────────────────────────────────────────
32
+ CHECKPOINT_PATH = os.environ.get(
33
+ "CHECKPOINT_PATH",
34
+ os.path.join(os.path.dirname(__file__), "..", "universal_vision_checkpoints", "best_model_phase1.pt"),
35
+ )
36
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+
38
+ MEDICAL_CLASSES = [
39
+ "Normal", "Pneumonia", "COVID-19", "Tuberculosis",
40
+ "Cardiomegaly", "Rib Fracture", "Lung Mass", "Pleural Effusion",
41
+ ]
42
+ SPORTS_CLASSES = ["Running", "Jumping", "Swimming", "Cycling", "Tennis", "Football"]
43
+ ALL_CLASSES = MEDICAL_CLASSES + SPORTS_CLASSES
44
+
45
+ # ─────────────────────────────────────────────────────────────────────────────
46
+ # Model (same architecture as app.py)
47
+ # ─────────────────────────────────────────────────────────────────────────────
48
+ class BiomedCLIPMultiModalFusion(nn.Module):
49
+ def __init__(self, embed_dim: int = 512, num_classes: int = len(ALL_CLASSES), dropout: float = 0.2):
50
+ super().__init__()
51
+ self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=8, dropout=dropout, batch_first=True)
52
+ self.ffn = nn.Sequential(
53
+ nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(dropout),
54
+ nn.Linear(embed_dim * 4, embed_dim), nn.Dropout(dropout),
55
+ )
56
+ self.norm1 = nn.LayerNorm(embed_dim)
57
+ self.norm2 = nn.LayerNorm(embed_dim)
58
+ self.domain_discriminator = nn.Sequential(
59
+ nn.Linear(embed_dim, embed_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(embed_dim // 2, 2),
60
+ )
61
+ self.classifier = nn.Sequential(
62
+ nn.Linear(embed_dim, embed_dim // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(embed_dim // 2, num_classes),
63
+ )
64
+
65
+ def forward(self, x):
66
+ x = x.unsqueeze(1)
67
+ attn_out, _ = self.attention(x, x, x)
68
+ x = self.norm1(x + attn_out)
69
+ fused = self.norm2(x + self.ffn(x)).squeeze(1)
70
+ return self.classifier(fused)
71
+
72
+
73
+ # ─────────────────────────────────────────────────────────────────────────────
74
+ # Singleton model loader
75
+ # ─────────────────────────────────────────────────────────────────────────────
76
+ _model = None
77
+ _backbone = None
78
+ _preprocess = None
79
+
80
+
81
+ def get_models():
82
+ global _model, _backbone, _preprocess
83
+ if _model is not None:
84
+ return _model, _backbone, _preprocess
85
+
86
+ try:
87
+ import open_clip
88
+ _backbone, _preprocess, _ = open_clip.create_model_and_transforms(
89
+ "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
90
+ )
91
+ except Exception:
92
+ import open_clip
93
+ _backbone, _, _preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
94
+
95
+ _backbone = _backbone.to(DEVICE).eval()
96
+ _model = BiomedCLIPMultiModalFusion().to(DEVICE).eval()
97
+
98
+ if os.path.isfile(CHECKPOINT_PATH):
99
+ ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
100
+ state = ckpt.get("model_state_dict", ckpt)
101
+ _model.load_state_dict(state, strict=False)
102
+
103
+ return _model, _backbone, _preprocess
104
+
105
+
106
+ def run_inference(pil_image: Image.Image) -> dict:
107
+ model, backbone, preprocess = get_models()
108
+ tensor = preprocess(pil_image).unsqueeze(0).to(DEVICE)
109
+ with torch.no_grad():
110
+ features = backbone.encode_image(tensor)
111
+ features = F.normalize(features.float(), dim=-1)
112
+ logits = model(features)
113
+ probs = F.softmax(logits, dim=-1).squeeze(0).cpu().tolist()
114
+ results = [{"label": lbl, "confidence": round(prob, 6)} for lbl, prob in zip(ALL_CLASSES, probs)]
115
+ results.sort(key=lambda x: x["confidence"], reverse=True)
116
+ return {"predictions": results, "top_prediction": results[0]}
117
+
118
+
119
+ # ─────────────────────────────────────────────────────────────────────────────
120
+ # FastAPI app
121
+ # ─────────────────────────────────────────────────────────────────────────────
122
+ app = FastAPI(
123
+ title="Universal Cross-Domain Vision Model API",
124
+ description="Classifies images across medical (X-ray pathologies) and sports domains.",
125
+ version="1.0.0",
126
+ )
127
+
128
+ app.add_middleware(
129
+ CORSMiddleware,
130
+ allow_origins=["*"],
131
+ allow_methods=["*"],
132
+ allow_headers=["*"],
133
+ )
134
+
135
+
136
+ @app.on_event("startup")
137
+ async def startup_event():
138
+ """Pre-load models at startup so first request is fast."""
139
+ get_models()
140
+
141
+
142
+ @app.get("/")
143
+ def health():
144
+ return {
145
+ "status": "ok",
146
+ "device": str(DEVICE),
147
+ "classes": ALL_CLASSES,
148
+ "checkpoint": os.path.isfile(CHECKPOINT_PATH),
149
+ }
150
+
151
+
152
+ @app.post("/predict")
153
+ async def predict_upload(file: UploadFile = File(...)):
154
+ """Upload an image file and get predictions."""
155
+ if not file.content_type.startswith("image/"):
156
+ raise HTTPException(status_code=400, detail="File must be an image.")
157
+ try:
158
+ contents = await file.read()
159
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
160
+ return run_inference(image)
161
+ except Exception as e:
162
+ raise HTTPException(status_code=500, detail=str(e))
163
+
164
+
165
+ class URLRequest(BaseModel):
166
+ url: str
167
+ timeout: Optional[int] = 10
168
+
169
+
170
+ @app.post("/predict/url")
171
+ async def predict_url(req: URLRequest):
172
+ """Pass an image URL and get predictions."""
173
+ import urllib.request
174
+ try:
175
+ with urllib.request.urlopen(req.url, timeout=req.timeout) as resp:
176
+ image = Image.open(io.BytesIO(resp.read())).convert("RGB")
177
+ return run_inference(image)
178
+ except Exception as e:
179
+ raise HTTPException(status_code=400, detail=f"Could not fetch image: {e}")
180
+
181
+
182
+ class Base64Request(BaseModel):
183
+ image_base64: str # base64-encoded image bytes
184
+
185
+
186
+ @app.post("/predict/base64")
187
+ async def predict_base64(req: Base64Request):
188
+ """Send a base64-encoded image and get predictions."""
189
+ try:
190
+ img_bytes = base64.b64decode(req.image_base64)
191
+ image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
192
+ return run_inference(image)
193
+ except Exception as e:
194
+ raise HTTPException(status_code=400, detail=str(e))
195
+
196
+
197
+ if __name__ == "__main__":
198
+ uvicorn.run("api:app", host="0.0.0.0", port=int(os.environ.get("PORT", 8000)), reload=True)
extract_head.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ extract_head.py
3
+ ===============
4
+ Run this ONCE on your local machine (where torch is installed):
5
+
6
+ cd D:\CoE\deploy
7
+ python extract_head.py
8
+
9
+ Reads best_model_phase1.pt (1.1 GB) and saves ONLY the fine-tuned layers:
10
+ - fusion.* (attention + FFN + norms) ~12 MB
11
+ - classifier.* (final classification head)
12
+ - uncertainty_head.*
13
+ - *_proj.* (lightweight projection adapters)
14
+
15
+ These total ~25 MB β€” well within HF's 1 GB limit.
16
+ The four backbone encoders (CLIP, ViT, ResNet, EfficientNet) are NOT saved
17
+ because app.py downloads them from HF Hub at runtime for free.
18
+ """
19
+
20
+ import torch, os
21
+
22
+ CHECKPOINT = os.path.join(
23
+ os.path.dirname(__file__),
24
+ "..", "universal_vision_checkpoints", "best_model_phase1.pt"
25
+ )
26
+ OUTPUT = os.path.join(os.path.dirname(__file__), "head_weights.pt")
27
+
28
+ print(f"Loading: {os.path.abspath(CHECKPOINT)}")
29
+ ckpt = torch.load(CHECKPOINT, map_location="cpu", weights_only=False)
30
+ state = ckpt.get("model_state_dict", ckpt)
31
+
32
+ # These are the BACKBONE prefixes β€” we drop them (loaded from HF Hub instead)
33
+ BACKBONE_PREFIXES = ("clip_model.", "vit.", "resnet.", "efficientnet.")
34
+
35
+ head_state = {
36
+ k: v for k, v in state.items()
37
+ if not any(k.startswith(p) for p in BACKBONE_PREFIXES)
38
+ }
39
+
40
+ total_mb = sum(v.numel() * v.element_size() for v in state.values()) / 1024**2
41
+ head_mb = sum(v.numel() * v.element_size() for v in head_state.values()) / 1024**2
42
+
43
+ print(f"\nFull checkpoint : {total_mb:.1f} MB ({len(state)} tensors)")
44
+ print(f"Head only : {head_mb:.2f} MB ({len(head_state)} tensors)")
45
+ print("\nSaved keys:")
46
+ for k, v in head_state.items():
47
+ kb = v.numel() * v.element_size() / 1024
48
+ print(f" {k:55s} {str(tuple(v.shape)):25s} {kb:.1f} KB")
49
+
50
+ torch.save({"model_state_dict": head_state}, OUTPUT)
51
+ print(f"\nβœ… Saved to: {os.path.abspath(OUTPUT)}")
52
+ print(f" Size: {os.path.getsize(OUTPUT)/1024**2:.2f} MB")
53
+ print("\nNext step: push head_weights.pt to your HF Space repo (no LFS needed).")