Elliot89's picture
Upload 3 files
07c2bbf verified
raw
history blame
9.13 kB
"""
Universal Cross-Domain Vision Model β€” Gradio Demo
==================================================
Runs locally: python app.py
HF Spaces: push this folder to a Space (SDK: gradio)
The app loads the trained BiomedCLIP checkpoint and classifies uploaded images
across medical (8 pathologies) and sports (6 action categories) domains.
"""
import os
import io
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import gradio as gr
# ─────────────────────────────────────────────────────────────────────────────
# Configuration
# ─────────────────────────────────────────────────────────────────────────────
CHECKPOINT_PATH = os.environ.get(
"CHECKPOINT_PATH",
os.path.join(os.path.dirname(__file__), "..", "universal_vision_checkpoints", "best_model_phase1.pt"),
)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MEDICAL_CLASSES = [
"Normal",
"Pneumonia",
"COVID-19",
"Tuberculosis",
"Cardiomegaly",
"Rib Fracture",
"Lung Mass",
"Pleural Effusion",
]
SPORTS_CLASSES = [
"Running",
"Jumping",
"Swimming",
"Cycling",
"Tennis",
"Football",
]
ALL_CLASSES = MEDICAL_CLASSES + SPORTS_CLASSES
# ─────────────────────────────────────────────────────────────────────────────
# Model Definition (must match training architecture)
# ─────────────────────────────────────────────────────────────────────────────
class BiomedCLIPMultiModalFusion(nn.Module):
"""Lightweight inference-only wrapper matching the training architecture."""
def __init__(self, embed_dim: int = 512, num_classes: int = len(ALL_CLASSES), dropout: float = 0.2):
super().__init__()
self.embed_dim = embed_dim
# Domain discriminator (kept for architecture compatibility)
self.domain_discriminator = nn.Sequential(
nn.Linear(embed_dim, embed_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(embed_dim // 2, 2),
)
# Multi-head attention fusion
self.attention = nn.MultiheadAttention(
embed_dim=embed_dim, num_heads=8, dropout=dropout, batch_first=True
)
# Feed-forward network
self.ffn = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(embed_dim * 4, embed_dim),
nn.Dropout(dropout),
)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
# Classifier head
self.classifier = nn.Sequential(
nn.Linear(embed_dim, embed_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(embed_dim // 2, num_classes),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, embed_dim] β€” pre-extracted image features
x = x.unsqueeze(1) # [B, 1, D]
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
ffn_out = self.ffn(x)
fused = self.norm2(x + ffn_out).squeeze(1) # [B, D]
return self.classifier(fused)
# ─────────────────────────────────────────────────────────────────────────────
# Load model + backbone
# ─────────────────────────────────────────────────────────────────────────────
_model = None
_backbone = None
_preprocess = None
def _load_models():
global _model, _backbone, _preprocess
if _model is not None:
return
print(f"[INFO] Loading models on {DEVICE} …")
# Try BiomedCLIP first, fall back to standard CLIP
try:
import open_clip
_backbone, _preprocess, _ = open_clip.create_model_and_transforms(
"hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
)
embed_dim = 512
print("[INFO] BiomedCLIP backbone loaded.")
except Exception as e:
print(f"[WARN] BiomedCLIP failed ({e}), using CLIP-ViT-B/32.")
import open_clip
_backbone, _, _preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
embed_dim = 512
_backbone = _backbone.to(DEVICE).eval()
# Build fusion model
_model = BiomedCLIPMultiModalFusion(embed_dim=embed_dim, num_classes=len(ALL_CLASSES))
# Load checkpoint weights (graceful fallback if checkpoint is missing)
if os.path.isfile(CHECKPOINT_PATH):
try:
ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
state = ckpt.get("model_state_dict", ckpt)
_model.load_state_dict(state, strict=False)
print(f"[INFO] Checkpoint loaded from {CHECKPOINT_PATH}")
except Exception as e:
print(f"[WARN] Could not load checkpoint: {e}. Running with random weights.")
else:
print(f"[WARN] Checkpoint not found at {CHECKPOINT_PATH}. Running with random weights.")
_model = _model.to(DEVICE).eval()
print("[INFO] Model ready.")
# ─────────────────────────────────────────────────────────────────────────────
# Inference
# ─────────────────────────────────────────────────────────────────────────────
def predict(image: Image.Image) -> dict:
"""Run inference on a PIL image. Returns a {label: confidence} dict."""
_load_models()
# Pre-process
tensor = _preprocess(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
features = _backbone.encode_image(tensor) # [1, D]
features = F.normalize(features.float(), dim=-1)
logits = _model(features) # [1, num_classes]
probs = F.softmax(logits, dim=-1).squeeze(0).cpu().numpy()
return {label: float(prob) for label, prob in zip(ALL_CLASSES, probs)}
def classify(image):
if image is None:
return {}
try:
pil_image = Image.fromarray(image).convert("RGB")
scores = predict(pil_image)
# Sort by confidence descending
return dict(sorted(scores.items(), key=lambda x: x[1], reverse=True))
except Exception as e:
return {"Error": str(e)}
# ─────────────────────────────────────────────────────────────────────────────
# Gradio Interface
# ─────────────────────────────────────────────────────────────────────────────
DESCRIPTION = """
## πŸ₯🎾 Universal Cross-Domain Vision Model
Classifies images across **medical** (X-ray pathologies) and **sports** domains using a
BiomedCLIP backbone with multi-modal attention fusion.
**Medical classes:** Normal, Pneumonia, COVID-19, Tuberculosis, Cardiomegaly, Rib Fracture, Lung Mass, Pleural Effusion
**Sports classes:** Running, Jumping, Swimming, Cycling, Tennis, Football
Upload any image to get started.
"""
with gr.Blocks(title="Universal Vision Model", theme=gr.themes.Soft()) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
img_input = gr.Image(label="Upload Image", type="numpy")
submit_btn = gr.Button("Classify", variant="primary")
with gr.Column(scale=1):
label_output = gr.Label(num_top_classes=8, label="Predictions")
submit_btn.click(fn=classify, inputs=img_input, outputs=label_output)
img_input.change(fn=classify, inputs=img_input, outputs=label_output)
gr.Examples(
examples=[], # Add example image paths here if available
inputs=img_input,
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
share=False,
)