| import torch
|
| from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| from fastapi import FastAPI
|
| from pydantic import BaseModel
|
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
| model_name = "facebook/bart-large-mnli"
|
| tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
|
| model.eval()
|
|
|
| temperature = 1.2
|
|
|
|
|
| app = FastAPI(title="Text Classification API")
|
|
|
| class InputData(BaseModel):
|
| text: str
|
| label: str
|
|
|
| def classify(sequence, label):
|
| hypothesis = f"This example is {label}."
|
| inputs = tokenizer(
|
| sequence,
|
| hypothesis,
|
| return_tensors="pt",
|
| truncation=True,
|
| max_length=512
|
| ).to(device)
|
|
|
| with torch.no_grad():
|
| logits = model(**inputs).logits
|
|
|
| idx = torch.tensor([0, 2], device=logits.device)
|
| logits = logits.index_select(1, idx)
|
|
|
| probs = torch.softmax(logits / temperature, dim=1)
|
| entropy = -torch.sum(probs * torch.log(probs + 1e-8))
|
|
|
| if entropy > 0.8:
|
| return "Uncertain"
|
| return round(probs[0, 1].item(), 3)
|
|
|
|
|
| @app.post("/predict")
|
| def predict(data: InputData):
|
| result = classify(data.text, data.label)
|
| return {"probability": result} |