Elliot89's picture
Upload README.md
25589b2 verified
---
title: Universal Cross-Domain Vision Model
emoji: πŸ₯🎾
colorFrom: blue
colorTo: green
sdk: gradio
sdk_version: 6.14.0
app_file: app.py
pinned: false
license: mit
---
# πŸ₯🎾 Universal Cross-Domain Vision Model
A multi-backbone vision model that classifies images across **medical X-ray pathologies** and **sports action** domains using fine-tuned multi-modal attention fusion on top of four pretrained encoders.
[![Hugging Face Space](https://img.shields.io/badge/πŸ€—%20Hugging%20Face-Space-blue)](https://huggingface.co/spaces/Elliot89/Universal_Cross-Domain_Vision_Model)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
---
## 🧠 Model Architecture
The model fuses features from four pretrained backbone encoders through a learned multi-head attention fusion layer:
| Backbone | Source | Output Dim |
|---|---|---|
| BiomedCLIP ViT-B/16 | `microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224` | 512 |
| ViT-B/16 | `timm` (ImageNet pretrained) | 512 |
| ResNet-50 | `timm` (ImageNet pretrained) | 512 |
| EfficientNet-B0 | `timm` (ImageNet pretrained) | 512 |
Each backbone's features are projected to a shared 512-dim space, then fused via an 8-head attention transformer block. The final classifier head outputs 14 class probabilities with an uncertainty estimate.
```
Image β†’ [BiomedCLIP, ViT-B/16, ResNet-50, EfficientNet-B0]
β†’ Projection Adapters (per backbone)
β†’ 8-Head Attention Fusion
β†’ Classifier β†’ 14 classes + Uncertainty estimate
```
---
## 🏷️ Classes
| Domain | Classes |
|---|---|
| πŸ₯ Medical (X-ray) | Normal, Pneumonia, COVID-19, Tuberculosis, Cardiomegaly, Rib Fracture, Lung Mass, Pleural Effusion |
| 🎾 Sports | Running, Jumping, Swimming, Cycling, Tennis, Football |
---
## πŸš€ Running the Demo
### Option 1 β€” Hugging Face Spaces (live)
Visit the live demo β€” no setup needed:
πŸ‘‰ **https://huggingface.co/spaces/Elliot89/Universal_Cross-Domain_Vision_Model**
Upload any image and click **Classify**.
### Option 2 β€” Run locally
**Requirements:** Python 3.9+, ~4 GB RAM (CPU) or GPU recommended
```bash
# 1. Clone this repo
git clone https://huggingface.co/spaces/Elliot89/Universal_Cross-Domain_Vision_Model
cd Universal_Cross-Domain_Vision_Model
# 2. Install dependencies
pip install -r requirements.txt
# 3. Launch
python app.py
# Opens at http://localhost:7860
```
### Option 3 β€” REST API
```bash
# Start the API server
uvicorn api:app --host 0.0.0.0 --port 8000
# Classify an image file
curl -X POST http://localhost:8000/predict -F "file=@your_image.jpg"
# Classify from URL
curl -X POST http://localhost:8000/predict/url \
-H "Content-Type: application/json" \
-d '{"url": "https://example.com/xray.jpg"}'
```
Interactive API docs at **http://localhost:8000/docs**
### Option 4 β€” Google Colab
Open `colab_deploy.ipynb` in Colab, set runtime to **T4 GPU**, and run all cells.
---
## πŸ“¦ Repository Structure
```
β”œβ”€β”€ app.py # Gradio web demo (main entry point)
β”œβ”€β”€ api.py # FastAPI REST inference server
β”œβ”€β”€ requirements.txt # Python dependencies
β”œβ”€β”€ head_weights.pt # Fine-tuned fusion + classifier weights (~25 MB)
β”œβ”€β”€ extract_head.py # Utility: extract head weights from full checkpoint
β”œβ”€β”€ colab_deploy.ipynb # One-click Google Colab notebook
└── README.md # This file
```
> **Note on weights:** The four backbone encoders (~1 GB total) are downloaded
> automatically from Hugging Face Hub at first startup and cached. Only the
> fine-tuned head (`head_weights.pt`, ~25 MB) is stored in this repo.
---
## πŸ”§ Training Details
| Setting | Value |
|---|---|
| Base model | BiomedCLIP (Microsoft), pretrained on PMC-15M medical image-text pairs |
| Additional backbones | ViT-B/16, ResNet-50, EfficientNet-B0 (ImageNet pretrained via timm) |
| Medical data | Synthesized X-ray images across 8 pathology classes |
| Sports data | Stanford40 action recognition dataset |
| Fusion | 8-head multi-head attention, 512-dim embedding space |
| Optimizer | AdamW with cosine annealing LR schedule |
| Regularization | Dropout (0.2), domain adversarial training |
---
## πŸ“‹ API Response Format
```json
{
"top_prediction": {
"label": "Pneumonia",
"confidence": 0.412
},
"predictions": [
{ "label": "Pneumonia", "confidence": 0.412 },
{ "label": "Normal", "confidence": 0.238 },
{ "label": "COVID-19", "confidence": 0.134 },
{ "label": "Tuberculosis", "confidence": 0.089 },
{ "label": "Cardiomegaly", "confidence": 0.061 },
{ "label": "Running", "confidence": 0.044 },
{ "label": "Lung Mass", "confidence": 0.031 },
{ "label": "Pleural Effusion","confidence": 0.021 }
]
}
```
---
## βš™οΈ Environment Variables
| Variable | Default | Description |
|---|---|---|
| `PORT` | `7860` (Gradio) / `8000` (API) | Server port |
---
## πŸ› οΈ Troubleshooting
**Slow first startup** β€” The four backbones (~1 GB total) are downloaded from HF Hub on first run and cached. On HF Spaces this happens automatically during the build phase.
**`head_weights.pt` not found** β€” The app still runs but uses random weights for the fusion and classifier layers. Predictions will not reflect actual training. Upload `head_weights.pt` to the repo to enable real predictions.
**Out of memory** β€” The model runs on CPU if no GPU is detected. If memory is tight, reduce image resolution or comment out extra backbones in `app.py`.
**Regenerating `head_weights.pt` from the original checkpoint** β€” If you have `best_model_phase1.pt`, run:
```bash
python extract_head.py
```
This strips the large backbone weights (which are loaded from HF Hub) and saves only the fine-tuned layers (~25 MB) as `head_weights.pt`.
---
## πŸ“„ License
MIT β€” see [https://opensource.org/licenses/MIT](https://opensource.org/licenses/MIT)
---
## πŸ™ Acknowledgements
- [Microsoft BiomedCLIP](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224) β€” vision-language model pretrained on 15M medical image-text pairs from PubMed Central
- [Stanford40](http://vision.stanford.edu/Datasets/40actions.html) β€” sports and human action recognition dataset
- [timm](https://github.com/huggingface/pytorch-image-models) β€” PyTorch Image Models library
- [open_clip](https://github.com/mlfoundations/open_clip) β€” open source CLIP implementation
- [Gradio](https://gradio.app) β€” web demo framework
- [FastAPI](https://fastapi.tiangolo.com) β€” REST API framework