| from fastapi import FastAPI |
| from pydantic import BaseModel |
| from transformers import pipeline |
|
|
| app = FastAPI() |
|
|
| classifier = pipeline( |
| "text-classification", |
| model="LokeshDevCreates/tone-baseline-v3", |
| top_k=None |
| ) |
|
|
| class TextRequest(BaseModel): |
| text: str |
|
|
| @app.post("/predict") |
| def predict_tone(req: TextRequest): |
| results = classifier(req.text)[0] |
| results = sorted(results, key=lambda x: x["score"], reverse=True) |
|
|
| return { |
| "detected_tone": results[0]["label"], |
| "confidence": round(results[0]["score"], 4), |
| "all_probs": {r["label"]: round(r["score"], 4) for r in results} |
| } |
|
|