""" Universal Cross-Domain Vision Model — FastAPI Inference Server ============================================================== Run: uvicorn api:app --host 0.0.0.0 --port 8000 --reload Endpoints --------- GET / health check POST /predict upload an image → JSON predictions POST /predict/url pass an image URL → JSON predictions """ import io import os import base64 from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from PIL import Image from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, HttpUrl import uvicorn # ───────────────────────────────────────────────────────────────────────────── # Config # ───────────────────────────────────────────────────────────────────────────── CHECKPOINT_PATH = os.environ.get( "CHECKPOINT_PATH", os.path.join(os.path.dirname(__file__), "..", "universal_vision_checkpoints", "best_model_phase1.pt"), ) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MEDICAL_CLASSES = [ "Normal", "Pneumonia", "COVID-19", "Tuberculosis", "Cardiomegaly", "Rib Fracture", "Lung Mass", "Pleural Effusion", ] SPORTS_CLASSES = ["Running", "Jumping", "Swimming", "Cycling", "Tennis", "Football"] ALL_CLASSES = MEDICAL_CLASSES + SPORTS_CLASSES # ───────────────────────────────────────────────────────────────────────────── # Model (same architecture as app.py) # ───────────────────────────────────────────────────────────────────────────── class BiomedCLIPMultiModalFusion(nn.Module): def __init__(self, embed_dim: int = 512, num_classes: int = len(ALL_CLASSES), dropout: float = 0.2): super().__init__() self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=8, dropout=dropout, batch_first=True) self.ffn = nn.Sequential( nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(embed_dim * 4, embed_dim), nn.Dropout(dropout), ) self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) self.domain_discriminator = nn.Sequential( nn.Linear(embed_dim, embed_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(embed_dim // 2, 2), ) self.classifier = nn.Sequential( nn.Linear(embed_dim, embed_dim // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(embed_dim // 2, num_classes), ) def forward(self, x): x = x.unsqueeze(1) attn_out, _ = self.attention(x, x, x) x = self.norm1(x + attn_out) fused = self.norm2(x + self.ffn(x)).squeeze(1) return self.classifier(fused) # ───────────────────────────────────────────────────────────────────────────── # Singleton model loader # ───────────────────────────────────────────────────────────────────────────── _model = None _backbone = None _preprocess = None def get_models(): global _model, _backbone, _preprocess if _model is not None: return _model, _backbone, _preprocess try: import open_clip _backbone, _preprocess, _ = open_clip.create_model_and_transforms( "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224" ) except Exception: import open_clip _backbone, _, _preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai") _backbone = _backbone.to(DEVICE).eval() _model = BiomedCLIPMultiModalFusion().to(DEVICE).eval() if os.path.isfile(CHECKPOINT_PATH): ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False) state = ckpt.get("model_state_dict", ckpt) _model.load_state_dict(state, strict=False) return _model, _backbone, _preprocess def run_inference(pil_image: Image.Image) -> dict: model, backbone, preprocess = get_models() tensor = preprocess(pil_image).unsqueeze(0).to(DEVICE) with torch.no_grad(): features = backbone.encode_image(tensor) features = F.normalize(features.float(), dim=-1) logits = model(features) probs = F.softmax(logits, dim=-1).squeeze(0).cpu().tolist() results = [{"label": lbl, "confidence": round(prob, 6)} for lbl, prob in zip(ALL_CLASSES, probs)] results.sort(key=lambda x: x["confidence"], reverse=True) return {"predictions": results, "top_prediction": results[0]} # ───────────────────────────────────────────────────────────────────────────── # FastAPI app # ───────────────────────────────────────────────────────────────────────────── app = FastAPI( title="Universal Cross-Domain Vision Model API", description="Classifies images across medical (X-ray pathologies) and sports domains.", version="1.0.0", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) @app.on_event("startup") async def startup_event(): """Pre-load models at startup so first request is fast.""" get_models() @app.get("/") def health(): return { "status": "ok", "device": str(DEVICE), "classes": ALL_CLASSES, "checkpoint": os.path.isfile(CHECKPOINT_PATH), } @app.post("/predict") async def predict_upload(file: UploadFile = File(...)): """Upload an image file and get predictions.""" if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="File must be an image.") try: contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") return run_inference(image) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) class URLRequest(BaseModel): url: str timeout: Optional[int] = 10 @app.post("/predict/url") async def predict_url(req: URLRequest): """Pass an image URL and get predictions.""" import urllib.request try: with urllib.request.urlopen(req.url, timeout=req.timeout) as resp: image = Image.open(io.BytesIO(resp.read())).convert("RGB") return run_inference(image) except Exception as e: raise HTTPException(status_code=400, detail=f"Could not fetch image: {e}") class Base64Request(BaseModel): image_base64: str # base64-encoded image bytes @app.post("/predict/base64") async def predict_base64(req: Base64Request): """Send a base64-encoded image and get predictions.""" try: img_bytes = base64.b64decode(req.image_base64) image = Image.open(io.BytesIO(img_bytes)).convert("RGB") return run_inference(image) except Exception as e: raise HTTPException(status_code=400, detail=str(e)) if __name__ == "__main__": uvicorn.run("api:app", host="0.0.0.0", port=int(os.environ.get("PORT", 8000)), reload=True)