| """ |
| Universal Cross-Domain Vision Model β Gradio Demo |
| ================================================== |
| Architecture (matches best_model_phase1.pt): |
| Backbones (loaded from HF Hub at runtime β no storage cost): |
| - CLIP ViT-B/32 via open_clip |
| - ViT-B/16 via timm |
| - ResNet-50 via timm |
| - EfficientNet-B0 via timm |
| |
| Fine-tuned layers (loaded from head_weights.pt β ~25 MB): |
| - *_proj.* projection adapters per backbone |
| - fusion.* multi-head attention fusion |
| - classifier.* final 14-class head |
| - uncertainty_head.* uncertainty estimation |
| |
| Run locally: python app.py |
| HF Spaces: push this folder + head_weights.pt |
| """ |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image |
| import gradio as gr |
|
|
| |
| |
| |
| HEAD_WEIGHTS = os.path.join(os.path.dirname(__file__), "head_weights.pt") |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| EMBED_DIM = 512 |
|
|
| MEDICAL_CLASSES = [ |
| "Normal", "Pneumonia", "COVID-19", "Tuberculosis", |
| "Cardiomegaly", "Rib Fracture", "Lung Mass", "Pleural Effusion", |
| ] |
| SPORTS_CLASSES = ["Running", "Jumping", "Swimming", "Cycling", "Tennis", "Football"] |
| ALL_CLASSES = MEDICAL_CLASSES + SPORTS_CLASSES |
|
|
| |
| |
| |
| class UniversalVisionModel(nn.Module): |
| """ |
| Multi-backbone fusion model. |
| Backbones are loaded separately; this module holds only the |
| projection adapters, fusion transformer, and classifier head. |
| """ |
|
|
| def __init__(self, embed_dim=EMBED_DIM, num_classes=len(ALL_CLASSES), dropout=0.2): |
| super().__init__() |
|
|
| |
| self.clip_vision_proj = nn.Linear(embed_dim, embed_dim) |
| self.vit_proj = nn.Linear(embed_dim, embed_dim) |
| self.resnet_proj = nn.Linear(embed_dim, embed_dim) |
| self.efficientnet_proj = nn.Linear(embed_dim, embed_dim) |
| self.clip_text_proj = nn.Linear(embed_dim, embed_dim) |
|
|
| |
| self.fusion = nn.ModuleDict({ |
| "attention": nn.MultiheadAttention(embed_dim, num_heads=8, dropout=dropout, batch_first=True), |
| "ffn": nn.Sequential( |
| nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(embed_dim * 4, embed_dim), nn.Dropout(dropout), |
| ), |
| "norm1": nn.LayerNorm(embed_dim), |
| "norm2": nn.LayerNorm(embed_dim), |
| }) |
|
|
| |
| self.classifier = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim // 2), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(embed_dim // 2, num_classes), |
| ) |
|
|
| |
| self.uncertainty_head = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim // 4), nn.ReLU(), |
| nn.Linear(embed_dim // 4, num_classes), |
| ) |
|
|
| def fuse(self, feature_list): |
| """Fuse a list of [B, D] feature tensors via multi-head attention.""" |
| stacked = torch.stack(feature_list, dim=1) |
| attn_out, _ = self.fusion["attention"](stacked, stacked, stacked) |
| stacked = self.fusion["norm1"](stacked + attn_out) |
| ffn_out = self.fusion["ffn"](stacked) |
| fused = self.fusion["norm2"](stacked + ffn_out) |
| return fused.mean(dim=1) |
|
|
| def forward(self, features: dict) -> dict: |
| """ |
| features: dict with keys matching backbone names, |
| each value is [B, raw_dim] tensor. |
| """ |
| projected = [] |
| if "clip_vision" in features: |
| projected.append(self.clip_vision_proj(features["clip_vision"])) |
| if "vit" in features: |
| projected.append(self.vit_proj(features["vit"])) |
| if "resnet" in features: |
| projected.append(self.resnet_proj(features["resnet"])) |
| if "efficientnet" in features: |
| projected.append(self.efficientnet_proj(features["efficientnet"])) |
| if "clip_text" in features: |
| projected.append(self.clip_text_proj(features["clip_text"])) |
|
|
| fused = self.fuse(projected) |
| logits = self.classifier(fused) |
| uncertainty = self.uncertainty_head(fused) |
| return {"logits": logits, "uncertainty": uncertainty, "fused": fused} |
|
|
|
|
| |
| |
| |
| _backbones = {} |
| _transforms = {} |
| _model = None |
|
|
|
|
| def _load_backbones(): |
| global _backbones, _transforms |
|
|
| import open_clip, timm |
| from torchvision import transforms as T |
|
|
| |
| timm_tfm = T.Compose([ |
| T.Resize(224), T.CenterCrop(224), T.ToTensor(), |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ]) |
|
|
| |
| print("[INFO] Loading CLIP backbone...") |
| try: |
| clip_model, clip_tfm, _ = open_clip.create_model_and_transforms( |
| "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224" |
| ) |
| except Exception: |
| clip_model, _, clip_tfm = open_clip.create_model_and_transforms( |
| "ViT-B-32", pretrained="openai" |
| ) |
| clip_model = clip_model.to(DEVICE).eval() |
| _backbones["clip"] = clip_model |
| _transforms["clip"] = clip_tfm |
|
|
| |
| print("[INFO] Loading ViT-B/16 backbone...") |
| vit = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0) |
| vit = vit.to(DEVICE).eval() |
| _backbones["vit"] = vit |
| _transforms["vit"] = timm_tfm |
|
|
| |
| print("[INFO] Loading ResNet-50 backbone...") |
| resnet = timm.create_model("resnet50", pretrained=True, num_classes=0) |
| resnet = resnet.to(DEVICE).eval() |
| _backbones["resnet"] = resnet |
| _transforms["resnet"] = timm_tfm |
|
|
| |
| print("[INFO] Loading EfficientNet-B0 backbone...") |
| effnet = timm.create_model("efficientnet_b0", pretrained=True, num_classes=0) |
| effnet = effnet.to(DEVICE).eval() |
| _backbones["efficientnet"] = effnet |
| _transforms["efficientnet"] = timm_tfm |
|
|
| print("[INFO] All backbones loaded.") |
|
|
|
|
| def _load_model(): |
| global _model |
| _model = UniversalVisionModel().to(DEVICE) |
| if os.path.isfile(HEAD_WEIGHTS): |
| ckpt = torch.load(HEAD_WEIGHTS, map_location=DEVICE, weights_only=False) |
| state = ckpt.get("model_state_dict", ckpt) |
| missing, unexpected = _model.load_state_dict(state, strict=False) |
| print(f"[INFO] Head loaded β missing: {len(missing)}, unexpected: {len(unexpected)}") |
| else: |
| print("[WARN] head_weights.pt not found β using random weights.") |
| _model.eval() |
|
|
|
|
| def _ensure_loaded(): |
| if _model is None: |
| _load_backbones() |
| _load_model() |
|
|
|
|
| |
| |
| |
| def extract_features(pil_image: Image.Image) -> dict: |
| """Extract features from all backbones.""" |
| feats = {} |
| with torch.no_grad(): |
| |
| t = _transforms["clip"](pil_image).unsqueeze(0).to(DEVICE) |
| clip_feat = _backbones["clip"].encode_image(t) |
| clip_feat = F.normalize(clip_feat.float(), dim=-1) |
| feats["clip_vision"] = clip_feat |
|
|
| |
| t = _transforms["vit"](pil_image).unsqueeze(0).to(DEVICE) |
| vit_feat = _backbones["vit"](t).float() |
| |
| if vit_feat.shape[-1] != EMBED_DIM: |
| |
| vit_feat = vit_feat[..., :EMBED_DIM] |
| feats["vit"] = F.normalize(vit_feat, dim=-1) |
|
|
| |
| t = _transforms["resnet"](pil_image).unsqueeze(0).to(DEVICE) |
| res_feat = _backbones["resnet"](t).float() |
| if res_feat.shape[-1] != EMBED_DIM: |
| res_feat = res_feat[..., :EMBED_DIM] |
| feats["resnet"] = F.normalize(res_feat, dim=-1) |
|
|
| |
| t = _transforms["efficientnet"](pil_image).unsqueeze(0).to(DEVICE) |
| eff_feat = _backbones["efficientnet"](t).float() |
| if eff_feat.shape[-1] != EMBED_DIM: |
| eff_feat = eff_feat[..., :EMBED_DIM] |
| feats["efficientnet"] = F.normalize(eff_feat, dim=-1) |
|
|
| return feats |
|
|
|
|
| def predict(pil_image: Image.Image) -> dict: |
| _ensure_loaded() |
| feats = extract_features(pil_image) |
| with torch.no_grad(): |
| out = _model(feats) |
| probs = F.softmax(out["logits"], dim=-1).squeeze(0).cpu().tolist() |
| scores = {label: round(p, 6) for label, p in zip(ALL_CLASSES, probs)} |
| return dict(sorted(scores.items(), key=lambda x: x[1], reverse=True)) |
|
|
|
|
| def classify(image): |
| if image is None: |
| return {} |
| try: |
| return predict(Image.fromarray(image)) |
| except Exception as e: |
| return {"Error": str(e)} |
|
|
|
|
| |
| |
| |
| DESCRIPTION = """ |
| ## π₯πΎ Universal Cross-Domain Vision Model |
| |
| Classifies images across **medical** (X-ray pathologies) and **sports** domains using an |
| ensemble of BiomedCLIP, ViT-B/16, ResNet-50, and EfficientNet-B0 backbones |
| with fine-tuned multi-modal attention fusion. |
| |
| **Medical classes:** Normal, Pneumonia, COVID-19, Tuberculosis, Cardiomegaly, Rib Fracture, Lung Mass, Pleural Effusion |
| **Sports classes:** Running, Jumping, Swimming, Cycling, Tennis, Football |
| |
| Upload any image to get started. |
| """ |
|
|
| with gr.Blocks(title="Universal Vision Model", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(DESCRIPTION) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| img_input = gr.Image(label="Upload Image", type="numpy") |
| submit_btn = gr.Button("Classify", variant="primary") |
| with gr.Column(scale=1): |
| label_output = gr.Label(num_top_classes=8, label="Predictions") |
|
|
| submit_btn.click(fn=classify, inputs=img_input, outputs=label_output) |
| img_input.change(fn=classify, inputs=img_input, outputs=label_output) |
|
|
| if __name__ == "__main__": |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=int(os.environ.get("PORT", 7860)), |
| share=False, |
| ) |
|
|