Elliot89's picture
Upload 2 files
6f0e045 verified
"""
Universal Cross-Domain Vision Model β€” FastAPI Inference Server
==============================================================
Run: uvicorn api:app --host 0.0.0.0 --port 8000 --reload
Endpoints
---------
GET / health check
POST /predict upload an image β†’ JSON predictions
POST /predict/url pass an image URL β†’ JSON predictions
"""
import io
import os
import base64
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, HttpUrl
import uvicorn
# ─────────────────────────────────────────────────────────────────────────────
# Config
# ─────────────────────────────────────────────────────────────────────────────
CHECKPOINT_PATH = os.environ.get(
"CHECKPOINT_PATH",
os.path.join(os.path.dirname(__file__), "..", "universal_vision_checkpoints", "best_model_phase1.pt"),
)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MEDICAL_CLASSES = [
"Normal", "Pneumonia", "COVID-19", "Tuberculosis",
"Cardiomegaly", "Rib Fracture", "Lung Mass", "Pleural Effusion",
]
SPORTS_CLASSES = ["Running", "Jumping", "Swimming", "Cycling", "Tennis", "Football"]
ALL_CLASSES = MEDICAL_CLASSES + SPORTS_CLASSES
# ─────────────────────────────────────────────────────────────────────────────
# Model (same architecture as app.py)
# ─────────────────────────────────────────────────────────────────────────────
class BiomedCLIPMultiModalFusion(nn.Module):
def __init__(self, embed_dim: int = 512, num_classes: int = len(ALL_CLASSES), dropout: float = 0.2):
super().__init__()
self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=8, dropout=dropout, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(dropout),
nn.Linear(embed_dim * 4, embed_dim), nn.Dropout(dropout),
)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.domain_discriminator = nn.Sequential(
nn.Linear(embed_dim, embed_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(embed_dim // 2, 2),
)
self.classifier = nn.Sequential(
nn.Linear(embed_dim, embed_dim // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(embed_dim // 2, num_classes),
)
def forward(self, x):
x = x.unsqueeze(1)
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
fused = self.norm2(x + self.ffn(x)).squeeze(1)
return self.classifier(fused)
# ─────────────────────────────────────────────────────────────────────────────
# Singleton model loader
# ─────────────────────────────────────────────────────────────────────────────
_model = None
_backbone = None
_preprocess = None
def get_models():
global _model, _backbone, _preprocess
if _model is not None:
return _model, _backbone, _preprocess
try:
import open_clip
_backbone, _preprocess, _ = open_clip.create_model_and_transforms(
"hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
)
except Exception:
import open_clip
_backbone, _, _preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
_backbone = _backbone.to(DEVICE).eval()
_model = BiomedCLIPMultiModalFusion().to(DEVICE).eval()
if os.path.isfile(CHECKPOINT_PATH):
ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
state = ckpt.get("model_state_dict", ckpt)
_model.load_state_dict(state, strict=False)
return _model, _backbone, _preprocess
def run_inference(pil_image: Image.Image) -> dict:
model, backbone, preprocess = get_models()
tensor = preprocess(pil_image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
features = backbone.encode_image(tensor)
features = F.normalize(features.float(), dim=-1)
logits = model(features)
probs = F.softmax(logits, dim=-1).squeeze(0).cpu().tolist()
results = [{"label": lbl, "confidence": round(prob, 6)} for lbl, prob in zip(ALL_CLASSES, probs)]
results.sort(key=lambda x: x["confidence"], reverse=True)
return {"predictions": results, "top_prediction": results[0]}
# ─────────────────────────────────────────────────────────────────────────────
# FastAPI app
# ─────────────────────────────────────────────────────────────────────────────
app = FastAPI(
title="Universal Cross-Domain Vision Model API",
description="Classifies images across medical (X-ray pathologies) and sports domains.",
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
async def startup_event():
"""Pre-load models at startup so first request is fast."""
get_models()
@app.get("/")
def health():
return {
"status": "ok",
"device": str(DEVICE),
"classes": ALL_CLASSES,
"checkpoint": os.path.isfile(CHECKPOINT_PATH),
}
@app.post("/predict")
async def predict_upload(file: UploadFile = File(...)):
"""Upload an image file and get predictions."""
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image.")
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
return run_inference(image)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
class URLRequest(BaseModel):
url: str
timeout: Optional[int] = 10
@app.post("/predict/url")
async def predict_url(req: URLRequest):
"""Pass an image URL and get predictions."""
import urllib.request
try:
with urllib.request.urlopen(req.url, timeout=req.timeout) as resp:
image = Image.open(io.BytesIO(resp.read())).convert("RGB")
return run_inference(image)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not fetch image: {e}")
class Base64Request(BaseModel):
image_base64: str # base64-encoded image bytes
@app.post("/predict/base64")
async def predict_base64(req: Base64Request):
"""Send a base64-encoded image and get predictions."""
try:
img_bytes = base64.b64decode(req.image_base64)
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
return run_inference(image)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
if __name__ == "__main__":
uvicorn.run("api:app", host="0.0.0.0", port=int(os.environ.get("PORT", 8000)), reload=True)