Spaces:
Sleeping
Sleeping
| """ | |
| SW Identifier β FastAPI server | |
| Routes | |
| ------ | |
| GET / SPA frontend | |
| POST /predict Internal SPA endpoint (no auth) | |
| POST /api/v1/predict Public API (requires X-API-Key header) | |
| GET /api/v1/keys List API keys (requires X-Admin-Key header) | |
| POST /api/v1/keys Create API key (requires X-Admin-Key header) | |
| DELETE /api/v1/keys/{key} Revoke API key (requires X-Admin-Key header) | |
| GET /docs OpenAPI / Swagger UI | |
| """ | |
| import csv | |
| import io | |
| import json | |
| import logging | |
| import os | |
| import secrets | |
| import sys | |
| import time | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime, timezone | |
| from typing import List, Optional | |
| import numpy as np | |
| from PIL import Image | |
| from fastapi import APIRouter, Depends, FastAPI, File, HTTPException, Security, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.security import APIKeyHeader | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| # ββ paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE = os.path.dirname(os.path.abspath(__file__)) | |
| DETECTOR_PATH = os.path.join(BASE, "detector", "model.pt") | |
| SEGMENTATOR_PATH = os.path.join(BASE, "segmentator", "model.ts") | |
| CLASSIFIER_CKPT = os.path.join(BASE, "classification_model", "model.ckpt") | |
| DATABASE_PATH = os.path.join(BASE, "classification_model", "database.pt") | |
| STATIC_DIR = os.path.join(BASE, "static") | |
| TAXONS_CSV = os.path.join(BASE, "taxons.csv") | |
| KEYS_FILE = os.path.join(BASE, "api_keys.json") | |
| sys.path.insert(0, BASE) | |
| # ββ logging βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logging.basicConfig(level=logging.WARNING) | |
| log = logging.getLogger("sw.app") | |
| # ββ common name lookup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_common_names(path: str) -> dict: | |
| mapping = {} | |
| with open(path, newline="", encoding="utf-8") as f: | |
| for row in csv.DictReader(f): | |
| taxon = row["taxon"].strip() | |
| common = row["common_name"].strip() | |
| if taxon: | |
| mapping[taxon] = common or taxon | |
| return mapping | |
| COMMON_NAMES: dict = _load_common_names(TAXONS_CSV) | |
| # ββ API key store βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_keys() -> list: | |
| if os.path.exists(KEYS_FILE): | |
| with open(KEYS_FILE, encoding="utf-8") as f: | |
| return json.load(f) | |
| return [] | |
| def _save_keys(keys: list) -> None: | |
| with open(KEYS_FILE, "w", encoding="utf-8") as f: | |
| json.dump(keys, f, indent=2) | |
| def _valid_key_set() -> set: | |
| # Prefer env var (comma-separated) β required for stateless deployments | |
| # like HF Spaces where the filesystem is ephemeral. | |
| env = os.environ.get("SW_API_KEYS", "").strip() | |
| if env: | |
| return {k.strip() for k in env.split(",") if k.strip()} | |
| return {k["key"] for k in _load_keys()} | |
| def _new_key(name: str) -> dict: | |
| return { | |
| "key": "fsh_" + secrets.token_urlsafe(32), | |
| "name": name, | |
| "created_at": datetime.now(timezone.utc).isoformat(), | |
| } | |
| # Ensure at least one key exists on startup; print it once to console. | |
| def _bootstrap_keys() -> None: | |
| # Skip file-based bootstrap when keys are supplied via env var. | |
| if os.environ.get("SW_API_KEYS", "").strip(): | |
| return | |
| keys = _load_keys() | |
| if not keys: | |
| k = _new_key("default") | |
| _save_keys([k]) | |
| print("\n" + "β" * 60) | |
| print(" No API keys found β generated a default key:") | |
| print(f" {k['key']}") | |
| print(" Store this somewhere safe; it won't be shown again.") | |
| print("β" * 60 + "\n") | |
| # Admin key β set SW_ADMIN_KEY env var, or one is auto-generated once. | |
| _ADMIN_KEY_FILE = os.path.join(BASE, ".admin_key") | |
| def _get_admin_key() -> str: | |
| env = os.environ.get("SW_ADMIN_KEY") | |
| if env: | |
| return env | |
| if os.path.exists(_ADMIN_KEY_FILE): | |
| with open(_ADMIN_KEY_FILE) as f: | |
| return f.read().strip() | |
| key = "fadm_" + secrets.token_urlsafe(32) | |
| with open(_ADMIN_KEY_FILE, "w") as f: | |
| f.write(key) | |
| print("\n" + "β" * 60) | |
| print(" Admin key (manage API keys):") | |
| print(f" {key}") | |
| print(" Stored in .admin_key β keep it out of version control.") | |
| print("β" * 60 + "\n") | |
| return key | |
| ADMIN_KEY: str = "" # set during lifespan | |
| # ββ model globals βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| detector = None | |
| segmentator = None | |
| classifier = None | |
| # ββ lifespan ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def lifespan(app: FastAPI): | |
| global detector, segmentator, classifier, ADMIN_KEY | |
| _bootstrap_keys() | |
| ADMIN_KEY = _get_admin_key() | |
| from ultralytics import YOLO | |
| log.warning("Loading detector β¦") | |
| detector = YOLO(DETECTOR_PATH) | |
| log.warning("Loading segmentator β¦") | |
| from segmentator.inference import Inference as Segmentator | |
| segmentator = Segmentator(SEGMENTATOR_PATH) | |
| log.warning("Loading classifier β¦") | |
| from classification_model.inference import EmbeddingClassifier | |
| classifier = EmbeddingClassifier({ | |
| "log_level": "WARNING", | |
| "dataset": {"path": DATABASE_PATH}, | |
| "model": { | |
| "checkpoint_path": CLASSIFIER_CKPT, | |
| "backbone_model_name": "beitv2_base_patch16_224.in1k_ft_in22k_in1k", | |
| "embedding_dim": 512, | |
| "num_classes": 775, | |
| "arcface_s": 64.0, | |
| "arcface_m": 0.2, | |
| "pooling_type": "attention", | |
| "device": "cpu", | |
| }, | |
| "use_knn": True, | |
| "arcface_min_score": 0.1, | |
| "centroid_fallback_score": 0.1, | |
| "topk_centroid": 5, | |
| "topk_neighbors": 10, | |
| "topk_arcface": 5, | |
| "centroid_threshold": 0.7, | |
| "neighbor_threshold": 0.8, | |
| "use_albumentations": False, | |
| }) | |
| log.warning("All models ready.") | |
| yield | |
| log.warning("Shutting down.") | |
| # ββ Pydantic response models ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class BoundingBox(BaseModel): | |
| x1: int | |
| y1: int | |
| x2: int | |
| y2: int | |
| confidence: float | |
| class Prediction(BaseModel): | |
| name: str # common name | |
| taxon: str # scientific name | |
| accuracy: float # confidence 0β1 | |
| species_id: str | |
| class Detection(BaseModel): | |
| bbox: BoundingBox | |
| polygon: Optional[List[List[int]]] # [[x,y], ...] in original image coords | |
| predictions: List[Prediction] | |
| class ImageSize(BaseModel): | |
| width: int | |
| height: int | |
| class Timing(BaseModel): | |
| detect_ms: int | |
| segment_ms: int | |
| classify_ms: int | |
| total_ms: int | |
| class PredictResponse(BaseModel): | |
| detections: List[Detection] | |
| image_size: ImageSize | |
| timing: Timing | |
| # ββ shared pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def _run_pipeline(raw: bytes) -> PredictResponse: | |
| try: | |
| image_rgb = np.array(Image.open(io.BytesIO(raw)).convert("RGB")) | |
| except Exception as exc: | |
| raise HTTPException(status_code=400, detail=f"Cannot decode image: {exc}") | |
| h, w = image_rgb.shape[:2] | |
| t_start = time.perf_counter() | |
| # 1. Detection | |
| t0 = time.perf_counter() | |
| yolo_out = detector.predict( | |
| source=image_rgb, imgsz=640, conf=0.25, iou=0.45, | |
| device="cpu", verbose=False, save=False, | |
| ) | |
| detect_ms = (time.perf_counter() - t0) * 1000 | |
| boxes_raw = yolo_out[0].boxes.data.cpu().numpy() if yolo_out else [] | |
| detections: List[Detection] = [] | |
| seg_ms_total = 0.0 | |
| cls_ms_total = 0.0 | |
| for box in boxes_raw: | |
| x1 = max(0, int(box[0])); y1 = max(0, int(box[1])) | |
| x2 = min(w, int(box[2])); y2 = min(h, int(box[3])) | |
| confidence = float(box[4]) | |
| if x2 <= x1 or y2 <= y1: | |
| continue | |
| crop_rgb = image_rgb[y1:y2, x1:x2] | |
| # 2. Segmentation | |
| polygon_coords = None | |
| masked_crop = crop_rgb | |
| t0 = time.perf_counter() | |
| try: | |
| seg_results = segmentator.predict(crop_rgb) | |
| if seg_results: | |
| poly = seg_results[0] | |
| polygon_coords = [[int(px) + x1, int(py) + y1] for px, py in poly.points] | |
| masked_crop = poly.mask_polygon(crop_rgb) | |
| except Exception as exc: | |
| log.warning("Segmentator error: %s", exc) | |
| seg_ms_total += (time.perf_counter() - t0) * 1000 | |
| # 3. Classification | |
| pred_list: List[Prediction] = [] | |
| t0 = time.perf_counter() | |
| try: | |
| preds = classifier(masked_crop) | |
| for p in (preds or [])[:3]: | |
| pred_list.append(Prediction( | |
| name = COMMON_NAMES.get(p.name, p.name), | |
| taxon = p.name, | |
| accuracy = round(float(p.accuracy), 4), | |
| species_id = str(p.species_id), | |
| )) | |
| except Exception as exc: | |
| log.warning("Classifier error: %s", exc) | |
| cls_ms_total += (time.perf_counter() - t0) * 1000 | |
| detections.append(Detection( | |
| bbox = BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2, | |
| confidence=round(confidence, 3)), | |
| polygon = polygon_coords, | |
| predictions = pred_list, | |
| )) | |
| total_ms = (time.perf_counter() - t_start) * 1000 | |
| return PredictResponse( | |
| detections = detections, | |
| image_size = ImageSize(width=w, height=h), | |
| timing = Timing( | |
| detect_ms = round(detect_ms), | |
| segment_ms = round(seg_ms_total), | |
| classify_ms = round(cls_ms_total), | |
| total_ms = round(total_ms), | |
| ), | |
| ) | |
| # ββ auth dependencies βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _api_key_header = APIKeyHeader(name="X-API-Key", auto_error=True) | |
| _admin_key_header = APIKeyHeader(name="X-Admin-Key", auto_error=True) | |
| def _require_api_key(key: str = Security(_api_key_header)): | |
| if key not in _valid_key_set(): | |
| raise HTTPException(status_code=401, detail="Invalid or missing API key.") | |
| return key | |
| def _require_admin_key(key: str = Security(_admin_key_header)): | |
| if key != ADMIN_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid admin key.") | |
| return key | |
| # ββ app & middleware ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title = "SW Identifier API", | |
| description = "Fish detection, segmentation, and species classification.", | |
| version = "1.0.0", | |
| lifespan = lifespan, | |
| docs_url = "/api/docs", | |
| redoc_url = "/api/redoc", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins = ["*"], | |
| allow_methods = ["GET", "POST", "DELETE"], | |
| allow_headers = ["*"], | |
| ) | |
| app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") | |
| # ββ SPA routes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def root(): | |
| with open(os.path.join(STATIC_DIR, "index.html"), encoding="utf-8") as fh: | |
| return fh.read() | |
| async def predict_spa(file: UploadFile = File(...)): | |
| """Internal endpoint used by the SPA β no auth required.""" | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Upload must be an image file.") | |
| return await _run_pipeline(await file.read()) | |
| # ββ public API v1 βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| api = APIRouter(prefix="/api/v1", tags=["SW Identifier API"]) | |
| async def predict_api( | |
| file: UploadFile = File(..., description="Image file (JPEG, PNG, WEBP, β¦)"), | |
| _key: str = Depends(_require_api_key), | |
| ): | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Upload must be an image file.") | |
| return await _run_pipeline(await file.read()) | |
| # ββ key management ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class KeyRecord(BaseModel): | |
| key: str | |
| name: str | |
| created_at: str | |
| class CreateKeyRequest(BaseModel): | |
| name: str = "unnamed" | |
| async def list_keys(_admin: str = Depends(_require_admin_key)): | |
| return _load_keys() | |
| async def create_key( | |
| body: CreateKeyRequest = CreateKeyRequest(), | |
| _admin: str = Depends(_require_admin_key), | |
| ): | |
| keys = _load_keys() | |
| k = _new_key(body.name) | |
| keys.append(k) | |
| _save_keys(keys) | |
| return k | |
| async def revoke_key(key: str, _admin: str = Depends(_require_admin_key)): | |
| keys = _load_keys() | |
| remaining = [k for k in keys if k["key"] != key] | |
| if len(remaining) == len(keys): | |
| raise HTTPException(status_code=404, detail="Key not found.") | |
| _save_keys(remaining) | |
| app.include_router(api) | |
| # ββ entry point βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False) | |