""" Admin API endpoints for API key management and audit logs. These endpoints should be protected (e.g., by an admin API key) in production. """ from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body from pydantic import BaseModel from typing import Optional, List, Dict, Any from datetime import datetime import uuid from app.core.usage_tracker import tracker, Tier router = APIRouter(prefix="/admin", tags=["admin"]) # Simple in‑memory admin key (replace with proper auth in production) ADMIN_API_KEY = "admin_secret_change_me" def verify_admin(admin_key: str = Query(..., alias="admin_key")): if admin_key != ADMIN_API_KEY: raise HTTPException(status_code=403, detail="Invalid admin key") return True class CreateKeyRequest(BaseModel): tier: str class UpdateTierRequest(BaseModel): tier: str @router.post("/keys", dependencies=[Depends(verify_admin)]) async def create_api_key(req: CreateKeyRequest): if req.tier not in [t.value for t in Tier]: raise HTTPException(status_code=400, detail=f"Invalid tier. Must be one of {[t.value for t in Tier]}") new_key = f"sk_live_{uuid.uuid4().hex[:24]}" tier_enum = Tier(req.tier) tracker.get_or_create_api_key(new_key, tier_enum) return {"api_key": new_key, "tier": req.tier} @router.get("/keys", dependencies=[Depends(verify_admin)]) async def list_api_keys(limit: int = 100, offset: int = 0): with tracker._get_conn() as conn: rows = conn.execute( "SELECT key, tier, created_at, last_used_at, is_active FROM api_keys ORDER BY created_at DESC LIMIT ? OFFSET ?", (limit, offset) ).fetchall() keys = [] for row in rows: month = tracker._get_month_key() usage_row = conn.execute( "SELECT count FROM monthly_counts WHERE api_key = ? AND year_month = ?", (row["key"], month) ).fetchone() usage = usage_row["count"] if usage_row else 0 keys.append({ "key": row["key"], "tier": row["tier"], "created_at": datetime.fromtimestamp(row["created_at"]).isoformat(), "last_used_at": datetime.fromtimestamp(row["last_used_at"]).isoformat() if row["last_used_at"] else None, "is_active": bool(row["is_active"]), "current_month_usage": usage, }) return {"keys": keys, "total": len(keys)} @router.patch("/keys/{api_key}/tier", dependencies=[Depends(verify_admin)]) async def update_key_tier( api_key: str = Path(..., description="The API key to update"), req: UpdateTierRequest = Body(...), ): if req.tier not in [t.value for t in Tier]: raise HTTPException(status_code=400, detail=f"Invalid tier. Must be one of {[t.value for t in Tier]}") with tracker._get_conn() as conn: row = conn.execute("SELECT key FROM api_keys WHERE key = ?", (api_key,)).fetchone() if not row: raise HTTPException(status_code=404, detail="API key not found") conn.execute("UPDATE api_keys SET tier = ? WHERE key = ?", (req.tier, api_key)) conn.commit() return {"message": f"Tier updated to {req.tier}"} @router.delete("/keys/{api_key}", dependencies=[Depends(verify_admin)]) async def deactivate_api_key(api_key: str = Path(..., description="The API key to deactivate")): with tracker._get_conn() as conn: row = conn.execute("SELECT key FROM api_keys WHERE key = ?", (api_key,)).fetchone() if not row: raise HTTPException(status_code=404, detail="API key not found") conn.execute("UPDATE api_keys SET is_active = 0 WHERE key = ?", (api_key,)) conn.commit() return {"message": "API key deactivated"} @router.get("/audit/{api_key}", dependencies=[Depends(verify_admin)]) async def get_audit_logs( api_key: str = Path(..., description="The API key to audit"), start_date: Optional[str] = Query(None), end_date: Optional[str] = Query(None), limit: int = 100, ): start = datetime.fromisoformat(start_date) if start_date else None end = datetime.fromisoformat(end_date) if end_date else None logs = tracker.get_audit_logs(api_key, start, end, limit) return {"api_key": api_key, "logs": logs} @router.get("/stats", dependencies=[Depends(verify_admin)]) async def get_global_stats(): with tracker._get_conn() as conn: total_keys = conn.execute("SELECT COUNT(*) FROM api_keys WHERE is_active = 1").fetchone()[0] total_requests = conn.execute("SELECT COUNT(*) FROM usage_log").fetchone()[0] by_tier = conn.execute( "SELECT tier, COUNT(*) as count FROM usage_log GROUP BY tier" ).fetchall() month = tracker._get_month_key() current_month_requests = conn.execute( "SELECT SUM(count) FROM monthly_counts WHERE year_month = ?", (month,) ).fetchone()[0] or 0 return { "active_api_keys": total_keys, "total_evaluations": total_requests, "current_month_evaluations": current_month_requests, "by_tier": [{"tier": row[0], "count": row[1]} for row in by_tier], }