| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel, Field |
| from typing import List |
| from transformers import pipeline |
|
|
| |
| classifier = pipeline("zero-shot-classification") |
|
|
| |
| app = FastAPI() |
|
|
| |
| class ClassificationRequest(BaseModel): |
| text: str = Field(..., example="This is a course about the Transformers library") |
| labels: List[str] = Field(..., example=["education", "politics", "technology"]) |
|
|
| @app.get("/") |
| def greet_json(): |
| """ |
| A simple GET endpoint that returns a greeting message. |
| """ |
| return {"Hello": "World!"} |
|
|
| @app.post("/classify") |
| def zero_shot_classification(request: ClassificationRequest): |
| """ |
| A POST endpoint that performs zero-shot classification on the input text |
| using the provided candidate labels. |
| """ |
| try: |
| |
| result = classifier( |
| request.text, |
| candidate_labels=request.labels |
| ) |
| return result |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|