| """ |
| Universal Cross-Domain Vision Model β Gradio Demo |
| ================================================== |
| Runs locally: python app.py |
| HF Spaces: push this folder to a Space (SDK: gradio) |
| |
| The app loads the trained BiomedCLIP checkpoint and classifies uploaded images |
| across medical (8 pathologies) and sports (6 action categories) domains. |
| """ |
|
|
| import os |
| import io |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from PIL import Image |
| import gradio as gr |
|
|
| |
| |
| |
| CHECKPOINT_PATH = os.environ.get( |
| "CHECKPOINT_PATH", |
| os.path.join(os.path.dirname(__file__), "..", "universal_vision_checkpoints", "best_model_phase1.pt"), |
| ) |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| MEDICAL_CLASSES = [ |
| "Normal", |
| "Pneumonia", |
| "COVID-19", |
| "Tuberculosis", |
| "Cardiomegaly", |
| "Rib Fracture", |
| "Lung Mass", |
| "Pleural Effusion", |
| ] |
|
|
| SPORTS_CLASSES = [ |
| "Running", |
| "Jumping", |
| "Swimming", |
| "Cycling", |
| "Tennis", |
| "Football", |
| ] |
|
|
| ALL_CLASSES = MEDICAL_CLASSES + SPORTS_CLASSES |
|
|
| |
| |
| |
| class BiomedCLIPMultiModalFusion(nn.Module): |
| """Lightweight inference-only wrapper matching the training architecture.""" |
|
|
| def __init__(self, embed_dim: int = 512, num_classes: int = len(ALL_CLASSES), dropout: float = 0.2): |
| super().__init__() |
| self.embed_dim = embed_dim |
|
|
| |
| self.domain_discriminator = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim // 2), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(embed_dim // 2, 2), |
| ) |
|
|
| |
| self.attention = nn.MultiheadAttention( |
| embed_dim=embed_dim, num_heads=8, dropout=dropout, batch_first=True |
| ) |
|
|
| |
| self.ffn = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim * 4), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(embed_dim * 4, embed_dim), |
| nn.Dropout(dropout), |
| ) |
|
|
| self.norm1 = nn.LayerNorm(embed_dim) |
| self.norm2 = nn.LayerNorm(embed_dim) |
|
|
| |
| self.classifier = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim // 2), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(embed_dim // 2, num_classes), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x = x.unsqueeze(1) |
| attn_out, _ = self.attention(x, x, x) |
| x = self.norm1(x + attn_out) |
| ffn_out = self.ffn(x) |
| fused = self.norm2(x + ffn_out).squeeze(1) |
| return self.classifier(fused) |
|
|
|
|
| |
| |
| |
| _model = None |
| _backbone = None |
| _preprocess = None |
|
|
|
|
| def _load_models(): |
| global _model, _backbone, _preprocess |
|
|
| if _model is not None: |
| return |
|
|
| print(f"[INFO] Loading models on {DEVICE} β¦") |
|
|
| |
| try: |
| import open_clip |
| _backbone, _preprocess, _ = open_clip.create_model_and_transforms( |
| "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224" |
| ) |
| embed_dim = 512 |
| print("[INFO] BiomedCLIP backbone loaded.") |
| except Exception as e: |
| print(f"[WARN] BiomedCLIP failed ({e}), using CLIP-ViT-B/32.") |
| import open_clip |
| _backbone, _, _preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai") |
| embed_dim = 512 |
|
|
| _backbone = _backbone.to(DEVICE).eval() |
|
|
| |
| _model = BiomedCLIPMultiModalFusion(embed_dim=embed_dim, num_classes=len(ALL_CLASSES)) |
|
|
| |
| if os.path.isfile(CHECKPOINT_PATH): |
| try: |
| ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False) |
| state = ckpt.get("model_state_dict", ckpt) |
| _model.load_state_dict(state, strict=False) |
| print(f"[INFO] Checkpoint loaded from {CHECKPOINT_PATH}") |
| except Exception as e: |
| print(f"[WARN] Could not load checkpoint: {e}. Running with random weights.") |
| else: |
| print(f"[WARN] Checkpoint not found at {CHECKPOINT_PATH}. Running with random weights.") |
|
|
| _model = _model.to(DEVICE).eval() |
| print("[INFO] Model ready.") |
|
|
|
|
| |
| |
| |
| def predict(image: Image.Image) -> dict: |
| """Run inference on a PIL image. Returns a {label: confidence} dict.""" |
| _load_models() |
|
|
| |
| tensor = _preprocess(image).unsqueeze(0).to(DEVICE) |
|
|
| with torch.no_grad(): |
| features = _backbone.encode_image(tensor) |
| features = F.normalize(features.float(), dim=-1) |
| logits = _model(features) |
| probs = F.softmax(logits, dim=-1).squeeze(0).cpu().numpy() |
|
|
| return {label: float(prob) for label, prob in zip(ALL_CLASSES, probs)} |
|
|
|
|
| def classify(image): |
| if image is None: |
| return {} |
| try: |
| pil_image = Image.fromarray(image).convert("RGB") |
| scores = predict(pil_image) |
| |
| return dict(sorted(scores.items(), key=lambda x: x[1], reverse=True)) |
| except Exception as e: |
| return {"Error": str(e)} |
|
|
|
|
| |
| |
| |
| DESCRIPTION = """ |
| ## π₯πΎ Universal Cross-Domain Vision Model |
| |
| Classifies images across **medical** (X-ray pathologies) and **sports** domains using a |
| BiomedCLIP backbone with multi-modal attention fusion. |
| |
| **Medical classes:** Normal, Pneumonia, COVID-19, Tuberculosis, Cardiomegaly, Rib Fracture, Lung Mass, Pleural Effusion |
| **Sports classes:** Running, Jumping, Swimming, Cycling, Tennis, Football |
| |
| Upload any image to get started. |
| """ |
|
|
| with gr.Blocks(title="Universal Vision Model", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(DESCRIPTION) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| img_input = gr.Image(label="Upload Image", type="numpy") |
| submit_btn = gr.Button("Classify", variant="primary") |
| with gr.Column(scale=1): |
| label_output = gr.Label(num_top_classes=8, label="Predictions") |
|
|
| submit_btn.click(fn=classify, inputs=img_input, outputs=label_output) |
| img_input.change(fn=classify, inputs=img_input, outputs=label_output) |
|
|
| gr.Examples( |
| examples=[], |
| inputs=img_input, |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=int(os.environ.get("PORT", 7860)), |
| share=False, |
| ) |
|
|