| """ |
| Production-grade API server for ml-intern. |
| |
| Features: |
| - FastAPI with async endpoints |
| - Distributed rate limiting (Redis-backed token bucket) |
| - Circuit breaker for external API resilience |
| - Automatic fallback: NIM (primary) -> Cloudflare Workers AI (secondary) -> Gemini (tertiary) -> MLX |
| - Request/response caching with Redis TTL |
| - Multi-tenant session isolation |
| - Health checks and graceful shutdown |
| - Structured logging with correlation IDs |
| - Cost tracking and budget enforcement |
| - Connection pooling for all HTTP clients |
| - Cloudflare Workers AI support via OpenAI-compatible API |
| - Google Gemini support via OpenAI-compatible API |
| - MLX local support for Apple Silicon |
| """ |
|
|
| import asyncio |
| import hashlib |
| import json |
| import logging |
| import os |
| import signal |
| import sys |
| import time |
| import uuid |
| from contextlib import asynccontextmanager |
| from dataclasses import dataclass, field |
| from typing import Any, Optional |
|
|
| import redis.asyncio as aioredis |
| import asyncpg |
| from fastapi import FastAPI, HTTPException, Request, Depends, BackgroundTasks |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.middleware.gzip import GZipMiddleware |
| from pydantic import BaseModel, Field |
| import uvicorn |
| from prometheus_client import Counter, Histogram, Gauge, generate_latest, CONTENT_TYPE_LATEST |
| import httpx |
| from tenacity import retry, stop_after_attempt, wait_exponential, RetryError |
|
|
| |
| |
| |
|
|
| REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379") |
| DATABASE_URL = os.environ.get("DATABASE_URL", "") |
| MAX_CONCURRENT_REQUESTS = int(os.environ.get("MAX_CONCURRENT_REQUESTS", "100")) |
| DEFAULT_RPM_LIMIT = int(os.environ.get("DEFAULT_RPM_LIMIT", "40")) |
| REQUEST_TIMEOUT = float(os.environ.get("REQUEST_TIMEOUT", "120")) |
| CACHE_TTL_SECONDS = int(os.environ.get("CACHE_TTL_SECONDS", "300")) |
| BUDGET_USD_PER_SESSION = float(os.environ.get("BUDGET_USD_PER_SESSION", "10.0")) |
| CIRCUIT_BREAKER_FAILURE_THRESHOLD = int(os.environ.get("CIRCUIT_BREAKER_FAILURE_THRESHOLD", "5")) |
| CIRCUIT_BREAKER_RECOVERY_TIMEOUT = int(os.environ.get("CIRCUIT_BREAKER_RECOVERY_TIMEOUT", "60")) |
|
|
| |
| NIM_API_BASE = os.environ.get("NIM_API_BASE", "https://integrate.api.nvidia.com/v1") |
| CLOUDFLARE_API_KEY = os.environ.get("CLOUDFLARE_API_KEY", "") |
| CLOUDFLARE_ACCOUNT_ID = os.environ.get("CLOUDFLARE_ACCOUNT_ID", "") |
|
|
| |
| GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "") |
| GEMINI_API_BASE = os.environ.get("GEMINI_API_BASE", "https://generativelanguage.googleapis.com/v1beta/openai") |
|
|
| |
| FALLBACK_ENABLED = os.environ.get("FALLBACK_ENABLED", "true").lower() == "true" |
| FALLBACK_PRIMARY = os.environ.get("FALLBACK_PRIMARY", "nim") |
| FALLBACK_SECONDARY = os.environ.get("FALLBACK_SECONDARY", "cloudflare") |
| FALLBACK_TERTIARY = os.environ.get("FALLBACK_TERTIARY", "gemini") |
|
|
| |
| MLX_API_BASE = os.environ.get("MLX_API_BASE", "http://localhost:8000/v1") |
| MLX_ENABLED = os.environ.get("MLX_ENABLED", "false").lower() == "true" |
|
|
| |
| |
| |
|
|
| REQUEST_COUNT = Counter( |
| "ml_intern_requests_total", |
| "Total requests", |
| ["method", "endpoint", "status", "provider"], |
| ) |
| REQUEST_LATENCY = Histogram( |
| "ml_intern_request_duration_seconds", |
| "Request duration", |
| ["method", "endpoint", "provider"], |
| buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0], |
| ) |
| ACTIVE_SESSIONS = Gauge( |
| "ml_intern_active_sessions", |
| "Number of active sessions", |
| ) |
| LLM_COST_USD = Counter( |
| "ml_intern_llm_cost_usd_total", |
| "Total LLM cost in USD", |
| ["provider", "model"], |
| ) |
| CACHE_HIT_COUNT = Counter( |
| "ml_intern_cache_hits_total", |
| "Cache hits", |
| ["cache_type"], |
| ) |
| CACHE_MISS_COUNT = Counter( |
| "ml_intern_cache_misses_total", |
| "Cache misses", |
| ["cache_type"], |
| ) |
| CIRCUIT_BREAKER_STATE = Gauge( |
| "ml_intern_circuit_breaker_state", |
| "Circuit breaker state (0=closed, 1=half-open, 2=open)", |
| ["provider"], |
| ) |
| FALLBACK_COUNT = Counter( |
| "ml_intern_fallback_total", |
| "Fallback events between providers", |
| ["from_provider", "to_provider", "reason"], |
| ) |
|
|
| |
| |
| |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s | %(levelname)s | correlation_id=%(correlation_id)s | %(name)s | %(message)s", |
| handlers=[logging.StreamHandler(sys.stdout)], |
| ) |
| logger = logging.getLogger("ml_intern.production") |
|
|
| class CorrelationIdFilter(logging.Filter): |
| def filter(self, record: logging.LogRecord) -> bool: |
| record.correlation_id = getattr(record, "correlation_id", "none") |
| return True |
|
|
| logger.addFilter(CorrelationIdFilter()) |
|
|
| |
| |
| |
|
|
| class DatabasePool: |
| def __init__(self, dsn: str): |
| self.dsn = dsn |
| self._pool: Optional[asyncpg.Pool] = None |
| |
| async def connect(self): |
| if not self.dsn: |
| logger.warning("No DATABASE_URL set — skipping database connection") |
| return |
| self._pool = await asyncpg.create_pool( |
| self.dsn, |
| min_size=2, |
| max_size=10, |
| command_timeout=60, |
| ) |
| logger.info("Database pool connected") |
| |
| async def disconnect(self): |
| if self._pool: |
| await self._pool.close() |
| logger.info("Database pool disconnected") |
| |
| async def execute(self, query: str, *args): |
| if not self._pool: |
| return |
| async with self._pool.acquire() as conn: |
| return await conn.execute(query, *args) |
| |
| async def fetch(self, query: str, *args): |
| if not self._pool: |
| return [] |
| async with self._pool.acquire() as conn: |
| return await conn.fetch(query, *args) |
| |
| async def fetchval(self, query: str, *args): |
| if not self._pool: |
| return None |
| async with self._pool.acquire() as conn: |
| return await conn.fetchval(query, *args) |
|
|
| |
| |
| |
|
|
| class RedisManager: |
| def __init__(self, url: str): |
| self.url = url |
| self._redis: Optional[aioredis.Redis] = None |
| |
| async def connect(self): |
| self._redis = aioredis.from_url(self.url, decode_responses=True) |
| await self._redis.ping() |
| logger.info("Redis connected") |
| |
| async def disconnect(self): |
| if self._redis: |
| await self._redis.close() |
| logger.info("Redis disconnected") |
| |
| async def get_cache(self, key: str) -> Optional[str]: |
| val = await self._redis.get(key) |
| if val: |
| CACHE_HIT_COUNT.labels(cache_type="llm_response").inc() |
| else: |
| CACHE_MISS_COUNT.labels(cache_type="llm_response").inc() |
| return val |
| |
| async def set_cache(self, key: str, value: str, ttl: int = CACHE_TTL_SECONDS): |
| await self._redis.setex(key, ttl, value) |
| |
| async def check_rate_limit(self, key: str, rpm: int) -> tuple[bool, float]: |
| now = time.time() |
| bucket_key = f"ratelimit:{key}" |
| script = """ |
| local key = KEYS[1] |
| local now = tonumber(ARGV[1]) |
| local rpm = tonumber(ARGV[2]) |
| local interval = 60.0 / rpm |
| local last = redis.call('hget', key, 'last') |
| local tokens = redis.call('hget', key, 'tokens') |
| if not last then |
| last = 0 |
| tokens = 1 |
| else |
| last = tonumber(last) |
| tokens = tonumber(tokens) |
| end |
| local elapsed = now - last |
| tokens = math.min(1, tokens + elapsed / interval) |
| if tokens >= 1 then |
| tokens = tokens - 1 |
| redis.call('hmset', key, 'last', now, 'tokens', tokens) |
| redis.call('expire', key, 120) |
| return {1, 0} |
| else |
| local retry_after = interval - (elapsed % interval) |
| redis.call('hmset', key, 'last', last, 'tokens', tokens) |
| redis.call('expire', key, 120) |
| return {0, retry_after} |
| end |
| """ |
| result = await self._redis.eval(script, 1, bucket_key, now, rpm) |
| allowed = bool(result[0]) |
| retry_after = float(result[1]) if not allowed else 0.0 |
| return allowed, retry_after |
| |
| async def get_circuit_state(self, provider: str) -> dict: |
| key = f"circuit:{provider}" |
| val = await self._redis.get(key) |
| if val: |
| return json.loads(val) |
| return {"state": "closed", "failures": 0, "last_failure": 0} |
| |
| async def set_circuit_state(self, provider: str, state: dict): |
| key = f"circuit:{provider}" |
| await self._redis.setex(key, 3600, json.dumps(state)) |
|
|
| |
| |
| |
|
|
| class CircuitBreaker: |
| def __init__(self, redis: RedisManager, provider: str): |
| self.redis = redis |
| self.provider = provider |
| self.failure_threshold = CIRCUIT_BREAKER_FAILURE_THRESHOLD |
| self.recovery_timeout = CIRCUIT_BREAKER_RECOVERY_TIMEOUT |
| |
| async def can_execute(self) -> bool: |
| state = await self.redis.get_circuit_state(self.provider) |
| if state["state"] == "open": |
| if time.time() - state["last_failure"] > self.recovery_timeout: |
| state["state"] = "half-open" |
| state["failures"] = 0 |
| await self.redis.set_circuit_state(self.provider, state) |
| CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(1) |
| logger.info(f"Circuit breaker {self.provider} half-open") |
| return True |
| CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(2) |
| return False |
| CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(0 if state["state"] == "closed" else 1) |
| return True |
| |
| async def record_success(self): |
| state = await self.redis.get_circuit_state(self.provider) |
| if state["state"] == "half-open": |
| state["state"] = "closed" |
| state["failures"] = 0 |
| await self.redis.set_circuit_state(self.provider, state) |
| CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(0) |
| logger.info(f"Circuit breaker {self.provider} closed") |
| |
| async def record_failure(self): |
| state = await self.redis.get_circuit_state(self.provider) |
| state["failures"] += 1 |
| state["last_failure"] = time.time() |
| if state["failures"] >= self.failure_threshold: |
| state["state"] = "open" |
| CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(2) |
| logger.warning(f"Circuit breaker {self.provider} OPENED after {state['failures']} failures") |
| await self.redis.set_circuit_state(self.provider, state) |
|
|
| |
| |
| |
|
|
| @dataclass |
| class FallbackConfig: |
| primary: str = "nim" |
| secondary: str = "cloudflare" |
| tertiary: str = "gemini" |
| enabled: bool = True |
|
|
| class FallbackManager: |
| def __init__(self, redis: RedisManager, config: FallbackConfig = None): |
| self.redis = redis |
| self.config = config or FallbackConfig() |
| self._http_client: Optional[httpx.AsyncClient] = None |
| |
| async def init_client(self): |
| if not self._http_client: |
| self._http_client = httpx.AsyncClient( |
| limits=httpx.Limits(max_connections=50, max_keepalive_connections=20), |
| timeout=httpx.Timeout(REQUEST_TIMEOUT), |
| ) |
| |
| async def close_client(self): |
| if self._http_client: |
| await self._http_client.aclose() |
| |
| async def get_active_provider(self) -> tuple[str, dict]: |
| if not self.config.enabled: |
| return self.config.primary, self._get_provider_config(self.config.primary) |
| |
| primary_breaker = CircuitBreaker(self.redis, self.config.primary) |
| if await primary_breaker.can_execute(): |
| return self.config.primary, self._get_provider_config(self.config.primary) |
| |
| secondary_breaker = CircuitBreaker(self.redis, self.config.secondary) |
| if await secondary_breaker.can_execute(): |
| FALLBACK_COUNT.labels( |
| from_provider=self.config.primary, |
| to_provider=self.config.secondary, |
| reason="circuit_open", |
| ).inc() |
| logger.warning( |
| f"Fallback: {self.config.primary} unavailable, switching to {self.config.secondary}" |
| ) |
| return self.config.secondary, self._get_provider_config(self.config.secondary) |
| |
| tertiary_breaker = CircuitBreaker(self.redis, self.config.tertiary) |
| if await tertiary_breaker.can_execute(): |
| FALLBACK_COUNT.labels( |
| from_provider=self.config.secondary, |
| to_provider=self.config.tertiary, |
| reason="secondary_down", |
| ).inc() |
| logger.warning( |
| f"Fallback: both {self.config.primary} and {self.config.secondary} down, " |
| f"switching to {self.config.tertiary}" |
| ) |
| return self.config.tertiary, self._get_provider_config(self.config.tertiary) |
| |
| if MLX_ENABLED: |
| mlx_breaker = CircuitBreaker(self.redis, "mlx") |
| if await mlx_breaker.can_execute(): |
| FALLBACK_COUNT.labels( |
| from_provider=self.config.tertiary, |
| to_provider="mlx", |
| reason="all_cloud_down", |
| ).inc() |
| logger.warning("All cloud providers down — falling back to MLX local") |
| return "mlx", self._get_provider_config("mlx") |
| |
| raise HTTPException(status_code=503, detail="All LLM providers unavailable.") |
| |
| def _get_provider_config(self, provider: str) -> dict: |
| configs = { |
| "nim": { |
| "api_base": NIM_API_BASE, |
| "api_key": os.environ.get("NVIDIA_API_KEY", "no-key"), |
| "rpm_limit": 40, |
| "cost_per_1m_input": 0.0, |
| "cost_per_1m_output": 0.0, |
| }, |
| "cloudflare": { |
| "api_base": f"https://api.cloudflare.com/client/v4/accounts/{CLOUDFLARE_ACCOUNT_ID}/ai/v1", |
| "api_key": CLOUDFLARE_API_KEY, |
| "rpm_limit": 100, |
| "cost_per_1m_input": 0.0, |
| "cost_per_1m_output": 0.0, |
| }, |
| "gemini": { |
| "api_base": GEMINI_API_BASE, |
| "api_key": GEMINI_API_KEY, |
| "rpm_limit": 60, |
| "cost_per_1m_input": 0.075, |
| "cost_per_1m_output": 0.30, |
| }, |
| "mlx": { |
| "api_base": MLX_API_BASE, |
| "api_key": "no-key", |
| "rpm_limit": 1000, |
| "cost_per_1m_input": 0.0, |
| "cost_per_1m_output": 0.0, |
| }, |
| } |
| return configs.get(provider, configs["nim"]) |
|
|
| |
| |
| |
|
|
| @dataclass |
| class CostTracker: |
| session_id: str |
| budget_usd: float = BUDGET_USD_PER_SESSION |
| spent_usd: float = 0.0 |
| provider: str = "unknown" |
| model: str = "unknown" |
| |
| def can_spend(self, estimated_cost: float) -> bool: |
| return (self.spent_usd + estimated_cost) <= self.budget_usd |
| |
| def record_spend(self, cost_usd: float): |
| self.spent_usd += cost_usd |
| LLM_COST_USD.labels(provider=self.provider, model=self.model).inc(cost_usd) |
| logger.info(f"Session {self.session_id}: spent ${cost_usd:.4f}, total ${self.spent_usd:.4f} / ${self.budget_usd:.2f}") |
|
|
| |
| |
| |
|
|
| class ConcurrencyLimiter: |
| def __init__(self, max_concurrent: int): |
| self.semaphore = asyncio.Semaphore(max_concurrent) |
| |
| async def acquire(self): |
| await self.semaphore.acquire() |
| |
| def release(self): |
| self.semaphore.release() |
|
|
| |
| |
| |
|
|
| class ChatRequest(BaseModel): |
| model: str = Field(..., description="Model ID (e.g., gemma-4-31b-bf16)") |
| messages: list[dict] = Field(..., description="OpenAI-compatible messages") |
| temperature: Optional[float] = 0.7 |
| max_tokens: Optional[int] = 4096 |
| stream: bool = False |
| tools: Optional[list[dict]] = None |
| tool_choice: Optional[str] = "auto" |
| session_id: Optional[str] = None |
| provider_override: Optional[str] = None |
|
|
| class ChatResponse(BaseModel): |
| id: str |
| session_id: str |
| model: str |
| provider: str |
| content: Optional[str] = None |
| tool_calls: Optional[list[dict]] = None |
| usage: dict = Field(default_factory=dict) |
| cost_usd: float = 0.0 |
| cached: bool = False |
| finish_reason: Optional[str] = None |
| fallback_used: bool = False |
|
|
| class HealthResponse(BaseModel): |
| status: str |
| version: str = "1.0.0" |
| uptime_seconds: float |
| active_sessions: int |
| redis_connected: bool |
| db_connected: bool |
| circuit_breakers: dict[str, str] |
| fallback_status: dict[str, str] |
|
|
| |
| |
| |
|
|
| db_pool: Optional[DatabasePool] = None |
| redis_manager: Optional[RedisManager] = None |
| concurrency_limiter: Optional[ConcurrencyLimiter] = None |
| fallback_manager: Optional[FallbackManager] = None |
| start_time: float = 0.0 |
| shutdown_event: asyncio.Event = asyncio.Event() |
|
|
| |
| |
| |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global db_pool, redis_manager, concurrency_limiter, fallback_manager, start_time |
| |
| start_time = time.time() |
| |
| db_pool = DatabasePool(DATABASE_URL) |
| await db_pool.connect() |
| |
| redis_manager = RedisManager(REDIS_URL) |
| await redis_manager.connect() |
| |
| concurrency_limiter = ConcurrencyLimiter(MAX_CONCURRENT_REQUESTS) |
| fallback_manager = FallbackManager(redis_manager) |
| await fallback_manager.init_client() |
| |
| loop = asyncio.get_event_loop() |
| for sig in (signal.SIGTERM, signal.SIGINT): |
| loop.add_signal_handler(sig, lambda: asyncio.create_task(_shutdown())) |
| |
| if DATABASE_URL: |
| await _init_schema() |
| |
| logger.info("ml-intern production server started") |
| |
| yield |
| |
| logger.info("Shutting down...") |
| shutdown_event.set() |
| |
| if fallback_manager: |
| await fallback_manager.close_client() |
| if redis_manager: |
| await redis_manager.disconnect() |
| if db_pool: |
| await db_pool.disconnect() |
| |
| logger.info("ml-intern production server stopped") |
|
|
| async def _shutdown(): |
| logger.info("Shutdown signal received") |
| shutdown_event.set() |
|
|
| async def _init_schema(): |
| await db_pool.execute(""" |
| CREATE TABLE IF NOT EXISTS sessions ( |
| id TEXT PRIMARY KEY, |
| created_at TIMESTAMP DEFAULT NOW(), |
| last_active_at TIMESTAMP DEFAULT NOW(), |
| budget_usd NUMERIC DEFAULT 10.0, |
| spent_usd NUMERIC DEFAULT 0.0, |
| metadata JSONB DEFAULT '{}' |
| ) |
| """) |
| await db_pool.execute(""" |
| CREATE TABLE IF NOT EXISTS requests ( |
| id TEXT PRIMARY KEY, |
| session_id TEXT, |
| model TEXT, |
| provider TEXT, |
| input_tokens INTEGER, |
| output_tokens INTEGER, |
| cost_usd NUMERIC, |
| latency_ms INTEGER, |
| cached BOOLEAN DEFAULT FALSE, |
| fallback_used BOOLEAN DEFAULT FALSE, |
| created_at TIMESTAMP DEFAULT NOW() |
| ) |
| """) |
| logger.info("Database schema initialized") |
|
|
| app = FastAPI( |
| title="ml-intern Production API", |
| description="Production-grade API with NIM/Cloudflare/Gemini fallback and MLX local support", |
| version="1.0.0", |
| lifespan=lifespan, |
| ) |
|
|
| app.add_middleware(GZipMiddleware, minimum_size=1000) |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| |
| |
|
|
| @app.middleware("http") |
| async def correlation_id_middleware(request: Request, call_next): |
| correlation_id = request.headers.get("X-Correlation-ID", str(uuid.uuid4())) |
| request.state.correlation_id = correlation_id |
| |
| old_factory = logging.getLogRecordFactory() |
| def record_factory(*args, **kwargs): |
| record = old_factory(*args, **kwargs) |
| record.correlation_id = correlation_id |
| return record |
| logging.setLogRecordFactory(record_factory) |
| |
| start = time.time() |
| response = await call_next(request) |
| latency = time.time() - start |
| |
| REQUEST_COUNT.labels( |
| method=request.method, |
| endpoint=request.url.path, |
| status=response.status_code, |
| provider=getattr(request.state, "provider", "unknown"), |
| ).inc() |
| |
| REQUEST_LATENCY.labels( |
| method=request.method, |
| endpoint=request.url.path, |
| provider=getattr(request.state, "provider", "unknown"), |
| ).observe(latency) |
| |
| response.headers["X-Correlation-ID"] = correlation_id |
| return response |
|
|
| |
| |
| |
|
|
| def estimate_cost(provider_config: dict, input_tokens: int, output_tokens: int) -> float: |
| cost = (input_tokens / 1_000_000) * provider_config.get("cost_per_1m_input", 0.0) |
| cost += (output_tokens / 1_000_000) * provider_config.get("cost_per_1m_output", 0.0) |
| return cost |
|
|
| def generate_cache_key(request: ChatRequest) -> str: |
| content = json.dumps({ |
| "model": request.model, |
| "messages": request.messages, |
| "temperature": request.temperature, |
| "max_tokens": request.max_tokens, |
| "tools": request.tools, |
| }, sort_keys=True) |
| return f"cache:llm:{hashlib.sha256(content.encode()).hexdigest()}" |
|
|
| |
| |
| |
|
|
| async def call_llm( |
| provider: str, |
| provider_config: dict, |
| request: ChatRequest, |
| session_id: str, |
| ) -> ChatResponse: |
| if not fallback_manager or not fallback_manager._http_client: |
| raise HTTPException(status_code=503, detail="HTTP client not initialized") |
| |
| api_base = provider_config["api_base"] |
| api_key = provider_config["api_key"] |
| |
| payload = { |
| "model": request.model, |
| "messages": request.messages, |
| "temperature": request.temperature, |
| "max_tokens": request.max_tokens, |
| "stream": False, |
| } |
| |
| headers = { |
| "Content-Type": "application/json", |
| "Authorization": f"Bearer {api_key}", |
| } |
| |
| if provider == "cloudflare": |
| headers["Authorization"] = f"Bearer {api_key}" |
| |
| start_time = time.time() |
| |
| try: |
| response = await fallback_manager._http_client.post( |
| f"{api_base}/chat/completions", |
| json=payload, |
| headers=headers, |
| timeout=REQUEST_TIMEOUT, |
| ) |
| response.raise_for_status() |
| data = response.json() |
| |
| latency_ms = int((time.time() - start_time) * 1000) |
| |
| usage = data.get("usage", {}) |
| input_tokens = usage.get("prompt_tokens", 0) |
| output_tokens = usage.get("completion_tokens", 0) |
| |
| cost = estimate_cost(provider_config, input_tokens, output_tokens) |
| |
| content = None |
| tool_calls = None |
| if "choices" in data and len(data["choices"]) > 0: |
| choice = data["choices"][0] |
| message = choice.get("message", {}) |
| content = message.get("content") |
| tool_calls = message.get("tool_calls") |
| |
| return ChatResponse( |
| id=data.get("id", str(uuid.uuid4())), |
| session_id=session_id, |
| model=request.model, |
| provider=provider, |
| content=content, |
| tool_calls=tool_calls, |
| usage=usage, |
| cost_usd=cost, |
| cached=False, |
| finish_reason=data.get("choices", [{}])[0].get("finish_reason"), |
| fallback_used=False, |
| ) |
| |
| except httpx.HTTPStatusError as e: |
| logger.error(f"HTTP error from {provider}: {e.response.status_code} - {e.response.text[:200]}") |
| raise HTTPException(status_code=502, detail=f"Provider {provider} returned HTTP {e.response.status_code}") |
| except httpx.RequestError as e: |
| logger.error(f"Network error calling {provider}: {e}") |
| raise HTTPException(status_code=503, detail=f"Cannot reach provider {provider}: {str(e)}") |
|
|
| |
| |
| |
|
|
| @app.get("/health", response_model=HealthResponse) |
| async def health_check(): |
| uptime = time.time() - start_time |
| |
| redis_ok = False |
| try: |
| await redis_manager._redis.ping() |
| redis_ok = True |
| except Exception: |
| pass |
| |
| db_ok = False |
| if DATABASE_URL: |
| try: |
| await db_pool.fetchval("SELECT 1") |
| db_ok = True |
| except Exception: |
| pass |
| |
| circuits = {} |
| fallback_status = {} |
| for provider in ["nim", "cloudflare", "gemini", "mlx"]: |
| try: |
| state = await redis_manager.get_circuit_state(provider) |
| circuits[provider] = state["state"] |
| fallback_status[provider] = "up" if state["state"] == "closed" else "down" |
| except Exception: |
| circuits[provider] = "unknown" |
| fallback_status[provider] = "unknown" |
| |
| return HealthResponse( |
| status="healthy" if redis_ok else "degraded", |
| version="1.0.0", |
| uptime_seconds=uptime, |
| active_sessions=0, |
| redis_connected=redis_ok, |
| db_connected=db_ok, |
| circuit_breakers=circuits, |
| fallback_status=fallback_status, |
| ) |
|
|
| @app.get("/metrics") |
| async def metrics(): |
| from starlette.responses import Response |
| return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST) |
|
|
| @app.post("/v1/chat/completions", response_model=ChatResponse) |
| async def chat_completions(request: ChatRequest, background_tasks: BackgroundTasks): |
| correlation_id = getattr(request.state, "correlation_id", str(uuid.uuid4())) |
| session_id = request.session_id or str(uuid.uuid4()) |
| |
| logger.info(f"Chat request: model={request.model}, stream={request.stream}, session={session_id}") |
| |
| await concurrency_limiter.acquire() |
| try: |
| if request.provider_override: |
| provider = request.provider_override |
| provider_config = fallback_manager._get_provider_config(provider) |
| breaker = CircuitBreaker(redis_manager, provider) |
| if not await breaker.can_execute(): |
| raise HTTPException(status_code=503, detail=f"Provider {provider} circuit breaker is open") |
| else: |
| provider, provider_config = await fallback_manager.get_active_provider() |
| |
| request.state.provider = provider |
| |
| rpm = provider_config.get("rpm_limit", DEFAULT_RPM_LIMIT) |
| rate_limit_key = f"{provider}:{session_id}" |
| allowed, retry_after = await redis_manager.check_rate_limit(rate_limit_key, rpm) |
| if not allowed: |
| logger.warning(f"Rate limit exceeded for {rate_limit_key}") |
| raise HTTPException(status_code=429, detail=f"Rate limit exceeded. Retry after {retry_after:.1f}s", headers={"Retry-After": str(int(retry_after))}) |
| |
| if not request.stream: |
| cache_key = generate_cache_key(request) |
| cached = await redis_manager.get_cache(cache_key) |
| if cached: |
| logger.info(f"Cache hit for {cache_key}") |
| data = json.loads(cached) |
| return ChatResponse( |
| id=str(uuid.uuid4()), |
| session_id=session_id, |
| model=request.model, |
| provider=provider, |
| content=data.get("content"), |
| tool_calls=data.get("tool_calls"), |
| usage=data.get("usage", {}), |
| cost_usd=0.0, |
| cached=True, |
| finish_reason=data.get("finish_reason"), |
| fallback_used=False, |
| ) |
| |
| cost_tracker = CostTracker(session_id, provider=provider, model=request.model) |
| |
| response = await call_llm(provider, provider_config, request, session_id) |
| |
| breaker = CircuitBreaker(redis_manager, provider) |
| await breaker.record_success() |
| |
| cost_tracker.record_spend(response.cost_usd) |
| |
| if not request.stream: |
| cache_key = generate_cache_key(request) |
| await redis_manager.set_cache( |
| cache_key, |
| json.dumps({ |
| "content": response.content, |
| "tool_calls": response.tool_calls, |
| "usage": response.usage, |
| "finish_reason": response.finish_reason, |
| }), |
| ) |
| |
| if DATABASE_URL: |
| background_tasks.add_task(_persist_request, session_id, request, response, provider) |
| |
| if provider != FALLBACK_PRIMARY and FALLBACK_ENABLED: |
| response.fallback_used = True |
| |
| return response |
| |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.exception(f"Error processing request: {e}") |
| breaker = CircuitBreaker(redis_manager, provider if 'provider' in locals() else "unknown") |
| await breaker.record_failure() |
| raise HTTPException(status_code=500, detail=str(e)) |
| finally: |
| concurrency_limiter.release() |
|
|
| async def _persist_request(session_id: str, request: ChatRequest, response: ChatResponse, provider: str): |
| try: |
| await db_pool.execute( |
| """ |
| INSERT INTO requests (id, session_id, model, provider, input_tokens, |
| output_tokens, cost_usd, latency_ms, cached, fallback_used) |
| VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) |
| """, |
| response.id, |
| session_id, |
| request.model, |
| provider, |
| response.usage.get("prompt_tokens", 0), |
| response.usage.get("completion_tokens", 0), |
| response.cost_usd, |
| 0, |
| response.cached, |
| response.fallback_used, |
| ) |
| except Exception as e: |
| logger.error(f"Failed to persist request: {e}") |
|
|
| @app.get("/v1/models") |
| async def list_models(): |
| models = [] |
| |
| if os.environ.get("NVIDIA_API_KEY"): |
| models.extend([ |
| {"id": "nim/llama-3.1-405b-instruct", "object": "model", "owned_by": "nvidia"}, |
| {"id": "nim/llama-3.1-70b-instruct", "object": "model", "owned_by": "nvidia"}, |
| {"id": "nim/llama-3.1-8b-instruct", "object": "model", "owned_by": "nvidia"}, |
| {"id": "nim/mistral-7b-instruct", "object": "model", "owned_by": "nvidia"}, |
| ]) |
| |
| if CLOUDFLARE_API_KEY and CLOUDFLARE_ACCOUNT_ID: |
| models.extend([ |
| {"id": "cloudflare/@cf/meta/llama-3.1-8b-instruct", "object": "model", "owned_by": "cloudflare"}, |
| {"id": "cloudflare/@cf/meta/llama-3.1-70b-instruct", "object": "model", "owned_by": "cloudflare"}, |
| {"id": "cloudflare/@cf/mistral/mistral-7b-instruct", "object": "model", "owned_by": "cloudflare"}, |
| {"id": "cloudflare/@cf/qwen/qwen1.5-14b-chat-awq", "object": "model", "owned_by": "cloudflare"}, |
| {"id": "cloudflare/@cf/google/gemma-4-26b-a4b-it", "object": "model", "owned_by": "cloudflare"}, |
| ]) |
| |
| if GEMINI_API_KEY: |
| models.extend([ |
| {"id": "gemini/gemini-2.5-pro-preview", "object": "model", "owned_by": "google"}, |
| {"id": "gemini/gemini-2.5-flash-preview", "object": "model", "owned_by": "google"}, |
| {"id": "gemini/gemma-4-26b", "object": "model", "owned_by": "google"}, |
| {"id": "gemini/gemma-4-9b", "object": "model", "owned_by": "google"}, |
| ]) |
| |
| if MLX_ENABLED: |
| models.extend([ |
| {"id": "mlx/llama-3.1-8b", "object": "model", "owned_by": "mlx"}, |
| {"id": "mlx/llama-3.1-70b", "object": "model", "owned_by": "mlx"}, |
| {"id": "mlx/gemma-4-26b-a4b-it", "object": "model", "owned_by": "mlx"}, |
| {"id": "mlx/gemma-4-31b-bf16", "object": "model", "owned_by": "mlx"}, |
| {"id": "mlx/gemma-4-e4b-it", "object": "model", "owned_by": "mlx"}, |
| ]) |
| |
| return {"object": "list", "data": models} |
|
|
| @app.get("/v1/fallback/status") |
| async def fallback_status(): |
| status = {} |
| for provider in ["nim", "cloudflare", "gemini", "mlx"]: |
| breaker = CircuitBreaker(redis_manager, provider) |
| can_execute = await breaker.can_execute() |
| state = await redis_manager.get_circuit_state(provider) |
| status[provider] = { |
| "circuit_state": state["state"], |
| "failures": state["failures"], |
| "available": can_execute, |
| "last_failure": state["last_failure"], |
| } |
| |
| return { |
| "fallback_enabled": FALLBACK_ENABLED, |
| "primary": FALLBACK_PRIMARY, |
| "secondary": FALLBACK_SECONDARY, |
| "tertiary": FALLBACK_TERTIARY, |
| "providers": status, |
| "active_provider": await _get_active_provider_name(), |
| } |
|
|
| async def _get_active_provider_name() -> str: |
| try: |
| provider, _ = await fallback_manager.get_active_provider() |
| return provider |
| except HTTPException: |
| return "none_available" |
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| port = int(os.environ.get("PORT", "8000")) |
| workers = int(os.environ.get("WORKERS", "1")) |
| |
| uvicorn.run( |
| "production_server:app", |
| host="0.0.0.0", |
| port=port, |
| workers=workers, |
| log_level="info", |
| access_log=True, |
| ) |
|
|