| import torch |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
| |
| app = FastAPI() |
|
|
| |
| model_name = "Canstralian/RabbitRedux" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) |
| model.eval() |
|
|
| |
| class PredictionRequest(BaseModel): |
| text: str |
|
|
| class PredictionResponse(BaseModel): |
| text: str |
| prediction: str |
|
|
| |
| @app.post("/predict", response_model=PredictionResponse) |
| async def predict(request: PredictionRequest): |
| try: |
| |
| inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True) |
| |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| |
| |
| prediction = torch.argmax(outputs.logits, dim=-1).item() |
| |
| |
| labels = ["Label 1", "Label 2", "Label 3"] |
| predicted_label = labels[prediction] |
| |
| |
| return PredictionResponse(text=request.text, prediction=predicted_label) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail="Prediction failed") |
|
|
| |
| @app.get("/health") |
| async def health_check(): |
| return {"status": "healthy"} |
|
|