mjpsm's picture
Upload 3 files
b0bb2e9 verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import torch
import torch.nn.functional as F
# =========================
# INIT
# =========================
app = FastAPI(
title="Skill Classification API",
description="Predicts skill from student check-ins",
version="1.0"
)
MODEL_PATH = "mjpsm/skill-classifier-BERT-v1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("🔄 Loading model...")
tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_PATH)
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
model.to(device)
model.eval()
print("✅ Model loaded!")
# =========================
# INPUT SCHEMA
# =========================
class InputText(BaseModel):
text: str
# =========================
# ROOT
# =========================
@app.get("/")
def home():
return {"message": "Skill Classification API is running"}
# =========================
# PREDICT
# =========================
@app.post("/predict")
def predict(input: InputText):
text = input.text
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1)
pred = torch.argmax(probs, dim=1).item()
label = model.config.id2label[pred]
confidence = probs[0][pred].item()
return {
"prediction": label,
"confidence": confidence
}