| """ |
| FastAPI application for YOLO PyTorch image classification |
| Accepts image uploads and returns verdict (green/red) based on detected classes |
| Includes GPU detection and automatic device selection |
| """ |
|
|
| from fastapi import FastAPI, File, UploadFile, HTTPException |
| from fastapi.responses import JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| from typing import List, Dict, Optional |
| import numpy as np |
| import cv2 |
| from io import BytesIO |
| from PIL import Image |
| import logging |
| import base64 |
|
|
| from yolo_inference_pytorch import YOLOInference |
| from verdict_logic import VerdictEngine |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI( |
| title="YOLO PyTorch Image Classifier - Engineer Selfie", |
| description="Upload an image to get classification verdict (green/red) based on detected objects. Supports GPU acceleration.", |
| version="2.0.0" |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| class Base64ImageRequest(BaseModel): |
| image: str |
| |
| class Config: |
| json_schema_extra = { |
| "example": { |
| "image": "data:image/jpeg;base64,/9j/4AAQSkZJRg..." |
| } |
| } |
|
|
| class Detection(BaseModel): |
| class_id: int |
| class_name: str |
| confidence: float |
| bbox: List[float] |
|
|
| class VerdictResponse(BaseModel): |
| verdict: str |
| confidence: float |
| detections: List[Detection] |
| message: str |
| image_size: Dict[str, int] |
|
|
| yolo_model = None |
| verdict_engine = None |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| """Initialize models on startup""" |
| global yolo_model, verdict_engine |
| |
| try: |
| import os |
| model_path = "model/best.pt" |
| |
| if not os.path.exists(model_path): |
| logger.warning(f"⚠️ YOLO model not found at {model_path}") |
| logger.warning("⚠️ The API will start but /predict endpoint will not work until you add the model") |
| logger.warning("⚠️ Please place your best.pt file in the model/ directory") |
| |
| verdict_engine = VerdictEngine( |
| green_classes=[], |
| red_classes=[], |
| rules_path="config/verdict_rules.json" |
| ) |
| logger.info("Verdict engine initialized successfully") |
| return |
| |
| logger.info("Loading YOLO PyTorch model with GPU detection...") |
| yolo_model = YOLOInference( |
| model_path=model_path, |
| labels_path="model/labels.txt", |
| conf_threshold=0.5, |
| iou_threshold=0.45, |
| device=None |
| ) |
| logger.info("YOLO model loaded successfully") |
| logger.info(f"Using confidence threshold: 0.5 (50%)") |
| |
| logger.info("Initializing verdict engine...") |
| verdict_engine = VerdictEngine( |
| green_classes=[], |
| red_classes=[], |
| rules_path="config/verdict_rules.json" |
| ) |
| logger.info("Verdict engine initialized successfully") |
| |
| except Exception as e: |
| logger.error(f"Error during startup: {str(e)}") |
| raise |
|
|
| @app.get("/") |
| async def root(): |
| """Root endpoint with API information""" |
| return { |
| "message": "YOLO PyTorch Image Classifier API - Engineer Selfie", |
| "version": "2.0.0", |
| "description": "Upload images to check for object detection and get approval verdicts. GPU-accelerated when available.", |
| "endpoints": { |
| "check_image_base64": "/check-image (POST with base64)", |
| "check_image_upload": "/check-image/upload (POST with file)", |
| "check_images_batch": "/check-images/batch (POST with multiple files)", |
| "health": "/health (GET)", |
| "model_info": "/model/info (GET)", |
| "device_info": "/device/info (GET)" |
| }, |
| "usage": { |
| "base64": "POST /check-image with JSON body: {\"image\": \"base64_string\"}", |
| "file_upload": "POST /check-image/upload with multipart/form-data", |
| "batch": "POST /check-images/batch with multiple files" |
| } |
| } |
|
|
| @app.get("/health") |
| async def health_check(): |
| """Health check endpoint""" |
| return { |
| "status": "healthy", |
| "model_loaded": yolo_model is not None, |
| "verdict_engine_loaded": verdict_engine is not None, |
| "device": yolo_model.device if yolo_model else "not loaded" |
| } |
|
|
| @app.get("/device/info") |
| async def device_info(): |
| """Get GPU/device information""" |
| if yolo_model is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
| |
| return yolo_model.get_device_info() |
|
|
| @app.get("/model/info") |
| async def model_info(): |
| """Get model information""" |
| if yolo_model is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
| |
| return { |
| "model_type": "YOLO PyTorch", |
| "model_path": yolo_model.model_path, |
| "classes": yolo_model.get_class_names(), |
| "num_classes": len(yolo_model.get_class_names()), |
| "input_size": yolo_model.get_input_size(), |
| "confidence_threshold": yolo_model.conf_threshold, |
| "iou_threshold": yolo_model.iou_threshold, |
| "device_info": yolo_model.get_device_info() |
| } |
|
|
| def process_image_to_numpy(image_data: bytes) -> tuple[np.ndarray, int, int]: |
| """Convert image bytes to numpy array for processing""" |
| image = Image.open(BytesIO(image_data)) |
| |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| |
| image_np = np.array(image) |
| |
| if len(image_np.shape) == 3 and image_np.shape[2] == 3: |
| image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) |
| elif len(image_np.shape) == 2: |
| image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2BGR) |
| |
| original_height, original_width = image_np.shape[:2] |
| return image_np, original_width, original_height |
|
|
|
|
| def create_verdict_response(detections: List[Dict], verdict_result: Dict, width: int, height: int) -> VerdictResponse: |
| """Create standardized verdict response""" |
| formatted_detections = [ |
| Detection( |
| class_id=det["class_id"], |
| class_name=det["class_name"], |
| confidence=float(det["confidence"]), |
| bbox=[float(x) for x in det["bbox"]] |
| ) |
| for det in detections |
| ] |
| |
| return VerdictResponse( |
| verdict=verdict_result["verdict"], |
| confidence=verdict_result["confidence"], |
| detections=formatted_detections, |
| message=verdict_result["message"], |
| image_size={ |
| "width": width, |
| "height": height |
| } |
| ) |
|
|
|
|
| @app.post("/check-image", response_model=VerdictResponse) |
| async def check_image_base64(request: Base64ImageRequest): |
| """ |
| CHECK IMAGE (Base64 Format) |
| |
| Send a base64 encoded image and get a verdict (GREEN/RED) based on detected objects. |
| """ |
| if yolo_model is None or verdict_engine is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
| |
| try: |
| base64_string = request.image |
| |
| if "base64," in base64_string: |
| base64_string = base64_string.split("base64,")[1] |
| |
| image_bytes = base64.b64decode(base64_string) |
| logger.info("Processing base64 encoded image") |
| |
| image_np, original_width, original_height = process_image_to_numpy(image_bytes) |
| logger.info(f"Image size: {original_width}x{original_height}") |
| |
| detections = yolo_model.predict(image_np) |
| logger.info(f"Found {len(detections)} detections") |
| |
| verdict_result = verdict_engine.get_verdict(detections) |
| |
| response = create_verdict_response(detections, verdict_result, original_width, original_height) |
| |
| logger.info(f"Verdict: {verdict_result['verdict']} (confidence: {verdict_result['confidence']:.2f})") |
| return response |
| |
| except base64.binascii.Error: |
| raise HTTPException(status_code=400, detail="Invalid base64 image data") |
| except Exception as e: |
| logger.error(f"Error processing image: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") |
|
|
|
|
| @app.post("/check-image/upload", response_model=VerdictResponse) |
| async def check_image_upload(file: UploadFile = File(...)): |
| """ |
| CHECK IMAGE (File Upload) |
| |
| Upload an image file and get a verdict (GREEN/RED) based on detected objects. |
| """ |
| if yolo_model is None or verdict_engine is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
| |
| if not file.content_type.startswith("image/"): |
| raise HTTPException( |
| status_code=400, |
| detail=f"Invalid file type: {file.content_type}. Please upload an image." |
| ) |
| |
| try: |
| logger.info(f"Processing image: {file.filename}") |
| contents = await file.read() |
| |
| image_np, original_width, original_height = process_image_to_numpy(contents) |
| logger.info(f"Image size: {original_width}x{original_height}") |
| |
| detections = yolo_model.predict(image_np) |
| logger.info(f"Found {len(detections)} detections") |
| |
| verdict_result = verdict_engine.get_verdict(detections) |
| |
| response = create_verdict_response(detections, verdict_result, original_width, original_height) |
| |
| logger.info(f"Verdict: {verdict_result['verdict']} (confidence: {verdict_result['confidence']:.2f})") |
| return response |
| |
| except Exception as e: |
| logger.error(f"Error processing image: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") |
|
|
|
|
| @app.post("/predict", response_model=VerdictResponse) |
| async def predict(file: UploadFile = File(...)): |
| """ |
| [LEGACY] Upload an image and get classification verdict |
| |
| Note: Use /check-image/upload instead for clearer API naming |
| """ |
| return await check_image_upload(file) |
|
|
| @app.post("/check-images/batch") |
| async def check_images_batch(files: List[UploadFile] = File(...)): |
| """ |
| CHECK MULTIPLE IMAGES (Batch Upload) |
| |
| Upload multiple image files (max 10) and get verdicts for each. |
| """ |
| if len(files) > 10: |
| raise HTTPException( |
| status_code=400, |
| detail="Maximum 10 images allowed per batch" |
| ) |
| |
| results = [] |
| for file in files: |
| try: |
| result = await check_image_upload(file) |
| results.append({ |
| "filename": file.filename, |
| "result": result |
| }) |
| except Exception as e: |
| results.append({ |
| "filename": file.filename, |
| "error": str(e) |
| }) |
| |
| return {"results": results} |
|
|
|
|
| @app.post("/predict/batch") |
| async def predict_batch(files: List[UploadFile] = File(...)): |
| """ |
| [LEGACY] Upload multiple images and get classification verdicts |
| |
| Note: Use /check-images/batch instead for clearer API naming |
| """ |
| return await check_images_batch(files) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| import sys |
| |
| port = int(sys.argv[1]) if len(sys.argv) > 1 else 8000 |
| |
| print(f"\n{'='*70}") |
| print(f"🚀 Starting Engineer Selfie API on http://0.0.0.0:{port}") |
| print(f"📖 API docs available at http://localhost:{port}/docs") |
| print(f"🖥️ GPU acceleration will be used if available") |
| print(f"{'='*70}\n") |
| |
| uvicorn.run(app, host="0.0.0.0", port=port) |
|
|