import torch from transformers import AutoTokenizer, AutoModel from fastapi import FastAPI, HTTPException from pydantic import BaseModel import os MODEL_NAME = os.getenv("MODEL_NAME", "jhu-clsp/mmBERT-base") app = FastAPI(title="ModernBERT Embedding API", version="1.0.0") print("Loading model:", MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME) model.eval() class EmbedRequest(BaseModel): text: str @app.get("/health") def health(): return {"status": "ok", "model": MODEL_NAME} @app.post("/embed") def embed(req: EmbedRequest): text = (req.text or "").strip() if not text: raise HTTPException(status_code=400, detail="Empty text") with torch.no_grad(): inputs = tokenizer( text, padding=True, truncation=True, max_length=512, return_tensors="pt", ) outputs = model(**inputs) mask = inputs["attention_mask"].unsqueeze(-1) embeddings = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1) emb = embeddings[0].tolist() return { "model": MODEL_NAME, "dim": len(emb), "preview_first_8": [round(x, 4) for x in emb[:8]], "embedding": emb, }