petter2025's picture
Add FastAPI app
2d521fd verified
"""
Admin API endpoints for API key management and audit logs.
These endpoints should be protected (e.g., by an admin API key) in production.
"""
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
from datetime import datetime
import uuid
from app.core.usage_tracker import tracker, Tier
router = APIRouter(prefix="/admin", tags=["admin"])
# Simple in‑memory admin key (replace with proper auth in production)
ADMIN_API_KEY = "admin_secret_change_me"
def verify_admin(admin_key: str = Query(..., alias="admin_key")):
if admin_key != ADMIN_API_KEY:
raise HTTPException(status_code=403, detail="Invalid admin key")
return True
class CreateKeyRequest(BaseModel):
tier: str
class UpdateTierRequest(BaseModel):
tier: str
@router.post("/keys", dependencies=[Depends(verify_admin)])
async def create_api_key(req: CreateKeyRequest):
if req.tier not in [t.value for t in Tier]:
raise HTTPException(status_code=400, detail=f"Invalid tier. Must be one of {[t.value for t in Tier]}")
new_key = f"sk_live_{uuid.uuid4().hex[:24]}"
tier_enum = Tier(req.tier)
tracker.get_or_create_api_key(new_key, tier_enum)
return {"api_key": new_key, "tier": req.tier}
@router.get("/keys", dependencies=[Depends(verify_admin)])
async def list_api_keys(limit: int = 100, offset: int = 0):
with tracker._get_conn() as conn:
rows = conn.execute(
"SELECT key, tier, created_at, last_used_at, is_active FROM api_keys ORDER BY created_at DESC LIMIT ? OFFSET ?",
(limit, offset)
).fetchall()
keys = []
for row in rows:
month = tracker._get_month_key()
usage_row = conn.execute(
"SELECT count FROM monthly_counts WHERE api_key = ? AND year_month = ?",
(row["key"], month)
).fetchone()
usage = usage_row["count"] if usage_row else 0
keys.append({
"key": row["key"],
"tier": row["tier"],
"created_at": datetime.fromtimestamp(row["created_at"]).isoformat(),
"last_used_at": datetime.fromtimestamp(row["last_used_at"]).isoformat() if row["last_used_at"] else None,
"is_active": bool(row["is_active"]),
"current_month_usage": usage,
})
return {"keys": keys, "total": len(keys)}
@router.patch("/keys/{api_key}/tier", dependencies=[Depends(verify_admin)])
async def update_key_tier(
api_key: str = Path(..., description="The API key to update"),
req: UpdateTierRequest = Body(...),
):
if req.tier not in [t.value for t in Tier]:
raise HTTPException(status_code=400, detail=f"Invalid tier. Must be one of {[t.value for t in Tier]}")
with tracker._get_conn() as conn:
row = conn.execute("SELECT key FROM api_keys WHERE key = ?", (api_key,)).fetchone()
if not row:
raise HTTPException(status_code=404, detail="API key not found")
conn.execute("UPDATE api_keys SET tier = ? WHERE key = ?", (req.tier, api_key))
conn.commit()
return {"message": f"Tier updated to {req.tier}"}
@router.delete("/keys/{api_key}", dependencies=[Depends(verify_admin)])
async def deactivate_api_key(api_key: str = Path(..., description="The API key to deactivate")):
with tracker._get_conn() as conn:
row = conn.execute("SELECT key FROM api_keys WHERE key = ?", (api_key,)).fetchone()
if not row:
raise HTTPException(status_code=404, detail="API key not found")
conn.execute("UPDATE api_keys SET is_active = 0 WHERE key = ?", (api_key,))
conn.commit()
return {"message": "API key deactivated"}
@router.get("/audit/{api_key}", dependencies=[Depends(verify_admin)])
async def get_audit_logs(
api_key: str = Path(..., description="The API key to audit"),
start_date: Optional[str] = Query(None),
end_date: Optional[str] = Query(None),
limit: int = 100,
):
start = datetime.fromisoformat(start_date) if start_date else None
end = datetime.fromisoformat(end_date) if end_date else None
logs = tracker.get_audit_logs(api_key, start, end, limit)
return {"api_key": api_key, "logs": logs}
@router.get("/stats", dependencies=[Depends(verify_admin)])
async def get_global_stats():
with tracker._get_conn() as conn:
total_keys = conn.execute("SELECT COUNT(*) FROM api_keys WHERE is_active = 1").fetchone()[0]
total_requests = conn.execute("SELECT COUNT(*) FROM usage_log").fetchone()[0]
by_tier = conn.execute(
"SELECT tier, COUNT(*) as count FROM usage_log GROUP BY tier"
).fetchall()
month = tracker._get_month_key()
current_month_requests = conn.execute(
"SELECT SUM(count) FROM monthly_counts WHERE year_month = ?", (month,)
).fetchone()[0] or 0
return {
"active_api_keys": total_keys,
"total_evaluations": total_requests,
"current_month_evaluations": current_month_requests,
"by_tier": [{"tier": row[0], "count": row[1]} for row in by_tier],
}