| """ |
| Universal Cross-Domain Vision Model β FastAPI Inference Server |
| ============================================================== |
| Run: uvicorn api:app --host 0.0.0.0 --port 8000 --reload |
| |
| Endpoints |
| --------- |
| GET / health check |
| POST /predict upload an image β JSON predictions |
| POST /predict/url pass an image URL β JSON predictions |
| """ |
|
|
| import io |
| import os |
| import base64 |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from PIL import Image |
|
|
| from fastapi import FastAPI, File, UploadFile, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, HttpUrl |
| import uvicorn |
|
|
| |
| |
| |
| CHECKPOINT_PATH = os.environ.get( |
| "CHECKPOINT_PATH", |
| os.path.join(os.path.dirname(__file__), "..", "universal_vision_checkpoints", "best_model_phase1.pt"), |
| ) |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| MEDICAL_CLASSES = [ |
| "Normal", "Pneumonia", "COVID-19", "Tuberculosis", |
| "Cardiomegaly", "Rib Fracture", "Lung Mass", "Pleural Effusion", |
| ] |
| SPORTS_CLASSES = ["Running", "Jumping", "Swimming", "Cycling", "Tennis", "Football"] |
| ALL_CLASSES = MEDICAL_CLASSES + SPORTS_CLASSES |
|
|
| |
| |
| |
| class BiomedCLIPMultiModalFusion(nn.Module): |
| def __init__(self, embed_dim: int = 512, num_classes: int = len(ALL_CLASSES), dropout: float = 0.2): |
| super().__init__() |
| self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=8, dropout=dropout, batch_first=True) |
| self.ffn = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(embed_dim * 4, embed_dim), nn.Dropout(dropout), |
| ) |
| self.norm1 = nn.LayerNorm(embed_dim) |
| self.norm2 = nn.LayerNorm(embed_dim) |
| self.domain_discriminator = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(embed_dim // 2, 2), |
| ) |
| self.classifier = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(embed_dim // 2, num_classes), |
| ) |
|
|
| def forward(self, x): |
| x = x.unsqueeze(1) |
| attn_out, _ = self.attention(x, x, x) |
| x = self.norm1(x + attn_out) |
| fused = self.norm2(x + self.ffn(x)).squeeze(1) |
| return self.classifier(fused) |
|
|
|
|
| |
| |
| |
| _model = None |
| _backbone = None |
| _preprocess = None |
|
|
|
|
| def get_models(): |
| global _model, _backbone, _preprocess |
| if _model is not None: |
| return _model, _backbone, _preprocess |
|
|
| try: |
| import open_clip |
| _backbone, _preprocess, _ = open_clip.create_model_and_transforms( |
| "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224" |
| ) |
| except Exception: |
| import open_clip |
| _backbone, _, _preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai") |
|
|
| _backbone = _backbone.to(DEVICE).eval() |
| _model = BiomedCLIPMultiModalFusion().to(DEVICE).eval() |
|
|
| if os.path.isfile(CHECKPOINT_PATH): |
| ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False) |
| state = ckpt.get("model_state_dict", ckpt) |
| _model.load_state_dict(state, strict=False) |
|
|
| return _model, _backbone, _preprocess |
|
|
|
|
| def run_inference(pil_image: Image.Image) -> dict: |
| model, backbone, preprocess = get_models() |
| tensor = preprocess(pil_image).unsqueeze(0).to(DEVICE) |
| with torch.no_grad(): |
| features = backbone.encode_image(tensor) |
| features = F.normalize(features.float(), dim=-1) |
| logits = model(features) |
| probs = F.softmax(logits, dim=-1).squeeze(0).cpu().tolist() |
| results = [{"label": lbl, "confidence": round(prob, 6)} for lbl, prob in zip(ALL_CLASSES, probs)] |
| results.sort(key=lambda x: x["confidence"], reverse=True) |
| return {"predictions": results, "top_prediction": results[0]} |
|
|
|
|
| |
| |
| |
| app = FastAPI( |
| title="Universal Cross-Domain Vision Model API", |
| description="Classifies images across medical (X-ray pathologies) and sports domains.", |
| version="1.0.0", |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| """Pre-load models at startup so first request is fast.""" |
| get_models() |
|
|
|
|
| @app.get("/") |
| def health(): |
| return { |
| "status": "ok", |
| "device": str(DEVICE), |
| "classes": ALL_CLASSES, |
| "checkpoint": os.path.isfile(CHECKPOINT_PATH), |
| } |
|
|
|
|
| @app.post("/predict") |
| async def predict_upload(file: UploadFile = File(...)): |
| """Upload an image file and get predictions.""" |
| if not file.content_type.startswith("image/"): |
| raise HTTPException(status_code=400, detail="File must be an image.") |
| try: |
| contents = await file.read() |
| image = Image.open(io.BytesIO(contents)).convert("RGB") |
| return run_inference(image) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| class URLRequest(BaseModel): |
| url: str |
| timeout: Optional[int] = 10 |
|
|
|
|
| @app.post("/predict/url") |
| async def predict_url(req: URLRequest): |
| """Pass an image URL and get predictions.""" |
| import urllib.request |
| try: |
| with urllib.request.urlopen(req.url, timeout=req.timeout) as resp: |
| image = Image.open(io.BytesIO(resp.read())).convert("RGB") |
| return run_inference(image) |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Could not fetch image: {e}") |
|
|
|
|
| class Base64Request(BaseModel): |
| image_base64: str |
|
|
|
|
| @app.post("/predict/base64") |
| async def predict_base64(req: Base64Request): |
| """Send a base64-encoded image and get predictions.""" |
| try: |
| img_bytes = base64.b64decode(req.image_base64) |
| image = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
| return run_inference(image) |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=str(e)) |
|
|
|
|
| if __name__ == "__main__": |
| uvicorn.run("api:app", host="0.0.0.0", port=int(os.environ.get("PORT", 8000)), reload=True) |
|
|