File size: 5,185 Bytes
2d521fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Admin API endpoints for API key management and audit logs.
These endpoints should be protected (e.g., by an admin API key) in production.
"""
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
from datetime import datetime
import uuid

from app.core.usage_tracker import tracker, Tier

router = APIRouter(prefix="/admin", tags=["admin"])

# Simple in‑memory admin key (replace with proper auth in production)
ADMIN_API_KEY = "admin_secret_change_me"

def verify_admin(admin_key: str = Query(..., alias="admin_key")):
    if admin_key != ADMIN_API_KEY:
        raise HTTPException(status_code=403, detail="Invalid admin key")
    return True

class CreateKeyRequest(BaseModel):
    tier: str

class UpdateTierRequest(BaseModel):
    tier: str


@router.post("/keys", dependencies=[Depends(verify_admin)])
async def create_api_key(req: CreateKeyRequest):
    if req.tier not in [t.value for t in Tier]:
        raise HTTPException(status_code=400, detail=f"Invalid tier. Must be one of {[t.value for t in Tier]}")
    new_key = f"sk_live_{uuid.uuid4().hex[:24]}"
    tier_enum = Tier(req.tier)
    tracker.get_or_create_api_key(new_key, tier_enum)
    return {"api_key": new_key, "tier": req.tier}


@router.get("/keys", dependencies=[Depends(verify_admin)])
async def list_api_keys(limit: int = 100, offset: int = 0):
    with tracker._get_conn() as conn:
        rows = conn.execute(
            "SELECT key, tier, created_at, last_used_at, is_active FROM api_keys ORDER BY created_at DESC LIMIT ? OFFSET ?",
            (limit, offset)
        ).fetchall()
        keys = []
        for row in rows:
            month = tracker._get_month_key()
            usage_row = conn.execute(
                "SELECT count FROM monthly_counts WHERE api_key = ? AND year_month = ?",
                (row["key"], month)
            ).fetchone()
            usage = usage_row["count"] if usage_row else 0
            keys.append({
                "key": row["key"],
                "tier": row["tier"],
                "created_at": datetime.fromtimestamp(row["created_at"]).isoformat(),
                "last_used_at": datetime.fromtimestamp(row["last_used_at"]).isoformat() if row["last_used_at"] else None,
                "is_active": bool(row["is_active"]),
                "current_month_usage": usage,
            })
        return {"keys": keys, "total": len(keys)}


@router.patch("/keys/{api_key}/tier", dependencies=[Depends(verify_admin)])
async def update_key_tier(
    api_key: str = Path(..., description="The API key to update"),
    req: UpdateTierRequest = Body(...),
):
    if req.tier not in [t.value for t in Tier]:
        raise HTTPException(status_code=400, detail=f"Invalid tier. Must be one of {[t.value for t in Tier]}")
    with tracker._get_conn() as conn:
        row = conn.execute("SELECT key FROM api_keys WHERE key = ?", (api_key,)).fetchone()
        if not row:
            raise HTTPException(status_code=404, detail="API key not found")
        conn.execute("UPDATE api_keys SET tier = ? WHERE key = ?", (req.tier, api_key))
        conn.commit()
    return {"message": f"Tier updated to {req.tier}"}


@router.delete("/keys/{api_key}", dependencies=[Depends(verify_admin)])
async def deactivate_api_key(api_key: str = Path(..., description="The API key to deactivate")):
    with tracker._get_conn() as conn:
        row = conn.execute("SELECT key FROM api_keys WHERE key = ?", (api_key,)).fetchone()
        if not row:
            raise HTTPException(status_code=404, detail="API key not found")
        conn.execute("UPDATE api_keys SET is_active = 0 WHERE key = ?", (api_key,))
        conn.commit()
    return {"message": "API key deactivated"}


@router.get("/audit/{api_key}", dependencies=[Depends(verify_admin)])
async def get_audit_logs(
    api_key: str = Path(..., description="The API key to audit"),
    start_date: Optional[str] = Query(None),
    end_date: Optional[str] = Query(None),
    limit: int = 100,
):
    start = datetime.fromisoformat(start_date) if start_date else None
    end = datetime.fromisoformat(end_date) if end_date else None
    logs = tracker.get_audit_logs(api_key, start, end, limit)
    return {"api_key": api_key, "logs": logs}


@router.get("/stats", dependencies=[Depends(verify_admin)])
async def get_global_stats():
    with tracker._get_conn() as conn:
        total_keys = conn.execute("SELECT COUNT(*) FROM api_keys WHERE is_active = 1").fetchone()[0]
        total_requests = conn.execute("SELECT COUNT(*) FROM usage_log").fetchone()[0]
        by_tier = conn.execute(
            "SELECT tier, COUNT(*) as count FROM usage_log GROUP BY tier"
        ).fetchall()
        month = tracker._get_month_key()
        current_month_requests = conn.execute(
            "SELECT SUM(count) FROM monthly_counts WHERE year_month = ?", (month,)
        ).fetchone()[0] or 0
    return {
        "active_api_keys": total_keys,
        "total_evaluations": total_requests,
        "current_month_evaluations": current_month_requests,
        "by_tier": [{"tier": row[0], "count": row[1]} for row in by_tier],
    }