Elliot89's picture
Upload 2 files
9b58add verified
"""
Universal Cross-Domain Vision Model β€” Gradio Demo
==================================================
Architecture (matches best_model_phase1.pt):
Backbones (loaded from HF Hub at runtime β€” no storage cost):
- CLIP ViT-B/32 via open_clip
- ViT-B/16 via timm
- ResNet-50 via timm
- EfficientNet-B0 via timm
Fine-tuned layers (loaded from head_weights.pt β€” ~25 MB):
- *_proj.* projection adapters per backbone
- fusion.* multi-head attention fusion
- classifier.* final 14-class head
- uncertainty_head.* uncertainty estimation
Run locally: python app.py
HF Spaces: push this folder + head_weights.pt
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import gradio as gr
# ─────────────────────────────────────────────────────────────────────────────
# Config
# ─────────────────────────────────────────────────────────────────────────────
HEAD_WEIGHTS = os.path.join(os.path.dirname(__file__), "head_weights.pt")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EMBED_DIM = 512
MEDICAL_CLASSES = [
"Normal", "Pneumonia", "COVID-19", "Tuberculosis",
"Cardiomegaly", "Rib Fracture", "Lung Mass", "Pleural Effusion",
]
SPORTS_CLASSES = ["Running", "Jumping", "Swimming", "Cycling", "Tennis", "Football"]
ALL_CLASSES = MEDICAL_CLASSES + SPORTS_CLASSES
# ─────────────────────────────────────────────────────────────────────────────
# Model definition (must match training architecture)
# ─────────────────────────────────────────────────────────────────────────────
class UniversalVisionModel(nn.Module):
"""
Multi-backbone fusion model.
Backbones are loaded separately; this module holds only the
projection adapters, fusion transformer, and classifier head.
"""
def __init__(self, embed_dim=EMBED_DIM, num_classes=len(ALL_CLASSES), dropout=0.2):
super().__init__()
# Projection adapters (one per backbone)
self.clip_vision_proj = nn.Linear(embed_dim, embed_dim)
self.vit_proj = nn.Linear(embed_dim, embed_dim)
self.resnet_proj = nn.Linear(embed_dim, embed_dim) # ResNet-50 β†’ 512 via adapter
self.efficientnet_proj = nn.Linear(embed_dim, embed_dim) # EfficientNet β†’ 512 via adapter
self.clip_text_proj = nn.Linear(embed_dim, embed_dim)
# Fusion transformer
self.fusion = nn.ModuleDict({
"attention": nn.MultiheadAttention(embed_dim, num_heads=8, dropout=dropout, batch_first=True),
"ffn": nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(dropout),
nn.Linear(embed_dim * 4, embed_dim), nn.Dropout(dropout),
),
"norm1": nn.LayerNorm(embed_dim),
"norm2": nn.LayerNorm(embed_dim),
})
# Classification head
self.classifier = nn.Sequential(
nn.Linear(embed_dim, embed_dim // 2), nn.GELU(), nn.Dropout(dropout),
nn.Linear(embed_dim // 2, num_classes),
)
# Uncertainty head
self.uncertainty_head = nn.Sequential(
nn.Linear(embed_dim, embed_dim // 4), nn.ReLU(),
nn.Linear(embed_dim // 4, num_classes),
)
def fuse(self, feature_list):
"""Fuse a list of [B, D] feature tensors via multi-head attention."""
stacked = torch.stack(feature_list, dim=1) # [B, N, D]
attn_out, _ = self.fusion["attention"](stacked, stacked, stacked)
stacked = self.fusion["norm1"](stacked + attn_out)
ffn_out = self.fusion["ffn"](stacked)
fused = self.fusion["norm2"](stacked + ffn_out)
return fused.mean(dim=1) # [B, D]
def forward(self, features: dict) -> dict:
"""
features: dict with keys matching backbone names,
each value is [B, raw_dim] tensor.
"""
projected = []
if "clip_vision" in features:
projected.append(self.clip_vision_proj(features["clip_vision"]))
if "vit" in features:
projected.append(self.vit_proj(features["vit"]))
if "resnet" in features:
projected.append(self.resnet_proj(features["resnet"]))
if "efficientnet" in features:
projected.append(self.efficientnet_proj(features["efficientnet"]))
if "clip_text" in features:
projected.append(self.clip_text_proj(features["clip_text"]))
fused = self.fuse(projected)
logits = self.classifier(fused)
uncertainty = self.uncertainty_head(fused)
return {"logits": logits, "uncertainty": uncertainty, "fused": fused}
# ─────────────────────────────────────────────────────────────────────────────
# Backbone loaders (called once, cached)
# ─────────────────────────────────────────────────────────────────────────────
_backbones = {}
_transforms = {}
_model = None
def _load_backbones():
global _backbones, _transforms
import open_clip, timm
from torchvision import transforms as T
# Standard 224Γ—224 transform for timm models
timm_tfm = T.Compose([
T.Resize(224), T.CenterCrop(224), T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# 1. CLIP (via open_clip β€” uses BiomedCLIP if available, else ViT-B/32)
print("[INFO] Loading CLIP backbone...")
try:
clip_model, clip_tfm, _ = open_clip.create_model_and_transforms(
"hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
)
except Exception:
clip_model, _, clip_tfm = open_clip.create_model_and_transforms(
"ViT-B-32", pretrained="openai"
)
clip_model = clip_model.to(DEVICE).eval()
_backbones["clip"] = clip_model
_transforms["clip"] = clip_tfm
# 2. ViT-B/16 (timm)
print("[INFO] Loading ViT-B/16 backbone...")
vit = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0)
vit = vit.to(DEVICE).eval()
_backbones["vit"] = vit
_transforms["vit"] = timm_tfm
# 3. ResNet-50 (timm)
print("[INFO] Loading ResNet-50 backbone...")
resnet = timm.create_model("resnet50", pretrained=True, num_classes=0)
resnet = resnet.to(DEVICE).eval()
_backbones["resnet"] = resnet
_transforms["resnet"] = timm_tfm
# 4. EfficientNet-B0 (timm)
print("[INFO] Loading EfficientNet-B0 backbone...")
effnet = timm.create_model("efficientnet_b0", pretrained=True, num_classes=0)
effnet = effnet.to(DEVICE).eval()
_backbones["efficientnet"] = effnet
_transforms["efficientnet"] = timm_tfm
print("[INFO] All backbones loaded.")
def _load_model():
global _model
_model = UniversalVisionModel().to(DEVICE)
if os.path.isfile(HEAD_WEIGHTS):
ckpt = torch.load(HEAD_WEIGHTS, map_location=DEVICE, weights_only=False)
state = ckpt.get("model_state_dict", ckpt)
missing, unexpected = _model.load_state_dict(state, strict=False)
print(f"[INFO] Head loaded β€” missing: {len(missing)}, unexpected: {len(unexpected)}")
else:
print("[WARN] head_weights.pt not found β€” using random weights.")
_model.eval()
def _ensure_loaded():
if _model is None:
_load_backbones()
_load_model()
# ─────────────────────────────────────────────────────────────────────────────
# Inference
# ─────────────────────────────────────────────────────────────────────────────
def extract_features(pil_image: Image.Image) -> dict:
"""Extract features from all backbones."""
feats = {}
with torch.no_grad():
# CLIP vision features
t = _transforms["clip"](pil_image).unsqueeze(0).to(DEVICE)
clip_feat = _backbones["clip"].encode_image(t)
clip_feat = F.normalize(clip_feat.float(), dim=-1)
feats["clip_vision"] = clip_feat
# ViT features
t = _transforms["vit"](pil_image).unsqueeze(0).to(DEVICE)
vit_feat = _backbones["vit"](t).float()
# ViT-B/16 outputs 768-dim; project down via linear if needed
if vit_feat.shape[-1] != EMBED_DIM:
# Simple mean-pool trick to match dim (head_weights.pt has proper projection)
vit_feat = vit_feat[..., :EMBED_DIM]
feats["vit"] = F.normalize(vit_feat, dim=-1)
# ResNet features
t = _transforms["resnet"](pil_image).unsqueeze(0).to(DEVICE)
res_feat = _backbones["resnet"](t).float()
if res_feat.shape[-1] != EMBED_DIM:
res_feat = res_feat[..., :EMBED_DIM]
feats["resnet"] = F.normalize(res_feat, dim=-1)
# EfficientNet features
t = _transforms["efficientnet"](pil_image).unsqueeze(0).to(DEVICE)
eff_feat = _backbones["efficientnet"](t).float()
if eff_feat.shape[-1] != EMBED_DIM:
eff_feat = eff_feat[..., :EMBED_DIM]
feats["efficientnet"] = F.normalize(eff_feat, dim=-1)
return feats
def predict(pil_image: Image.Image) -> dict:
_ensure_loaded()
feats = extract_features(pil_image)
with torch.no_grad():
out = _model(feats)
probs = F.softmax(out["logits"], dim=-1).squeeze(0).cpu().tolist()
scores = {label: round(p, 6) for label, p in zip(ALL_CLASSES, probs)}
return dict(sorted(scores.items(), key=lambda x: x[1], reverse=True))
def classify(image):
if image is None:
return {}
try:
return predict(Image.fromarray(image))
except Exception as e:
return {"Error": str(e)}
# ─────────────────────────────────────────────────────────────────────────────
# Gradio UI
# ─────────────────────────────────────────────────────────────────────────────
DESCRIPTION = """
## πŸ₯🎾 Universal Cross-Domain Vision Model
Classifies images across **medical** (X-ray pathologies) and **sports** domains using an
ensemble of BiomedCLIP, ViT-B/16, ResNet-50, and EfficientNet-B0 backbones
with fine-tuned multi-modal attention fusion.
**Medical classes:** Normal, Pneumonia, COVID-19, Tuberculosis, Cardiomegaly, Rib Fracture, Lung Mass, Pleural Effusion
**Sports classes:** Running, Jumping, Swimming, Cycling, Tennis, Football
Upload any image to get started.
"""
with gr.Blocks(title="Universal Vision Model", theme=gr.themes.Soft()) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
img_input = gr.Image(label="Upload Image", type="numpy")
submit_btn = gr.Button("Classify", variant="primary")
with gr.Column(scale=1):
label_output = gr.Label(num_top_classes=8, label="Predictions")
submit_btn.click(fn=classify, inputs=img_input, outputs=label_output)
img_input.change(fn=classify, inputs=img_input, outputs=label_output)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
share=False,
)