| from typing import Optional |
|
|
| from fastapi import APIRouter |
| from fastapi import FastAPI |
| from schemas import ClassificationResult |
| from utils import load_image |
| from utils import load_model |
|
|
|
|
| |
|
|
| model = load_model() |
|
|
| app = FastAPI( |
| title="MosAl", |
| openapi_url="/openapi.json", |
| description="""Obtain classification predictions for mosquito image""", |
| version="0.1.0", |
| ) |
|
|
| api_router = APIRouter() |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| @api_router.get("/classify", status_code=200, response_model=ClassificationResult) |
| async def predict_image(image_name, model=model): |
| img = load_image(image_name) |
| prediction, pred_idx, probs = model.predict(img) |
| if prediction: |
| return {"prediction": prediction, |
| "score": round(probs.numpy()[pred_idx], 3), |
| } |
| else: |
| return {"message": [0]} |
|
|
|
|
|
|
| app.include_router(api_router) |
|
|
| if __name__ == "__main__": |
| |
| import uvicorn |
|
|
| uvicorn.run(app, host="0.0.0.0", port=7860, log_level="debug") |
|
|