""" Universal Cross-Domain Vision Model — Gradio Demo ================================================== Architecture (matches best_model_phase1.pt): Backbones (loaded from HF Hub at runtime — no storage cost): - CLIP ViT-B/32 via open_clip - ViT-B/16 via timm - ResNet-50 via timm - EfficientNet-B0 via timm Fine-tuned layers (loaded from head_weights.pt — ~25 MB): - *_proj.* projection adapters per backbone - fusion.* multi-head attention fusion - classifier.* final 14-class head - uncertainty_head.* uncertainty estimation Run locally: python app.py HF Spaces: push this folder + head_weights.pt """ import os import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image import gradio as gr # ───────────────────────────────────────────────────────────────────────────── # Config # ───────────────────────────────────────────────────────────────────────────── HEAD_WEIGHTS = os.path.join(os.path.dirname(__file__), "head_weights.pt") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") EMBED_DIM = 512 MEDICAL_CLASSES = [ "Normal", "Pneumonia", "COVID-19", "Tuberculosis", "Cardiomegaly", "Rib Fracture", "Lung Mass", "Pleural Effusion", ] SPORTS_CLASSES = ["Running", "Jumping", "Swimming", "Cycling", "Tennis", "Football"] ALL_CLASSES = MEDICAL_CLASSES + SPORTS_CLASSES # ───────────────────────────────────────────────────────────────────────────── # Model definition (must match training architecture) # ───────────────────────────────────────────────────────────────────────────── class UniversalVisionModel(nn.Module): """ Multi-backbone fusion model. Backbones are loaded separately; this module holds only the projection adapters, fusion transformer, and classifier head. """ def __init__(self, embed_dim=EMBED_DIM, num_classes=len(ALL_CLASSES), dropout=0.2): super().__init__() # Projection adapters (one per backbone) self.clip_vision_proj = nn.Linear(embed_dim, embed_dim) self.vit_proj = nn.Linear(embed_dim, embed_dim) self.resnet_proj = nn.Linear(embed_dim, embed_dim) # ResNet-50 → 512 via adapter self.efficientnet_proj = nn.Linear(embed_dim, embed_dim) # EfficientNet → 512 via adapter self.clip_text_proj = nn.Linear(embed_dim, embed_dim) # Fusion transformer self.fusion = nn.ModuleDict({ "attention": nn.MultiheadAttention(embed_dim, num_heads=8, dropout=dropout, batch_first=True), "ffn": nn.Sequential( nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(embed_dim * 4, embed_dim), nn.Dropout(dropout), ), "norm1": nn.LayerNorm(embed_dim), "norm2": nn.LayerNorm(embed_dim), }) # Classification head self.classifier = nn.Sequential( nn.Linear(embed_dim, embed_dim // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(embed_dim // 2, num_classes), ) # Uncertainty head self.uncertainty_head = nn.Sequential( nn.Linear(embed_dim, embed_dim // 4), nn.ReLU(), nn.Linear(embed_dim // 4, num_classes), ) def fuse(self, feature_list): """Fuse a list of [B, D] feature tensors via multi-head attention.""" stacked = torch.stack(feature_list, dim=1) # [B, N, D] attn_out, _ = self.fusion["attention"](stacked, stacked, stacked) stacked = self.fusion["norm1"](stacked + attn_out) ffn_out = self.fusion["ffn"](stacked) fused = self.fusion["norm2"](stacked + ffn_out) return fused.mean(dim=1) # [B, D] def forward(self, features: dict) -> dict: """ features: dict with keys matching backbone names, each value is [B, raw_dim] tensor. """ projected = [] if "clip_vision" in features: projected.append(self.clip_vision_proj(features["clip_vision"])) if "vit" in features: projected.append(self.vit_proj(features["vit"])) if "resnet" in features: projected.append(self.resnet_proj(features["resnet"])) if "efficientnet" in features: projected.append(self.efficientnet_proj(features["efficientnet"])) if "clip_text" in features: projected.append(self.clip_text_proj(features["clip_text"])) fused = self.fuse(projected) logits = self.classifier(fused) uncertainty = self.uncertainty_head(fused) return {"logits": logits, "uncertainty": uncertainty, "fused": fused} # ───────────────────────────────────────────────────────────────────────────── # Backbone loaders (called once, cached) # ───────────────────────────────────────────────────────────────────────────── _backbones = {} _transforms = {} _model = None def _load_backbones(): global _backbones, _transforms import open_clip, timm from torchvision import transforms as T # Standard 224×224 transform for timm models timm_tfm = T.Compose([ T.Resize(224), T.CenterCrop(224), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) # 1. CLIP (via open_clip — uses BiomedCLIP if available, else ViT-B/32) print("[INFO] Loading CLIP backbone...") try: clip_model, clip_tfm, _ = open_clip.create_model_and_transforms( "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224" ) except Exception: clip_model, _, clip_tfm = open_clip.create_model_and_transforms( "ViT-B-32", pretrained="openai" ) clip_model = clip_model.to(DEVICE).eval() _backbones["clip"] = clip_model _transforms["clip"] = clip_tfm # 2. ViT-B/16 (timm) print("[INFO] Loading ViT-B/16 backbone...") vit = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0) vit = vit.to(DEVICE).eval() _backbones["vit"] = vit _transforms["vit"] = timm_tfm # 3. ResNet-50 (timm) print("[INFO] Loading ResNet-50 backbone...") resnet = timm.create_model("resnet50", pretrained=True, num_classes=0) resnet = resnet.to(DEVICE).eval() _backbones["resnet"] = resnet _transforms["resnet"] = timm_tfm # 4. EfficientNet-B0 (timm) print("[INFO] Loading EfficientNet-B0 backbone...") effnet = timm.create_model("efficientnet_b0", pretrained=True, num_classes=0) effnet = effnet.to(DEVICE).eval() _backbones["efficientnet"] = effnet _transforms["efficientnet"] = timm_tfm print("[INFO] All backbones loaded.") def _load_model(): global _model _model = UniversalVisionModel().to(DEVICE) if os.path.isfile(HEAD_WEIGHTS): ckpt = torch.load(HEAD_WEIGHTS, map_location=DEVICE, weights_only=False) state = ckpt.get("model_state_dict", ckpt) missing, unexpected = _model.load_state_dict(state, strict=False) print(f"[INFO] Head loaded — missing: {len(missing)}, unexpected: {len(unexpected)}") else: print("[WARN] head_weights.pt not found — using random weights.") _model.eval() def _ensure_loaded(): if _model is None: _load_backbones() _load_model() # ───────────────────────────────────────────────────────────────────────────── # Inference # ───────────────────────────────────────────────────────────────────────────── def extract_features(pil_image: Image.Image) -> dict: """Extract features from all backbones.""" feats = {} with torch.no_grad(): # CLIP vision features t = _transforms["clip"](pil_image).unsqueeze(0).to(DEVICE) clip_feat = _backbones["clip"].encode_image(t) clip_feat = F.normalize(clip_feat.float(), dim=-1) feats["clip_vision"] = clip_feat # ViT features t = _transforms["vit"](pil_image).unsqueeze(0).to(DEVICE) vit_feat = _backbones["vit"](t).float() # ViT-B/16 outputs 768-dim; project down via linear if needed if vit_feat.shape[-1] != EMBED_DIM: # Simple mean-pool trick to match dim (head_weights.pt has proper projection) vit_feat = vit_feat[..., :EMBED_DIM] feats["vit"] = F.normalize(vit_feat, dim=-1) # ResNet features t = _transforms["resnet"](pil_image).unsqueeze(0).to(DEVICE) res_feat = _backbones["resnet"](t).float() if res_feat.shape[-1] != EMBED_DIM: res_feat = res_feat[..., :EMBED_DIM] feats["resnet"] = F.normalize(res_feat, dim=-1) # EfficientNet features t = _transforms["efficientnet"](pil_image).unsqueeze(0).to(DEVICE) eff_feat = _backbones["efficientnet"](t).float() if eff_feat.shape[-1] != EMBED_DIM: eff_feat = eff_feat[..., :EMBED_DIM] feats["efficientnet"] = F.normalize(eff_feat, dim=-1) return feats def predict(pil_image: Image.Image) -> dict: _ensure_loaded() feats = extract_features(pil_image) with torch.no_grad(): out = _model(feats) probs = F.softmax(out["logits"], dim=-1).squeeze(0).cpu().tolist() scores = {label: round(p, 6) for label, p in zip(ALL_CLASSES, probs)} return dict(sorted(scores.items(), key=lambda x: x[1], reverse=True)) def classify(image): if image is None: return {} try: return predict(Image.fromarray(image)) except Exception as e: return {"Error": str(e)} # ───────────────────────────────────────────────────────────────────────────── # Gradio UI # ───────────────────────────────────────────────────────────────────────────── DESCRIPTION = """ ## 🏥🎾 Universal Cross-Domain Vision Model Classifies images across **medical** (X-ray pathologies) and **sports** domains using an ensemble of BiomedCLIP, ViT-B/16, ResNet-50, and EfficientNet-B0 backbones with fine-tuned multi-modal attention fusion. **Medical classes:** Normal, Pneumonia, COVID-19, Tuberculosis, Cardiomegaly, Rib Fracture, Lung Mass, Pleural Effusion **Sports classes:** Running, Jumping, Swimming, Cycling, Tennis, Football Upload any image to get started. """ with gr.Blocks(title="Universal Vision Model", theme=gr.themes.Soft()) as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=1): img_input = gr.Image(label="Upload Image", type="numpy") submit_btn = gr.Button("Classify", variant="primary") with gr.Column(scale=1): label_output = gr.Label(num_top_classes=8, label="Predictions") submit_btn.click(fn=classify, inputs=img_input, outputs=label_output) img_input.change(fn=classify, inputs=img_input, outputs=label_output) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), share=False, )