| """ |
| Usage Tracker for ARF API – quotas, tiers, and audit logging. |
| Non‑invasive, configurable, thread‑safe, and background‑task ready. |
| """ |
|
|
| import os |
| import json |
| import sqlite3 |
| import threading |
| import time |
| from contextlib import contextmanager |
| from datetime import datetime, timedelta |
| from typing import Dict, Any, Optional, List |
| from enum import Enum |
| from dataclasses import dataclass |
| from fastapi import BackgroundTasks |
|
|
| |
| try: |
| import redis |
| REDIS_AVAILABLE = True |
| except ImportError: |
| REDIS_AVAILABLE = False |
| redis = None |
|
|
|
|
| class Tier(str, Enum): |
| FREE = "free" |
| PRO = "pro" |
| PREMIUM = "premium" |
| ENTERPRISE = "enterprise" |
|
|
| @property |
| def monthly_evaluation_limit(self) -> Optional[int]: |
| limits = { |
| Tier.FREE: 1000, |
| Tier.PRO: 10_000, |
| Tier.PREMIUM: 50_000, |
| Tier.ENTERPRISE: None, |
| } |
| return limits[self] |
|
|
| @property |
| def audit_log_retention_days(self) -> int: |
| retention = { |
| Tier.FREE: 7, |
| Tier.PRO: 30, |
| Tier.PREMIUM: 90, |
| Tier.ENTERPRISE: 365, |
| } |
| return retention[self] |
|
|
|
|
| @dataclass |
| class UsageRecord: |
| """Single evaluation usage record.""" |
| api_key: str |
| tier: Tier |
| timestamp: float |
| endpoint: str |
| request_body: Optional[Dict[str, Any]] = None |
| response: Optional[Dict[str, Any]] = None |
| error: Optional[str] = None |
| processing_ms: Optional[float] = None |
|
|
|
|
| class UsageTracker: |
| """ |
| Thread‑safe usage tracker with SQLite storage and optional Redis for counters. |
| """ |
|
|
| def __init__(self, db_path: str = "arf_usage.db", redis_url: Optional[str] = None): |
| self.db_path = db_path |
| self._local = threading.local() |
| self._init_db() |
|
|
| self._redis_client = None |
| if redis_url and REDIS_AVAILABLE: |
| self._redis_client = redis.from_url(redis_url) |
| elif redis_url: |
| raise ImportError("Redis client not installed. Run: pip install redis") |
|
|
| @contextmanager |
| def _get_conn(self): |
| """Get a thread‑local SQLite connection.""" |
| if not hasattr(self._local, "conn"): |
| self._local.conn = sqlite3.connect(self.db_path, check_same_thread=False) |
| self._local.conn.row_factory = sqlite3.Row |
| yield self._local.conn |
|
|
| def _init_db(self): |
| with self._get_conn() as conn: |
| conn.execute(""" |
| CREATE TABLE IF NOT EXISTS api_keys ( |
| key TEXT PRIMARY KEY, |
| tier TEXT NOT NULL, |
| created_at REAL NOT NULL, |
| last_used_at REAL, |
| is_active INTEGER DEFAULT 1 |
| ) |
| """) |
| conn.execute(""" |
| CREATE TABLE IF NOT EXISTS usage_log ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| api_key TEXT NOT NULL, |
| tier TEXT NOT NULL, |
| timestamp REAL NOT NULL, |
| endpoint TEXT NOT NULL, |
| request_body TEXT, |
| response TEXT, |
| error TEXT, |
| processing_ms REAL |
| ) |
| """) |
| conn.execute(""" |
| CREATE INDEX IF NOT EXISTS idx_api_key_timestamp |
| ON usage_log (api_key, timestamp) |
| """) |
| conn.execute(""" |
| CREATE TABLE IF NOT EXISTS monthly_counts ( |
| api_key TEXT NOT NULL, |
| year_month TEXT NOT NULL, |
| count INTEGER DEFAULT 0, |
| PRIMARY KEY (api_key, year_month) |
| ) |
| """) |
| conn.commit() |
|
|
| def _get_month_key(self) -> str: |
| return datetime.now().strftime("%Y-%m") |
|
|
| def get_or_create_api_key(self, key: str, tier: Tier = Tier.FREE) -> bool: |
| """Register a new API key. Returns True if key exists or was created.""" |
| with self._get_conn() as conn: |
| row = conn.execute("SELECT key FROM api_keys WHERE key = ?", (key,)).fetchone() |
| if row: |
| return True |
| conn.execute( |
| "INSERT INTO api_keys (key, tier, created_at, is_active) VALUES (?, ?, ?, ?)", |
| (key, tier.value, time.time(), 1) |
| ) |
| conn.commit() |
| return True |
|
|
| def get_tier(self, api_key: str) -> Optional[Tier]: |
| """Return the tier for a given API key, or None if key invalid/inactive.""" |
| with self._get_conn() as conn: |
| row = conn.execute( |
| "SELECT tier FROM api_keys WHERE key = ? AND is_active = 1", |
| (api_key,) |
| ).fetchone() |
| if not row: |
| return None |
| return Tier(row["tier"]) |
|
|
| def update_api_key_tier(self, api_key: str, new_tier: Tier) -> bool: |
| """Update the tier of an existing API key. Returns True if successful.""" |
| with self._get_conn() as conn: |
| row = conn.execute("SELECT key FROM api_keys WHERE key = ?", (api_key,)).fetchone() |
| if not row: |
| return False |
| conn.execute("UPDATE api_keys SET tier = ? WHERE key = ?", (new_tier.value, api_key)) |
| conn.commit() |
| return True |
|
|
| def get_remaining_quota(self, api_key: str, tier: Tier) -> Optional[int]: |
| """Return remaining evaluations for the month, or None if unlimited.""" |
| limit = tier.monthly_evaluation_limit |
| if limit is None: |
| return None |
|
|
| month = self._get_month_key() |
| if self._redis_client: |
| redis_key = f"arf:quota:{api_key}:{month}" |
| count = int(self._redis_client.get(redis_key) or 0) |
| return max(0, limit - count) |
|
|
| with self._get_conn() as conn: |
| row = conn.execute( |
| "SELECT count FROM monthly_counts WHERE api_key = ? AND year_month = ?", |
| (api_key, month) |
| ).fetchone() |
| count = row["count"] if row else 0 |
| return max(0, limit - count) |
|
|
| def _increment_quota(self, api_key: str, tier: Tier) -> None: |
| """Increment the monthly counter (internal, synchronous).""" |
| limit = tier.monthly_evaluation_limit |
| if limit is None: |
| return |
| month = self._get_month_key() |
| if self._redis_client: |
| redis_key = f"arf:quota:{api_key}:{month}" |
| self._redis_client.incr(redis_key) |
| self._redis_client.expire(redis_key, timedelta(days=31)) |
| else: |
| with self._get_conn() as conn: |
| conn.execute( |
| """INSERT INTO monthly_counts (api_key, year_month, count) |
| VALUES (?, ?, 1) |
| ON CONFLICT(api_key, year_month) DO UPDATE SET count = count + 1""", |
| (api_key, month) |
| ) |
| conn.commit() |
|
|
| def _insert_audit_log(self, record: UsageRecord) -> None: |
| """Insert a single audit log (internal, synchronous).""" |
| with self._get_conn() as conn: |
| conn.execute( |
| """INSERT INTO usage_log |
| (api_key, tier, timestamp, endpoint, request_body, response, error, processing_ms) |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", |
| ( |
| record.api_key, |
| record.tier.value, |
| record.timestamp, |
| record.endpoint, |
| json.dumps(record.request_body) if record.request_body else None, |
| json.dumps(record.response) if record.response else None, |
| record.error, |
| record.processing_ms, |
| ) |
| ) |
| conn.commit() |
|
|
| def increment_usage_sync(self, record: UsageRecord) -> bool: |
| """ |
| Synchronously record usage and increment counter. |
| Returns True if within quota (i.e., counter was incremented), False if quota exceeded. |
| """ |
| tier = record.tier |
| limit = tier.monthly_evaluation_limit |
| if limit is not None: |
| remaining = self.get_remaining_quota(record.api_key, tier) |
| if remaining <= 0: |
| return False |
| self._increment_quota(record.api_key, tier) |
| self._insert_audit_log(record) |
| return True |
|
|
| async def increment_usage_async(self, record: UsageRecord, background_tasks: BackgroundTasks) -> bool: |
| """ |
| Asynchronously record usage using FastAPI BackgroundTasks. |
| Returns True if quota allows (i.e., will be recorded), False if quota exceeded. |
| """ |
| tier = record.tier |
| limit = tier.monthly_evaluation_limit |
| if limit is not None: |
| remaining = self.get_remaining_quota(record.api_key, tier) |
| if remaining <= 0: |
| return False |
| |
| background_tasks.add_task(self._increment_quota, record.api_key, tier) |
| background_tasks.add_task(self._insert_audit_log, record) |
| return True |
|
|
| def get_audit_logs( |
| self, |
| api_key: str, |
| start_date: Optional[datetime] = None, |
| end_date: Optional[datetime] = None, |
| limit: int = 100, |
| ) -> List[Dict[str, Any]]: |
| """Retrieve audit logs for a given API key.""" |
| query = "SELECT * FROM usage_log WHERE api_key = ?" |
| params = [api_key] |
| if start_date: |
| query += " AND timestamp >= ?" |
| params.append(start_date.timestamp()) |
| if end_date: |
| query += " AND timestamp <= ?" |
| params.append(end_date.timestamp()) |
| query += " ORDER BY timestamp DESC LIMIT ?" |
| params.append(limit) |
|
|
| with self._get_conn() as conn: |
| rows = conn.execute(query, params).fetchall() |
| return [dict(row) for row in rows] |
|
|
| def clean_old_logs(self): |
| """Delete logs older than retention period for each tier.""" |
| with self._get_conn() as conn: |
| for tier in Tier: |
| retention_days = tier.audit_log_retention_days |
| if retention_days is None: |
| continue |
| cutoff = time.time() - retention_days * 86400 |
| conn.execute( |
| "DELETE FROM usage_log WHERE tier = ? AND timestamp < ?", |
| (tier.value, cutoff) |
| ) |
| conn.commit() |
|
|
|
|
| |
| tracker: Optional[UsageTracker] = None |
|
|
|
|
| def init_tracker(db_path: str = "arf_usage.db", redis_url: Optional[str] = None): |
| global tracker |
| tracker = UsageTracker(db_path, redis_url) |
|
|
|
|
| def update_key_tier(api_key: str, new_tier: Tier) -> bool: |
| """Globally accessible helper to update API key tier.""" |
| if tracker is None: |
| return False |
| return tracker.update_api_key_tier(api_key, new_tier) |
|
|
|
|
| |
| from fastapi import HTTPException, Request |
|
|
| async def enforce_quota(request: Request, api_key: str = None): |
| """ |
| Dependency that checks API key and remaining quota. |
| Use in your endpoint: `quota = Depends(enforce_quota)` |
| |
| If usage tracking is disabled, returns a default dict (no enforcement). |
| """ |
| |
| if tracker is None: |
| return {"api_key": api_key or "disabled", "tier": Tier.FREE, "remaining": None} |
|
|
| |
| if api_key is None: |
| auth_header = request.headers.get("Authorization") |
| if auth_header and auth_header.startswith("Bearer "): |
| api_key = auth_header[7:] |
| else: |
| api_key = request.query_params.get("api_key") |
|
|
| if not api_key: |
| raise HTTPException(status_code=401, detail="Missing API key") |
|
|
| tier = tracker.get_tier(api_key) |
| if tier is None: |
| raise HTTPException(status_code=403, detail="Invalid or inactive API key") |
|
|
| remaining = tracker.get_remaining_quota(api_key, tier) |
| if remaining is not None and remaining <= 0: |
| raise HTTPException(status_code=429, detail="Monthly evaluation quota exceeded") |
|
|
| |
| request.state.api_key = api_key |
| request.state.tier = tier |
| return {"api_key": api_key, "tier": tier, "remaining": remaining} |
|
|